mirror of
https://github.com/bitwarden/server.git
synced 2024-11-21 12:05:42 +01:00
Update ReplaceAsync
Implementation in EF CollectionRepository
(#4611)
* Add Collections Tests * Update CollectionRepository Implementation * Test Adding And Deleting Through Replace * Format
This commit is contained in:
parent
db4ff79c91
commit
3d7fe4f8af
@ -47,7 +47,7 @@ public interface ICollectionRepository : IRepository<Collection, Guid>
|
||||
/// </summary>
|
||||
Task<CollectionAdminDetails?> GetByIdWithPermissionsAsync(Guid collectionId, Guid? userId, bool includeAccessRelationships);
|
||||
|
||||
Task CreateAsync(Collection obj, IEnumerable<CollectionAccessSelection> groups, IEnumerable<CollectionAccessSelection> users);
|
||||
Task CreateAsync(Collection obj, IEnumerable<CollectionAccessSelection>? groups, IEnumerable<CollectionAccessSelection>? users);
|
||||
Task ReplaceAsync(Collection obj, IEnumerable<CollectionAccessSelection> groups, IEnumerable<CollectionAccessSelection> users);
|
||||
Task DeleteUserAsync(Guid collectionId, Guid organizationUserId);
|
||||
Task UpdateUsersAsync(Guid id, IEnumerable<CollectionAccessSelection> users);
|
||||
|
@ -50,7 +50,7 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
|
||||
}
|
||||
}
|
||||
|
||||
public async Task CreateAsync(Core.Entities.Collection obj, IEnumerable<CollectionAccessSelection> groups, IEnumerable<CollectionAccessSelection> users)
|
||||
public async Task CreateAsync(Core.Entities.Collection obj, IEnumerable<CollectionAccessSelection>? groups, IEnumerable<CollectionAccessSelection>? users)
|
||||
{
|
||||
await CreateAsync(obj);
|
||||
using (var scope = ServiceScopeFactory.CreateScope())
|
||||
@ -523,6 +523,7 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
|
||||
await ReplaceCollectionGroupsAsync(dbContext, collection, groups);
|
||||
await ReplaceCollectionUsersAsync(dbContext, collection, users);
|
||||
await dbContext.UserBumpAccountRevisionDateByCollectionIdAsync(collection.Id, collection.OrganizationId);
|
||||
await dbContext.SaveChangesAsync();
|
||||
}
|
||||
}
|
||||
|
||||
@ -689,133 +690,75 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
|
||||
}
|
||||
}
|
||||
|
||||
private async Task ReplaceCollectionGroupsAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable<CollectionAccessSelection> groups)
|
||||
private static async Task ReplaceCollectionGroupsAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable<CollectionAccessSelection> groups)
|
||||
{
|
||||
var groupsInOrg = dbContext.Groups.Where(g => g.OrganizationId == collection.OrganizationId);
|
||||
var modifiedGroupEntities = dbContext.Groups.Where(x => groups.Select(x => x.Id).Contains(x.Id));
|
||||
var target = (from cg in dbContext.CollectionGroups
|
||||
join g in modifiedGroupEntities
|
||||
on cg.CollectionId equals collection.Id into s_g
|
||||
from g in s_g.DefaultIfEmpty()
|
||||
where g == null || cg.GroupId == g.Id
|
||||
select new { cg, g }).AsNoTracking();
|
||||
var source = (from g in modifiedGroupEntities
|
||||
from cg in dbContext.CollectionGroups
|
||||
.Where(cg => cg.CollectionId == collection.Id && cg.GroupId == g.Id).DefaultIfEmpty()
|
||||
select new { cg, g }).AsNoTracking();
|
||||
var union = await target
|
||||
.Union(source)
|
||||
.Where(x =>
|
||||
x.cg == null ||
|
||||
((x.g == null || x.g.Id == x.cg.GroupId) &&
|
||||
(x.cg.CollectionId == collection.Id)))
|
||||
.AsNoTracking()
|
||||
.ToListAsync();
|
||||
var insert = union.Where(x => x.cg == null && groupsInOrg.Any(c => x.g.Id == c.Id))
|
||||
.Select(x => new CollectionGroup
|
||||
{
|
||||
CollectionId = collection.Id,
|
||||
GroupId = x.g.Id,
|
||||
ReadOnly = groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly,
|
||||
HidePasswords = groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords,
|
||||
Manage = groups.FirstOrDefault(g => g.Id == x.g.Id).Manage
|
||||
}).ToList();
|
||||
var update = union
|
||||
.Where(
|
||||
x => x.g != null &&
|
||||
x.cg != null &&
|
||||
(x.cg.ReadOnly != groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly ||
|
||||
x.cg.HidePasswords != groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords ||
|
||||
x.cg.Manage != groups.FirstOrDefault(g => g.Id == x.g.Id).Manage)
|
||||
)
|
||||
.Select(x => new CollectionGroup
|
||||
{
|
||||
CollectionId = collection.Id,
|
||||
GroupId = x.g.Id,
|
||||
ReadOnly = groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly,
|
||||
HidePasswords = groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords,
|
||||
Manage = groups.FirstOrDefault(g => g.Id == x.g.Id).Manage,
|
||||
});
|
||||
var delete = union
|
||||
.Where(
|
||||
x => x.g == null &&
|
||||
x.cg.CollectionId == collection.Id
|
||||
)
|
||||
.Select(x => new CollectionGroup
|
||||
{
|
||||
CollectionId = collection.Id,
|
||||
GroupId = x.cg.GroupId,
|
||||
})
|
||||
.ToList();
|
||||
var existingCollectionGroups = await dbContext.CollectionGroups
|
||||
.Where(cg => cg.CollectionId == collection.Id)
|
||||
.ToDictionaryAsync(cg => cg.GroupId);
|
||||
|
||||
await dbContext.AddRangeAsync(insert);
|
||||
dbContext.UpdateRange(update);
|
||||
dbContext.RemoveRange(delete);
|
||||
await dbContext.SaveChangesAsync();
|
||||
foreach (var group in groups)
|
||||
{
|
||||
if (existingCollectionGroups.TryGetValue(group.Id, out var existingCollectionGroup))
|
||||
{
|
||||
// It already exists, update it
|
||||
existingCollectionGroup.HidePasswords = group.HidePasswords;
|
||||
existingCollectionGroup.ReadOnly = group.ReadOnly;
|
||||
existingCollectionGroup.Manage = group.Manage;
|
||||
dbContext.CollectionGroups.Update(existingCollectionGroup);
|
||||
}
|
||||
else
|
||||
{
|
||||
// This is a brand new entry, add it
|
||||
dbContext.CollectionGroups.Add(new CollectionGroup
|
||||
{
|
||||
GroupId = group.Id,
|
||||
CollectionId = collection.Id,
|
||||
HidePasswords = group.HidePasswords,
|
||||
ReadOnly = group.ReadOnly,
|
||||
Manage = group.Manage,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
var requestedGroupIds = groups.Select(g => g.Id).ToArray();
|
||||
var toDelete = existingCollectionGroups.Values.Where(cg => !requestedGroupIds.Contains(cg.GroupId));
|
||||
dbContext.CollectionGroups.RemoveRange(toDelete);
|
||||
// SaveChangesAsync is expected to be called outside this method
|
||||
}
|
||||
|
||||
private async Task ReplaceCollectionUsersAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable<CollectionAccessSelection> users)
|
||||
private static async Task ReplaceCollectionUsersAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable<CollectionAccessSelection> users)
|
||||
{
|
||||
var usersInOrg = dbContext.OrganizationUsers.Where(u => u.OrganizationId == collection.OrganizationId);
|
||||
var modifiedUserEntities = dbContext.OrganizationUsers.Where(x => users.Select(x => x.Id).Contains(x.Id));
|
||||
var target = (from cu in dbContext.CollectionUsers
|
||||
join u in modifiedUserEntities
|
||||
on cu.CollectionId equals collection.Id into s_g
|
||||
from u in s_g.DefaultIfEmpty()
|
||||
where u == null || cu.OrganizationUserId == u.Id
|
||||
select new { cu, u }).AsNoTracking();
|
||||
var source = (from u in modifiedUserEntities
|
||||
from cu in dbContext.CollectionUsers
|
||||
.Where(cu => cu.CollectionId == collection.Id && cu.OrganizationUserId == u.Id).DefaultIfEmpty()
|
||||
select new { cu, u }).AsNoTracking();
|
||||
var union = await target
|
||||
.Union(source)
|
||||
.Where(x =>
|
||||
x.cu == null ||
|
||||
((x.u == null || x.u.Id == x.cu.OrganizationUserId) &&
|
||||
(x.cu.CollectionId == collection.Id)))
|
||||
.AsNoTracking()
|
||||
.ToListAsync();
|
||||
var insert = union.Where(x => x.u == null && usersInOrg.Any(c => x.u.Id == c.Id))
|
||||
.Select(x => new CollectionUser
|
||||
{
|
||||
CollectionId = collection.Id,
|
||||
OrganizationUserId = x.u.Id,
|
||||
ReadOnly = users.FirstOrDefault(u => u.Id == x.u.Id).ReadOnly,
|
||||
HidePasswords = users.FirstOrDefault(u => u.Id == x.u.Id).HidePasswords,
|
||||
Manage = users.FirstOrDefault(u => u.Id == x.u.Id).Manage,
|
||||
}).ToList();
|
||||
var update = union
|
||||
.Where(
|
||||
x => x.u != null &&
|
||||
x.cu != null &&
|
||||
(x.cu.ReadOnly != users.FirstOrDefault(u => u.Id == x.u.Id).ReadOnly ||
|
||||
x.cu.HidePasswords != users.FirstOrDefault(u => u.Id == x.u.Id).HidePasswords ||
|
||||
x.cu.Manage != users.FirstOrDefault(u => u.Id == x.u.Id).Manage)
|
||||
)
|
||||
.Select(x => new CollectionUser
|
||||
{
|
||||
CollectionId = collection.Id,
|
||||
OrganizationUserId = x.u.Id,
|
||||
ReadOnly = users.FirstOrDefault(u => u.Id == x.u.Id).ReadOnly,
|
||||
HidePasswords = users.FirstOrDefault(u => u.Id == x.u.Id).HidePasswords,
|
||||
Manage = users.FirstOrDefault(u => u.Id == x.u.Id).Manage,
|
||||
});
|
||||
var delete = union
|
||||
.Where(
|
||||
x => x.u == null &&
|
||||
x.cu.CollectionId == collection.Id
|
||||
)
|
||||
.Select(x => new CollectionUser
|
||||
{
|
||||
CollectionId = collection.Id,
|
||||
OrganizationUserId = x.cu.OrganizationUserId,
|
||||
})
|
||||
.ToList();
|
||||
var existingCollectionUsers = await dbContext.CollectionUsers
|
||||
.Where(cu => cu.CollectionId == collection.Id)
|
||||
.ToDictionaryAsync(cu => cu.OrganizationUserId);
|
||||
|
||||
await dbContext.AddRangeAsync(insert);
|
||||
dbContext.UpdateRange(update);
|
||||
dbContext.RemoveRange(delete);
|
||||
await dbContext.SaveChangesAsync();
|
||||
foreach (var user in users)
|
||||
{
|
||||
if (existingCollectionUsers.TryGetValue(user.Id, out var existingCollectionUser))
|
||||
{
|
||||
// This is an existing entry, update it.
|
||||
existingCollectionUser.HidePasswords = user.HidePasswords;
|
||||
existingCollectionUser.ReadOnly = user.ReadOnly;
|
||||
existingCollectionUser.Manage = user.Manage;
|
||||
dbContext.CollectionUsers.Update(existingCollectionUser);
|
||||
}
|
||||
else
|
||||
{
|
||||
// This is a brand new entry, add it
|
||||
dbContext.CollectionUsers.Add(new CollectionUser
|
||||
{
|
||||
OrganizationUserId = user.Id,
|
||||
CollectionId = collection.Id,
|
||||
HidePasswords = user.HidePasswords,
|
||||
ReadOnly = user.ReadOnly,
|
||||
Manage = user.Manage,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
var requestedUserIds = users.Select(u => u.Id).ToArray();
|
||||
var toDelete = existingCollectionUsers.Values.Where(cu => !requestedUserIds.Contains(cu.OrganizationUserId));
|
||||
dbContext.CollectionUsers.RemoveRange(toDelete);
|
||||
// SaveChangesAsync is expected to be called outside this method
|
||||
}
|
||||
}
|
||||
|
@ -13,8 +13,8 @@
|
||||
<PackageReference Include="Microsoft.Extensions.Logging" Version="8.0.0" />
|
||||
<PackageReference Include="Microsoft.Extensions.TimeProvider.Testing" Version="8.6.0" />
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNetTestSdkVersion)" />
|
||||
<PackageReference Include="xunit" Version="2.4.1" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.3">
|
||||
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitRunnerVisualStudioVersion)">
|
||||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
|
||||
<PrivateAssets>all</PrivateAssets>
|
||||
</PackageReference>
|
||||
|
@ -463,4 +463,141 @@ public class CollectionRepositoryTests
|
||||
Assert.False(c3.Unmanaged);
|
||||
});
|
||||
}
|
||||
|
||||
[DatabaseTheory, DatabaseData]
|
||||
public async Task ReplaceAsync_Works(
|
||||
IUserRepository userRepository,
|
||||
IOrganizationRepository organizationRepository,
|
||||
IOrganizationUserRepository organizationUserRepository,
|
||||
IGroupRepository groupRepository,
|
||||
ICollectionRepository collectionRepository)
|
||||
{
|
||||
var user = await userRepository.CreateAsync(new User
|
||||
{
|
||||
Name = "Test User",
|
||||
Email = $"test+{Guid.NewGuid()}@email.com",
|
||||
ApiKey = "TEST",
|
||||
SecurityStamp = "stamp",
|
||||
});
|
||||
|
||||
var organization = await organizationRepository.CreateAsync(new Organization
|
||||
{
|
||||
Name = "Test Org",
|
||||
PlanType = PlanType.EnterpriseAnnually,
|
||||
Plan = "Test Plan",
|
||||
BillingEmail = "billing@email.com"
|
||||
});
|
||||
|
||||
var orgUser1 = await organizationUserRepository.CreateAsync(new OrganizationUser
|
||||
{
|
||||
OrganizationId = organization.Id,
|
||||
UserId = user.Id,
|
||||
Status = OrganizationUserStatusType.Confirmed,
|
||||
});
|
||||
|
||||
var orgUser2 = await organizationUserRepository.CreateAsync(new OrganizationUser
|
||||
{
|
||||
OrganizationId = organization.Id,
|
||||
UserId = user.Id,
|
||||
Status = OrganizationUserStatusType.Confirmed,
|
||||
});
|
||||
|
||||
var orgUser3 = await organizationUserRepository.CreateAsync(new OrganizationUser
|
||||
{
|
||||
OrganizationId = organization.Id,
|
||||
UserId = user.Id,
|
||||
Status = OrganizationUserStatusType.Confirmed,
|
||||
});
|
||||
|
||||
var group1 = await groupRepository.CreateAsync(new Group
|
||||
{
|
||||
Name = "Test Group #1",
|
||||
OrganizationId = organization.Id,
|
||||
});
|
||||
|
||||
var group2 = await groupRepository.CreateAsync(new Group
|
||||
{
|
||||
Name = "Test Group #2",
|
||||
OrganizationId = organization.Id,
|
||||
});
|
||||
|
||||
var group3 = await groupRepository.CreateAsync(new Group
|
||||
{
|
||||
Name = "Test Group #3",
|
||||
OrganizationId = organization.Id,
|
||||
});
|
||||
|
||||
var collection = new Collection
|
||||
{
|
||||
Name = "Test Collection Name",
|
||||
OrganizationId = organization.Id,
|
||||
};
|
||||
|
||||
await collectionRepository.CreateAsync(collection,
|
||||
[
|
||||
new CollectionAccessSelection { Id = group1.Id, Manage = true, HidePasswords = true, ReadOnly = false, },
|
||||
new CollectionAccessSelection { Id = group2.Id, Manage = false, HidePasswords = false, ReadOnly = true, },
|
||||
],
|
||||
[
|
||||
new CollectionAccessSelection { Id = orgUser1.Id, Manage = true, HidePasswords = false, ReadOnly = true },
|
||||
new CollectionAccessSelection { Id = orgUser2.Id, Manage = false, HidePasswords = true, ReadOnly = false },
|
||||
]
|
||||
);
|
||||
|
||||
collection.Name = "Updated Collection Name";
|
||||
|
||||
await collectionRepository.ReplaceAsync(collection,
|
||||
[
|
||||
// Should delete group1
|
||||
new CollectionAccessSelection { Id = group2.Id, Manage = true, HidePasswords = true, ReadOnly = false, },
|
||||
// Should add group3
|
||||
new CollectionAccessSelection { Id = group3.Id, Manage = false, HidePasswords = false, ReadOnly = true, },
|
||||
],
|
||||
[
|
||||
// Should delete orgUser1
|
||||
new CollectionAccessSelection { Id = orgUser2.Id, Manage = false, HidePasswords = false, ReadOnly = true },
|
||||
// Should add orgUser3
|
||||
new CollectionAccessSelection { Id = orgUser3.Id, Manage = true, HidePasswords = false, ReadOnly = true },
|
||||
]
|
||||
);
|
||||
|
||||
// Assert it
|
||||
var info = await collectionRepository.GetByIdWithPermissionsAsync(collection.Id, user.Id, true);
|
||||
|
||||
Assert.NotNull(info);
|
||||
|
||||
Assert.Equal("Updated Collection Name", info.Name);
|
||||
|
||||
var groups = info.Groups.ToArray();
|
||||
|
||||
Assert.Equal(2, groups.Length);
|
||||
|
||||
var actualGroup2 = Assert.Single(groups.Where(g => g.Id == group2.Id));
|
||||
|
||||
Assert.True(actualGroup2.Manage);
|
||||
Assert.True(actualGroup2.HidePasswords);
|
||||
Assert.False(actualGroup2.ReadOnly);
|
||||
|
||||
var actualGroup3 = Assert.Single(groups.Where(g => g.Id == group3.Id));
|
||||
|
||||
Assert.False(actualGroup3.Manage);
|
||||
Assert.False(actualGroup3.HidePasswords);
|
||||
Assert.True(actualGroup3.ReadOnly);
|
||||
|
||||
var users = info.Users.ToArray();
|
||||
|
||||
Assert.Equal(2, users.Length);
|
||||
|
||||
var actualOrgUser2 = Assert.Single(users.Where(u => u.Id == orgUser2.Id));
|
||||
|
||||
Assert.False(actualOrgUser2.Manage);
|
||||
Assert.False(actualOrgUser2.HidePasswords);
|
||||
Assert.True(actualOrgUser2.ReadOnly);
|
||||
|
||||
var actualOrgUser3 = Assert.Single(users.Where(u => u.Id == orgUser3.Id));
|
||||
|
||||
Assert.True(actualOrgUser3.Manage);
|
||||
Assert.False(actualOrgUser3.HidePasswords);
|
||||
Assert.True(actualOrgUser3.ReadOnly);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user