1
0
mirror of https://github.com/bitwarden/server.git synced 2024-11-21 12:05:42 +01:00

Update ReplaceAsync Implementation in EF CollectionRepository (#4611)

* Add Collections Tests

* Update CollectionRepository Implementation

* Test Adding And Deleting Through Replace

* Format
This commit is contained in:
Justin Baur 2024-08-14 13:50:29 -04:00 committed by GitHub
parent db4ff79c91
commit 3d7fe4f8af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 206 additions and 126 deletions

View File

@ -47,7 +47,7 @@ public interface ICollectionRepository : IRepository<Collection, Guid>
/// </summary>
Task<CollectionAdminDetails?> GetByIdWithPermissionsAsync(Guid collectionId, Guid? userId, bool includeAccessRelationships);
Task CreateAsync(Collection obj, IEnumerable<CollectionAccessSelection> groups, IEnumerable<CollectionAccessSelection> users);
Task CreateAsync(Collection obj, IEnumerable<CollectionAccessSelection>? groups, IEnumerable<CollectionAccessSelection>? users);
Task ReplaceAsync(Collection obj, IEnumerable<CollectionAccessSelection> groups, IEnumerable<CollectionAccessSelection> users);
Task DeleteUserAsync(Guid collectionId, Guid organizationUserId);
Task UpdateUsersAsync(Guid id, IEnumerable<CollectionAccessSelection> users);

View File

@ -50,7 +50,7 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
}
}
public async Task CreateAsync(Core.Entities.Collection obj, IEnumerable<CollectionAccessSelection> groups, IEnumerable<CollectionAccessSelection> users)
public async Task CreateAsync(Core.Entities.Collection obj, IEnumerable<CollectionAccessSelection>? groups, IEnumerable<CollectionAccessSelection>? users)
{
await CreateAsync(obj);
using (var scope = ServiceScopeFactory.CreateScope())
@ -523,6 +523,7 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
await ReplaceCollectionGroupsAsync(dbContext, collection, groups);
await ReplaceCollectionUsersAsync(dbContext, collection, users);
await dbContext.UserBumpAccountRevisionDateByCollectionIdAsync(collection.Id, collection.OrganizationId);
await dbContext.SaveChangesAsync();
}
}
@ -689,133 +690,75 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
}
}
private async Task ReplaceCollectionGroupsAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable<CollectionAccessSelection> groups)
private static async Task ReplaceCollectionGroupsAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable<CollectionAccessSelection> groups)
{
var groupsInOrg = dbContext.Groups.Where(g => g.OrganizationId == collection.OrganizationId);
var modifiedGroupEntities = dbContext.Groups.Where(x => groups.Select(x => x.Id).Contains(x.Id));
var target = (from cg in dbContext.CollectionGroups
join g in modifiedGroupEntities
on cg.CollectionId equals collection.Id into s_g
from g in s_g.DefaultIfEmpty()
where g == null || cg.GroupId == g.Id
select new { cg, g }).AsNoTracking();
var source = (from g in modifiedGroupEntities
from cg in dbContext.CollectionGroups
.Where(cg => cg.CollectionId == collection.Id && cg.GroupId == g.Id).DefaultIfEmpty()
select new { cg, g }).AsNoTracking();
var union = await target
.Union(source)
.Where(x =>
x.cg == null ||
((x.g == null || x.g.Id == x.cg.GroupId) &&
(x.cg.CollectionId == collection.Id)))
.AsNoTracking()
.ToListAsync();
var insert = union.Where(x => x.cg == null && groupsInOrg.Any(c => x.g.Id == c.Id))
.Select(x => new CollectionGroup
{
CollectionId = collection.Id,
GroupId = x.g.Id,
ReadOnly = groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly,
HidePasswords = groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords,
Manage = groups.FirstOrDefault(g => g.Id == x.g.Id).Manage
}).ToList();
var update = union
.Where(
x => x.g != null &&
x.cg != null &&
(x.cg.ReadOnly != groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly ||
x.cg.HidePasswords != groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords ||
x.cg.Manage != groups.FirstOrDefault(g => g.Id == x.g.Id).Manage)
)
.Select(x => new CollectionGroup
{
CollectionId = collection.Id,
GroupId = x.g.Id,
ReadOnly = groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly,
HidePasswords = groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords,
Manage = groups.FirstOrDefault(g => g.Id == x.g.Id).Manage,
});
var delete = union
.Where(
x => x.g == null &&
x.cg.CollectionId == collection.Id
)
.Select(x => new CollectionGroup
{
CollectionId = collection.Id,
GroupId = x.cg.GroupId,
})
.ToList();
var existingCollectionGroups = await dbContext.CollectionGroups
.Where(cg => cg.CollectionId == collection.Id)
.ToDictionaryAsync(cg => cg.GroupId);
await dbContext.AddRangeAsync(insert);
dbContext.UpdateRange(update);
dbContext.RemoveRange(delete);
await dbContext.SaveChangesAsync();
foreach (var group in groups)
{
if (existingCollectionGroups.TryGetValue(group.Id, out var existingCollectionGroup))
{
// It already exists, update it
existingCollectionGroup.HidePasswords = group.HidePasswords;
existingCollectionGroup.ReadOnly = group.ReadOnly;
existingCollectionGroup.Manage = group.Manage;
dbContext.CollectionGroups.Update(existingCollectionGroup);
}
else
{
// This is a brand new entry, add it
dbContext.CollectionGroups.Add(new CollectionGroup
{
GroupId = group.Id,
CollectionId = collection.Id,
HidePasswords = group.HidePasswords,
ReadOnly = group.ReadOnly,
Manage = group.Manage,
});
}
}
var requestedGroupIds = groups.Select(g => g.Id).ToArray();
var toDelete = existingCollectionGroups.Values.Where(cg => !requestedGroupIds.Contains(cg.GroupId));
dbContext.CollectionGroups.RemoveRange(toDelete);
// SaveChangesAsync is expected to be called outside this method
}
private async Task ReplaceCollectionUsersAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable<CollectionAccessSelection> users)
private static async Task ReplaceCollectionUsersAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable<CollectionAccessSelection> users)
{
var usersInOrg = dbContext.OrganizationUsers.Where(u => u.OrganizationId == collection.OrganizationId);
var modifiedUserEntities = dbContext.OrganizationUsers.Where(x => users.Select(x => x.Id).Contains(x.Id));
var target = (from cu in dbContext.CollectionUsers
join u in modifiedUserEntities
on cu.CollectionId equals collection.Id into s_g
from u in s_g.DefaultIfEmpty()
where u == null || cu.OrganizationUserId == u.Id
select new { cu, u }).AsNoTracking();
var source = (from u in modifiedUserEntities
from cu in dbContext.CollectionUsers
.Where(cu => cu.CollectionId == collection.Id && cu.OrganizationUserId == u.Id).DefaultIfEmpty()
select new { cu, u }).AsNoTracking();
var union = await target
.Union(source)
.Where(x =>
x.cu == null ||
((x.u == null || x.u.Id == x.cu.OrganizationUserId) &&
(x.cu.CollectionId == collection.Id)))
.AsNoTracking()
.ToListAsync();
var insert = union.Where(x => x.u == null && usersInOrg.Any(c => x.u.Id == c.Id))
.Select(x => new CollectionUser
{
CollectionId = collection.Id,
OrganizationUserId = x.u.Id,
ReadOnly = users.FirstOrDefault(u => u.Id == x.u.Id).ReadOnly,
HidePasswords = users.FirstOrDefault(u => u.Id == x.u.Id).HidePasswords,
Manage = users.FirstOrDefault(u => u.Id == x.u.Id).Manage,
}).ToList();
var update = union
.Where(
x => x.u != null &&
x.cu != null &&
(x.cu.ReadOnly != users.FirstOrDefault(u => u.Id == x.u.Id).ReadOnly ||
x.cu.HidePasswords != users.FirstOrDefault(u => u.Id == x.u.Id).HidePasswords ||
x.cu.Manage != users.FirstOrDefault(u => u.Id == x.u.Id).Manage)
)
.Select(x => new CollectionUser
{
CollectionId = collection.Id,
OrganizationUserId = x.u.Id,
ReadOnly = users.FirstOrDefault(u => u.Id == x.u.Id).ReadOnly,
HidePasswords = users.FirstOrDefault(u => u.Id == x.u.Id).HidePasswords,
Manage = users.FirstOrDefault(u => u.Id == x.u.Id).Manage,
});
var delete = union
.Where(
x => x.u == null &&
x.cu.CollectionId == collection.Id
)
.Select(x => new CollectionUser
{
CollectionId = collection.Id,
OrganizationUserId = x.cu.OrganizationUserId,
})
.ToList();
var existingCollectionUsers = await dbContext.CollectionUsers
.Where(cu => cu.CollectionId == collection.Id)
.ToDictionaryAsync(cu => cu.OrganizationUserId);
await dbContext.AddRangeAsync(insert);
dbContext.UpdateRange(update);
dbContext.RemoveRange(delete);
await dbContext.SaveChangesAsync();
foreach (var user in users)
{
if (existingCollectionUsers.TryGetValue(user.Id, out var existingCollectionUser))
{
// This is an existing entry, update it.
existingCollectionUser.HidePasswords = user.HidePasswords;
existingCollectionUser.ReadOnly = user.ReadOnly;
existingCollectionUser.Manage = user.Manage;
dbContext.CollectionUsers.Update(existingCollectionUser);
}
else
{
// This is a brand new entry, add it
dbContext.CollectionUsers.Add(new CollectionUser
{
OrganizationUserId = user.Id,
CollectionId = collection.Id,
HidePasswords = user.HidePasswords,
ReadOnly = user.ReadOnly,
Manage = user.Manage,
});
}
}
var requestedUserIds = users.Select(u => u.Id).ToArray();
var toDelete = existingCollectionUsers.Values.Where(cu => !requestedUserIds.Contains(cu.OrganizationUserId));
dbContext.CollectionUsers.RemoveRange(toDelete);
// SaveChangesAsync is expected to be called outside this method
}
}

View File

@ -13,8 +13,8 @@
<PackageReference Include="Microsoft.Extensions.Logging" Version="8.0.0" />
<PackageReference Include="Microsoft.Extensions.TimeProvider.Testing" Version="8.6.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNetTestSdkVersion)" />
<PackageReference Include="xunit" Version="2.4.1" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.3">
<PackageReference Include="xunit" Version="$(XUnitVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XUnitRunnerVisualStudioVersion)">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>

View File

@ -463,4 +463,141 @@ public class CollectionRepositoryTests
Assert.False(c3.Unmanaged);
});
}
[DatabaseTheory, DatabaseData]
public async Task ReplaceAsync_Works(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
IGroupRepository groupRepository,
ICollectionRepository collectionRepository)
{
var user = await userRepository.CreateAsync(new User
{
Name = "Test User",
Email = $"test+{Guid.NewGuid()}@email.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
});
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
PlanType = PlanType.EnterpriseAnnually,
Plan = "Test Plan",
BillingEmail = "billing@email.com"
});
var orgUser1 = await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
});
var orgUser2 = await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
});
var orgUser3 = await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
});
var group1 = await groupRepository.CreateAsync(new Group
{
Name = "Test Group #1",
OrganizationId = organization.Id,
});
var group2 = await groupRepository.CreateAsync(new Group
{
Name = "Test Group #2",
OrganizationId = organization.Id,
});
var group3 = await groupRepository.CreateAsync(new Group
{
Name = "Test Group #3",
OrganizationId = organization.Id,
});
var collection = new Collection
{
Name = "Test Collection Name",
OrganizationId = organization.Id,
};
await collectionRepository.CreateAsync(collection,
[
new CollectionAccessSelection { Id = group1.Id, Manage = true, HidePasswords = true, ReadOnly = false, },
new CollectionAccessSelection { Id = group2.Id, Manage = false, HidePasswords = false, ReadOnly = true, },
],
[
new CollectionAccessSelection { Id = orgUser1.Id, Manage = true, HidePasswords = false, ReadOnly = true },
new CollectionAccessSelection { Id = orgUser2.Id, Manage = false, HidePasswords = true, ReadOnly = false },
]
);
collection.Name = "Updated Collection Name";
await collectionRepository.ReplaceAsync(collection,
[
// Should delete group1
new CollectionAccessSelection { Id = group2.Id, Manage = true, HidePasswords = true, ReadOnly = false, },
// Should add group3
new CollectionAccessSelection { Id = group3.Id, Manage = false, HidePasswords = false, ReadOnly = true, },
],
[
// Should delete orgUser1
new CollectionAccessSelection { Id = orgUser2.Id, Manage = false, HidePasswords = false, ReadOnly = true },
// Should add orgUser3
new CollectionAccessSelection { Id = orgUser3.Id, Manage = true, HidePasswords = false, ReadOnly = true },
]
);
// Assert it
var info = await collectionRepository.GetByIdWithPermissionsAsync(collection.Id, user.Id, true);
Assert.NotNull(info);
Assert.Equal("Updated Collection Name", info.Name);
var groups = info.Groups.ToArray();
Assert.Equal(2, groups.Length);
var actualGroup2 = Assert.Single(groups.Where(g => g.Id == group2.Id));
Assert.True(actualGroup2.Manage);
Assert.True(actualGroup2.HidePasswords);
Assert.False(actualGroup2.ReadOnly);
var actualGroup3 = Assert.Single(groups.Where(g => g.Id == group3.Id));
Assert.False(actualGroup3.Manage);
Assert.False(actualGroup3.HidePasswords);
Assert.True(actualGroup3.ReadOnly);
var users = info.Users.ToArray();
Assert.Equal(2, users.Length);
var actualOrgUser2 = Assert.Single(users.Where(u => u.Id == orgUser2.Id));
Assert.False(actualOrgUser2.Manage);
Assert.False(actualOrgUser2.HidePasswords);
Assert.True(actualOrgUser2.ReadOnly);
var actualOrgUser3 = Assert.Single(users.Where(u => u.Id == orgUser3.Id));
Assert.True(actualOrgUser3.Manage);
Assert.False(actualOrgUser3.HidePasswords);
Assert.True(actualOrgUser3.ReadOnly);
}
}