1
0
mirror of https://github.com/bitwarden/server.git synced 2024-11-24 12:35:25 +01:00

[PS-1928] Add BumpAccountRevisionDate methods (#2458)

* Move RevisionDate Bumps to Extension Class

* Add Tests against live databases

* Run Formatting

* Fix Typo

* Fix Test Solution Typo

* Await ReplaceAsync
This commit is contained in:
Justin Baur 2022-12-02 14:24:30 -05:00 committed by GitHub
parent 41db511872
commit efe91fd0d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 3788 additions and 309 deletions

View File

@ -19,6 +19,12 @@
"commands": [
"reportgenerator"
]
},
"dotnet-ef": {
"version": "6.0.11",
"commands": [
"dotnet-ef"
]
}
}
}

View File

@ -104,6 +104,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MicroBenchmarks", "perf\Mic
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Scim.Test", "bitwarden_license\test\Scim.Test\Scim.Test.csproj", "{B1595DA3-4C60-41AA-8BF0-499A5F75A885}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Infrastructure.IntegrationTest", "test\Infrastructure.IntegrationTest\Infrastructure.IntegrationTest.csproj", "{7E9A7DD5-EB78-4AC5-BFD5-64573FD2533B}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@ -252,6 +254,10 @@ Global
{B1595DA3-4C60-41AA-8BF0-499A5F75A885}.Debug|Any CPU.Build.0 = Debug|Any CPU
{B1595DA3-4C60-41AA-8BF0-499A5F75A885}.Release|Any CPU.ActiveCfg = Release|Any CPU
{B1595DA3-4C60-41AA-8BF0-499A5F75A885}.Release|Any CPU.Build.0 = Release|Any CPU
{7E9A7DD5-EB78-4AC5-BFD5-64573FD2533B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{7E9A7DD5-EB78-4AC5-BFD5-64573FD2533B}.Debug|Any CPU.Build.0 = Debug|Any CPU
{7E9A7DD5-EB78-4AC5-BFD5-64573FD2533B}.Release|Any CPU.ActiveCfg = Release|Any CPU
{7E9A7DD5-EB78-4AC5-BFD5-64573FD2533B}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@ -292,6 +298,7 @@ Global
{FE998849-5FC8-41A2-B7C9-9227901471A0} = {287CFF34-BBDB-4BC4-AF88-1E19A5A4679B}
{9C8F8255-5F74-4085-AB9C-9075CF6DDC61} = {EC2D422A-6060-48E2-AAD2-37220D759F03}
{B1595DA3-4C60-41AA-8BF0-499A5F75A885} = {287CFF34-BBDB-4BC4-AF88-1E19A5A4679B}
{7E9A7DD5-EB78-4AC5-BFD5-64573FD2533B} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84F}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {E01CBF68-2E20-425F-9EDB-E0A6510CA92F}

View File

@ -1,13 +1,10 @@
using System.Text.Json;
using AutoMapper;
using Bit.Core.Enums;
using Bit.Core.Enums.Provider;
using Bit.Infrastructure.EntityFramework.Models;
using Bit.Infrastructure.EntityFramework.Repositories.Queries;
using LinqToDB.Data;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using Cipher = Bit.Core.Entities.Cipher;
using User = Bit.Core.Entities.User;
namespace Bit.Infrastructure.EntityFramework.Repositories;
@ -51,68 +48,6 @@ public abstract class BaseEntityFrameworkRepository
}
}
protected async Task UserBumpAccountRevisionDateByCipherId(Cipher cipher)
{
var list = new List<Cipher> { cipher };
await UserBumpAccountRevisionDateByCipherId(list);
}
protected async Task UserBumpAccountRevisionDateByCipherId(IEnumerable<Cipher> ciphers)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
foreach (var cipher in ciphers)
{
var dbContext = GetDatabaseContext(scope);
var query = new UserBumpAccountRevisionDateByCipherIdQuery(cipher);
var users = query.Run(dbContext);
await users.ForEachAsync(e =>
{
dbContext.Attach(e);
e.RevisionDate = DateTime.UtcNow;
});
await dbContext.SaveChangesAsync();
}
}
}
protected async Task UserBumpAccountRevisionDateByOrganizationId(Guid organizationId)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var query = new UserBumpAccountRevisionDateByOrganizationIdQuery(organizationId);
var users = query.Run(dbContext);
await users.ForEachAsync(e =>
{
dbContext.Attach(e);
e.RevisionDate = DateTime.UtcNow;
});
await dbContext.SaveChangesAsync();
}
}
protected async Task UserBumpAccountRevisionDate(Guid userId)
{
await UserBumpManyAccountRevisionDates(new[] { userId });
}
protected async Task UserBumpManyAccountRevisionDates(ICollection<Guid> userIds)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var users = dbContext.Users.Where(u => userIds.Contains(u.Id));
await users.ForEachAsync(u =>
{
dbContext.Attach(u);
u.RevisionDate = DateTime.UtcNow;
});
await dbContext.SaveChangesAsync();
}
}
protected async Task OrganizationUpdateStorage(Guid organizationId)
{
using (var scope = ServiceScopeFactory.CreateScope())
@ -197,81 +132,4 @@ public abstract class BaseEntityFrameworkRepository
await dbContext.SaveChangesAsync();
}
}
protected async Task UserBumpAccountRevisionDateByCollectionId(Guid collectionId, Guid organizationId)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var query = from u in dbContext.Users
join ou in dbContext.OrganizationUsers
on u.Id equals ou.UserId
join cu in dbContext.CollectionUsers
on ou.Id equals cu.OrganizationUserId into cu_g
from cu in cu_g.DefaultIfEmpty()
where !ou.AccessAll && cu.CollectionId.Equals(collectionId)
join gu in dbContext.GroupUsers
on ou.Id equals gu.OrganizationUserId into gu_g
from gu in gu_g.DefaultIfEmpty()
where cu.CollectionId == default(Guid) && !ou.AccessAll
join g in dbContext.Groups
on gu.GroupId equals g.Id into g_g
from g in g_g.DefaultIfEmpty()
join cg in dbContext.CollectionGroups
on gu.GroupId equals cg.GroupId into cg_g
from cg in cg_g.DefaultIfEmpty()
where !g.AccessAll && cg.CollectionId == collectionId &&
(ou.OrganizationId == organizationId && ou.Status == OrganizationUserStatusType.Confirmed &&
(cu.CollectionId != default(Guid) || cg.CollectionId != default(Guid) || ou.AccessAll || g.AccessAll))
select new { u, ou, cu, gu, g, cg };
var users = query.Select(x => x.u);
await users.ForEachAsync(u =>
{
dbContext.Attach(u);
u.RevisionDate = DateTime.UtcNow;
});
await dbContext.SaveChangesAsync();
}
}
protected async Task UserBumpAccountRevisionDateByOrganizationUserId(Guid organizationUserId)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var query = from u in dbContext.Users
join ou in dbContext.OrganizationUsers
on u.Id equals ou.UserId
where ou.Id.Equals(organizationUserId) && ou.Status.Equals(OrganizationUserStatusType.Confirmed)
select new { u, ou };
var users = query.Select(x => x.u);
await users.ForEachAsync(u =>
{
dbContext.Attach(u);
u.AccountRevisionDate = DateTime.UtcNow;
});
await dbContext.SaveChangesAsync();
}
}
protected async Task UserBumpAccountRevisionDateByProviderUserIds(ICollection<Guid> providerUserIds)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var query = from pu in dbContext.ProviderUsers
join u in dbContext.Users
on pu.UserId equals u.Id
where pu.Status.Equals(ProviderUserStatusType.Confirmed) &&
providerUserIds.Contains(pu.Id)
select new { pu, u };
var users = query.Select(x => x.u);
await users.ForEachAsync(u =>
{
dbContext.Attach(u);
u.AccountRevisionDate = DateTime.UtcNow;
});
await dbContext.SaveChangesAsync();
}
}
}

View File

@ -29,40 +29,59 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
var dbContext = GetDatabaseContext(scope);
if (cipher.OrganizationId.HasValue)
{
await UserBumpAccountRevisionDateByCipherId(cipher);
await dbContext.UserBumpAccountRevisionDateByCipherIdAsync(cipher.Id, cipher.OrganizationId);
}
else if (cipher.UserId.HasValue)
{
await UserBumpAccountRevisionDate(cipher.UserId.Value);
await dbContext.UserBumpAccountRevisionDateAsync(cipher.UserId.Value);
}
await dbContext.SaveChangesAsync();
}
return cipher;
}
public IQueryable<User> GetBumpedAccountsByCipherId(Core.Entities.Cipher cipher)
public override async Task DeleteAsync(Core.Entities.Cipher cipher)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var query = new UserBumpAccountRevisionDateByCipherIdQuery(cipher);
return query.Run(dbContext);
var cipherInfo = await dbContext.Ciphers
.Where(c => c.Id == cipher.Id)
.Select(c => new { c.UserId, c.OrganizationId, HasAttachments = c.Attachments != null })
.FirstOrDefaultAsync();
await base.DeleteAsync(cipher);
if (cipherInfo?.OrganizationId != null)
{
if (cipherInfo.HasAttachments == true)
{
await OrganizationUpdateStorage(cipherInfo.OrganizationId.Value);
}
await dbContext.UserBumpAccountRevisionDateByCipherIdAsync(cipher.Id, cipherInfo.OrganizationId);
}
else if (cipherInfo?.UserId != null)
{
if (cipherInfo.HasAttachments)
{
await UserUpdateStorage(cipherInfo.UserId.Value);
}
await dbContext.UserBumpAccountRevisionDateAsync(cipherInfo.UserId.Value);
}
await dbContext.SaveChangesAsync();
}
}
public async Task CreateAsync(Core.Entities.Cipher cipher, IEnumerable<Guid> collectionIds)
{
cipher = await base.CreateAsync(cipher);
await UpdateCollections(cipher, collectionIds);
}
private async Task UpdateCollections(Core.Entities.Cipher cipher, IEnumerable<Guid> collectionIds)
{
cipher = await CreateAsync(cipher);
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var cipherEntity = await dbContext.Ciphers.FindAsync(cipher.Id);
var query = new CipherUpdateCollectionsQuery(cipherEntity, collectionIds).Run(dbContext);
await dbContext.AddRangeAsync(query);
await UpdateCollectionsAsync(dbContext, cipher.Id,
cipher.UserId, cipher.OrganizationId, collectionIds);
await dbContext.SaveChangesAsync();
}
}
@ -88,16 +107,22 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
null;
var entity = Mapper.Map<Cipher>((Core.Entities.Cipher)cipher);
await dbContext.AddAsync(entity);
await dbContext.UserBumpAccountRevisionDateByCipherIdAsync(cipher.Id, cipher.OrganizationId.GetValueOrDefault());
await dbContext.SaveChangesAsync();
}
await UserBumpAccountRevisionDateByCipherId(cipher);
return cipher;
}
public async Task CreateAsync(CipherDetails cipher, IEnumerable<Guid> collectionIds)
{
cipher = await CreateAsyncReturnCipher(cipher);
await UpdateCollections(cipher, collectionIds);
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await UpdateCollectionsAsync(dbContext, cipher.Id,
cipher.UserId, cipher.OrganizationId, collectionIds);
await dbContext.SaveChangesAsync();
}
}
public async Task CreateAsync(IEnumerable<Core.Entities.Cipher> ciphers, IEnumerable<Core.Entities.Folder> folders)
@ -114,7 +139,8 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, folderEntities);
var cipherEntities = Mapper.Map<List<Cipher>>(ciphers);
await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, cipherEntities);
await UserBumpAccountRevisionDateByCipherId(ciphers);
await dbContext.UserBumpAccountRevisionDateByCipherIdAsync(ciphers);
await dbContext.SaveChangesAsync();
}
}
@ -140,7 +166,8 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, collectionCipherEntities);
}
}
await UserBumpAccountRevisionDateByOrganizationId(ciphers.First().OrganizationId.Value);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(ciphers.First().OrganizationId.Value);
await dbContext.SaveChangesAsync();
}
}
@ -163,13 +190,14 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
if (cipher.OrganizationId.HasValue)
{
await OrganizationUpdateStorage(cipher.OrganizationId.Value);
await UserBumpAccountRevisionDateByCipherId(cipher);
await dbContext.UserBumpAccountRevisionDateByCipherIdAsync(cipher.Id, cipher.OrganizationId.Value);
}
else if (cipher.UserId.HasValue)
{
await UserUpdateStorage(cipher.UserId.Value);
await UserBumpAccountRevisionDate(cipher.UserId.Value);
await dbContext.UserBumpAccountRevisionDateAsync(cipher.UserId.Value);
}
await dbContext.SaveChangesAsync();
}
}
@ -184,9 +212,10 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
select c;
dbContext.RemoveRange(ciphers);
await dbContext.SaveChangesAsync();
await OrganizationUpdateStorage(organizationId);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(organizationId);
await dbContext.SaveChangesAsync();
}
await OrganizationUpdateStorage(organizationId);
await UserBumpAccountRevisionDateByOrganizationId(organizationId);
}
public async Task DeleteByOrganizationIdAsync(Guid organizationId)
@ -207,10 +236,10 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
select c;
dbContext.RemoveRange(ciphers);
await OrganizationUpdateStorage(organizationId);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(organizationId);
await dbContext.SaveChangesAsync();
}
await OrganizationUpdateStorage(organizationId);
await UserBumpAccountRevisionDateByOrganizationId(organizationId);
}
public async Task DeleteByUserIdAsync(Guid userId)
@ -228,7 +257,8 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
dbContext.RemoveRange(folders);
await dbContext.SaveChangesAsync();
await UserUpdateStorage(userId);
await UserBumpAccountRevisionDate(userId);
await dbContext.UserBumpAccountRevisionDateAsync(userId);
await dbContext.SaveChangesAsync();
}
}
@ -364,8 +394,8 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
dbContext.Attach(cipher);
cipher.Folders = JsonConvert.SerializeObject(foldersJson);
});
await dbContext.UserBumpAccountRevisionDateAsync(userId);
await dbContext.SaveChangesAsync();
await UserBumpAccountRevisionDate(userId);
}
}
@ -427,26 +457,100 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
}
var mappedEntity = Mapper.Map<Cipher>((Core.Entities.Cipher)cipher);
dbContext.Entry(entity).CurrentValues.SetValues(mappedEntity);
await UserBumpAccountRevisionDateByCipherId(cipher);
await dbContext.UserBumpAccountRevisionDateByCipherIdAsync(cipher.Id, cipher.OrganizationId.GetValueOrDefault());
await dbContext.SaveChangesAsync();
}
}
}
public async Task<bool> ReplaceAsync(Core.Entities.Cipher obj, IEnumerable<Guid> collectionIds)
private static async Task<int> UpdateCollectionsAsync(DatabaseContext context, Guid id, Guid? userId, Guid? organizationId, IEnumerable<Guid> collectionIds)
{
if (!organizationId.HasValue || !collectionIds.Any())
{
return -1;
}
IQueryable<Guid> availableCollectionsQuery;
if (!userId.HasValue)
{
availableCollectionsQuery = context.Collections
.Where(c => c.OrganizationId == organizationId.Value)
.Select(c => c.Id);
}
else
{
availableCollectionsQuery = from c in context.Collections
join o in context.Organizations
on c.OrganizationId equals o.Id
join ou in context.OrganizationUsers
on new { OrganizationId = o.Id, UserId = (Guid?)userId.Value } equals
new { ou.OrganizationId, ou.UserId }
join cu in context.CollectionUsers
on new { ou.AccessAll, CollectionId = c.Id, OrganizationUserId = ou.Id } equals
new { AccessAll = false, cu.CollectionId, cu.OrganizationUserId } into cu_g
from cu in cu_g.DefaultIfEmpty()
join gu in context.GroupUsers
on new { CollectionId = (Guid?)cu.CollectionId, ou.AccessAll, OrganizationUserId = ou.Id } equals
new { CollectionId = (Guid?)null, AccessAll = false, gu.OrganizationUserId } into gu_g
from gu in gu_g.DefaultIfEmpty()
join g in context.Groups
on gu.GroupId equals g.Id into g_g
from g in g_g.DefaultIfEmpty()
join cg in context.CollectionGroups
on new { g.AccessAll, CollectionId = c.Id, gu.GroupId } equals
new { AccessAll = false, cg.CollectionId, cg.GroupId }
where o.Id == organizationId &&
o.Enabled &&
ou.Status == OrganizationUserStatusType.Confirmed &&
(ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly)
select c.Id;
}
var availableCollections = await availableCollectionsQuery.ToListAsync();
if (!availableCollections.Any())
{
return -1;
}
var collectionCiphers = collectionIds
.Where(collectionId => availableCollections.Contains(collectionId))
.Select(collectionId => new CollectionCipher
{
CollectionId = collectionId,
CipherId = id,
});
context.CollectionCiphers.AddRange(collectionCiphers);
return 0;
}
public async Task<bool> ReplaceAsync(Core.Entities.Cipher cipher, IEnumerable<Guid> collectionIds)
{
await UpdateCollections(obj, collectionIds);
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var cipher = await dbContext.Ciphers.FindAsync(obj.Id);
cipher.UserId = null;
cipher.OrganizationId = obj.OrganizationId;
cipher.Data = obj.Data;
cipher.Attachments = obj.Attachments;
cipher.RevisionDate = obj.RevisionDate;
cipher.DeletedDate = obj.DeletedDate;
await dbContext.SaveChangesAsync();
var transaction = await dbContext.Database.BeginTransactionAsync();
var successes = await UpdateCollectionsAsync(
dbContext, cipher.Id, cipher.UserId,
cipher.OrganizationId, collectionIds);
if (successes < 0)
{
await transaction.CommitAsync();
return false;
}
var trackedCipher = await dbContext.Ciphers.FindAsync(cipher.Id);
trackedCipher.UserId = null;
trackedCipher.OrganizationId = cipher.OrganizationId;
trackedCipher.Data = cipher.Data;
trackedCipher.Attachments = cipher.Attachments;
trackedCipher.RevisionDate = cipher.RevisionDate;
trackedCipher.DeletedDate = cipher.DeletedDate;
await transaction.CommitAsync();
if (!string.IsNullOrWhiteSpace(cipher.Attachments))
{
@ -460,7 +564,8 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
}
}
await UserBumpAccountRevisionDateByCipherId(cipher);
await dbContext.UserBumpAccountRevisionDateByCipherIdAsync(cipher.Id, cipher.OrganizationId.GetValueOrDefault());
await dbContext.SaveChangesAsync();
return true;
}
}
@ -522,13 +627,13 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
foreach (var orgId in orgIds)
{
await OrganizationUpdateStorage(orgId.Value);
await UserBumpAccountRevisionDateByOrganizationId(orgId.Value);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(orgId.Value);
}
if (query.Any(c => c.UserId.HasValue && !string.IsNullOrWhiteSpace(c.Attachments)))
{
await UserUpdateStorage(userId);
}
await UserBumpAccountRevisionDate(userId);
await dbContext.UserBumpAccountRevisionDateAsync(userId);
await dbContext.SaveChangesAsync();
return utcNow;
}
@ -547,9 +652,9 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
cipher.DeletedDate = utcNow;
cipher.RevisionDate = utcNow;
});
await dbContext.SaveChangesAsync();
await OrganizationUpdateStorage(organizationId);
await UserBumpAccountRevisionDateByOrganizationId(organizationId);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(organizationId);
await dbContext.SaveChangesAsync();
}
}
@ -570,13 +675,14 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
if (attachment.OrganizationId.HasValue)
{
await OrganizationUpdateStorage(cipher.OrganizationId.Value);
await UserBumpAccountRevisionDateByCipherId(new List<Core.Entities.Cipher> { cipher });
await dbContext.UserBumpAccountRevisionDateByCipherIdAsync(cipher.Id, cipher.OrganizationId);
}
else if (attachment.UserId.HasValue)
{
await UserUpdateStorage(attachment.UserId.Value);
await UserBumpAccountRevisionDate(attachment.UserId.Value);
await dbContext.UserBumpAccountRevisionDateAsync(attachment.UserId.Value);
}
await dbContext.SaveChangesAsync();
}
}
@ -591,7 +697,8 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
var dbContext = GetDatabaseContext(scope);
var entities = Mapper.Map<List<Cipher>>(ciphers);
await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, entities);
await UserBumpAccountRevisionDate(userId);
await dbContext.UserBumpAccountRevisionDateAsync(userId);
await dbContext.SaveChangesAsync();
}
}
@ -626,8 +733,8 @@ public class CipherRepository : Repository<Core.Entities.Cipher, Cipher, Guid>,
favoritesJson.Remove(userId.ToString());
}
await dbContext.UserBumpAccountRevisionDateAsync(userId);
await dbContext.SaveChangesAsync();
await UserBumpAccountRevisionDate(userId);
}
}

View File

@ -25,7 +25,8 @@ public class CollectionCipherRepository : BaseEntityFrameworkRepository, ICollec
var organizationId = (await dbContext.Ciphers.FirstOrDefaultAsync(c => c.Id.Equals(obj.CipherId))).OrganizationId;
if (organizationId.HasValue)
{
await UserBumpAccountRevisionDateByCollectionId(obj.CollectionId, organizationId.Value);
await dbContext.UserBumpAccountRevisionDateByCollectionIdAsync(obj.CollectionId, organizationId.Value);
await dbContext.SaveChangesAsync();
}
return obj;
}
@ -132,12 +133,12 @@ public class CollectionCipherRepository : BaseEntityFrameworkRepository, ICollec
});
await dbContext.AddRangeAsync(insert);
dbContext.RemoveRange(delete);
await dbContext.SaveChangesAsync();
if (organizationId.HasValue)
{
await UserBumpAccountRevisionDateByOrganizationId(organizationId.Value);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(organizationId.Value);
}
await dbContext.SaveChangesAsync();
}
}
@ -182,8 +183,8 @@ public class CollectionCipherRepository : BaseEntityFrameworkRepository, ICollec
});
await dbContext.AddRangeAsync(insert);
dbContext.RemoveRange(delete);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(organizationId);
await dbContext.SaveChangesAsync();
await UserBumpAccountRevisionDateByOrganizationId(organizationId);
}
}
@ -231,7 +232,8 @@ public class CollectionCipherRepository : BaseEntityFrameworkRepository, ICollec
CipherId = cipherId,
};
await dbContext.AddRangeAsync(insertData);
await UserBumpAccountRevisionDateByOrganizationId(organizationId);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(organizationId);
await dbContext.SaveChangesAsync();
}
}
}

View File

@ -14,16 +14,43 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
: base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Collections)
{ }
public override async Task<Core.Entities.Collection> CreateAsync(Core.Entities.Collection obj)
public override async Task<Core.Entities.Collection> CreateAsync(Core.Entities.Collection collection)
{
await base.CreateAsync(obj);
await UserBumpAccountRevisionDateByCollectionId(obj.Id, obj.OrganizationId);
return obj;
await base.CreateAsync(collection);
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await dbContext.UserBumpAccountRevisionDateByCollectionIdAsync(collection.Id, collection.OrganizationId);
await dbContext.SaveChangesAsync();
}
return collection;
}
public override async Task DeleteAsync(Core.Entities.Collection collection)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await dbContext.UserBumpAccountRevisionDateByCollectionIdAsync(collection.Id, collection.OrganizationId);
await dbContext.SaveChangesAsync();
}
await base.DeleteAsync(collection);
}
public override async Task UpsertAsync(Core.Entities.Collection collection)
{
await base.UpsertAsync(collection);
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await dbContext.UserBumpAccountRevisionDateByCollectionIdAsync(collection.Id, collection.OrganizationId);
await dbContext.SaveChangesAsync();
}
}
public async Task CreateAsync(Core.Entities.Collection obj, IEnumerable<SelectionReadOnly> groups)
{
await base.CreateAsync(obj);
await CreateAsync(obj);
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
@ -40,8 +67,8 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
HidePasswords = g.HidePasswords,
});
await dbContext.AddRangeAsync(collectionGroups);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(obj.OrganizationId);
await dbContext.SaveChangesAsync();
await UserBumpAccountRevisionDateByOrganizationId(obj.OrganizationId);
}
}
@ -55,8 +82,8 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
cu.OrganizationUserId == organizationUserId
select cu;
dbContext.RemoveRange(await query.ToListAsync());
await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdAsync(organizationUserId);
await dbContext.SaveChangesAsync();
await UserBumpAccountRevisionDateByOrganizationUserId(organizationUserId);
}
}
@ -167,7 +194,7 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
public async Task ReplaceAsync(Core.Entities.Collection collection, IEnumerable<SelectionReadOnly> groups)
{
await base.ReplaceAsync(collection);
await UpsertAsync(collection);
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
@ -228,8 +255,8 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
await dbContext.AddRangeAsync(insert);
dbContext.UpdateRange(update);
dbContext.RemoveRange(delete);
await dbContext.UserBumpAccountRevisionDateByCollectionIdAsync(collection.Id, collection.OrganizationId);
await dbContext.SaveChangesAsync();
await UserBumpAccountRevisionDateByCollectionId(collection.Id, collection.OrganizationId);
}
}
@ -273,7 +300,7 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
// Remove all existing ones that are no longer requested
var requestedUserIds = requestedUsers.Select(u => u.Id);
dbContext.CollectionUsers.RemoveRange(existingCollectionUsers.Where(cu => !requestedUserIds.Contains(cu.OrganizationUserId)));
await UserBumpAccountRevisionDateByCollectionId(id, organizationId);
await dbContext.UserBumpAccountRevisionDateByCollectionIdAsync(id, organizationId);
await dbContext.SaveChangesAsync();
}
}

View File

@ -0,0 +1,146 @@
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Enums.Provider;
using Bit.Infrastructure.EntityFramework.Repositories.Queries;
using Microsoft.EntityFrameworkCore;
namespace Bit.Infrastructure.EntityFramework.Repositories;
public static class DatabaseContextExtensions
{
public static async Task UserBumpAccountRevisionDateAsync(this DatabaseContext context, Guid userId)
{
var user = await context.Users.FindAsync(userId);
user.AccountRevisionDate = DateTime.UtcNow;
}
public static async Task UserBumpManyAccountRevisionDatesAsync(this DatabaseContext context, ICollection<Guid> userIds)
{
var users = context.Users.Where(u => userIds.Contains(u.Id));
var currentTime = DateTime.UtcNow;
await users.ForEachAsync(u =>
{
context.Attach(u);
u.AccountRevisionDate = currentTime;
});
}
public static async Task UserBumpAccountRevisionDateByOrganizationIdAsync(this DatabaseContext context, Guid organizationId)
{
var users = await (from u in context.Users
join ou in context.OrganizationUsers on u.Id equals ou.UserId
where ou.OrganizationId == organizationId && ou.Status == OrganizationUserStatusType.Confirmed
select u).ToListAsync();
UpdateUserRevisionDate(users);
}
public static async Task UserBumpAccountRevisionDateByCipherIdAsync(this DatabaseContext context, Guid cipherId, Guid? organizationId)
{
var query = new UserBumpAccountRevisionDateByCipherIdQuery(cipherId, organizationId);
var users = await query.Run(context).ToListAsync();
UpdateUserRevisionDate(users);
}
public static async Task UserBumpAccountRevisionDateByCipherIdAsync(this DatabaseContext context, IEnumerable<Cipher> ciphers)
{
foreach (var cipher in ciphers)
{
await context.UserBumpAccountRevisionDateByCipherIdAsync(cipher.Id, cipher.OrganizationId);
}
}
public static async Task UserBumpAccountRevisionDateByCollectionIdAsync(this DatabaseContext context, Guid collectionId, Guid organizationId)
{
var query = from u in context.Users
join ou in context.OrganizationUsers
on u.Id equals ou.UserId
join cu in context.CollectionUsers
on new { ou.AccessAll, OrganizationUserId = ou.Id, CollectionId = collectionId } equals
new { AccessAll = false, cu.OrganizationUserId, cu.CollectionId } into cu_g
from cu in cu_g.DefaultIfEmpty()
join gu in context.GroupUsers
on new { CollectionId = (Guid?)cu.CollectionId, ou.AccessAll, OrganizationUserId = ou.Id } equals
new { CollectionId = (Guid?)null, AccessAll = false, gu.OrganizationUserId } into gu_g
from gu in gu_g.DefaultIfEmpty()
join g in context.Groups
on gu.GroupId equals g.Id into g_g
from g in g_g.DefaultIfEmpty()
join cg in context.CollectionGroups
on new { g.AccessAll, gu.GroupId, CollectionId = collectionId } equals
new { AccessAll = false, cg.GroupId, cg.CollectionId } into cg_g
from cg in cg_g.DefaultIfEmpty()
where ou.OrganizationId == organizationId &&
ou.Status == OrganizationUserStatusType.Confirmed &&
cg.CollectionId != null &&
ou.AccessAll == true &&
g.AccessAll == true
select u;
var users = await query.ToListAsync();
UpdateUserRevisionDate(users);
}
public static async Task UserBumpAccountRevisionDateByOrganizationUserIdAsync(this DatabaseContext context, Guid organizationUserId)
{
var query = from u in context.Users
join ou in context.OrganizationUsers
on u.Id equals ou.UserId
where ou.Id == organizationUserId && ou.Status == OrganizationUserStatusType.Confirmed
select u;
var users = await query.ToListAsync();
UpdateUserRevisionDate(users);
}
public static async Task UserBumpAccountRevisionDateByOrganizationUserIdsAsync(this DatabaseContext context, IEnumerable<Guid> organizationUserIds)
{
foreach (var organizationUserId in organizationUserIds)
{
await context.UserBumpAccountRevisionDateByOrganizationUserIdAsync(organizationUserId);
}
}
public static async Task UserBumpAccountRevisionDateByEmergencyAccessGranteeIdAsync(this DatabaseContext context, Guid emergencyAccessId)
{
var query = from u in context.Users
join ea in context.EmergencyAccesses on u.Id equals ea.GranteeId
where ea.Id == emergencyAccessId && ea.Status == EmergencyAccessStatusType.Confirmed
select u;
var users = await query.ToListAsync();
UpdateUserRevisionDate(users);
}
public static async Task UserBumpAccountRevisionDateByProviderIdAsync(this DatabaseContext context, Guid providerId)
{
var query = from u in context.Users
join pu in context.ProviderUsers on u.Id equals pu.UserId
where pu.ProviderId == providerId && pu.Status == ProviderUserStatusType.Confirmed
select u;
var users = await query.ToListAsync();
UpdateUserRevisionDate(users);
}
public static async Task UserBumpAccountRevisionDateByProviderUserIdAsync(this DatabaseContext context, Guid providerUserId)
{
var query = from u in context.Users
join pu in context.ProviderUsers on u.Id equals pu.UserId
where pu.ProviderId == providerUserId && pu.Status == ProviderUserStatusType.Confirmed
select u;
var users = await query.ToListAsync();
UpdateUserRevisionDate(users);
}
private static void UpdateUserRevisionDate(List<Models.User> users)
{
var time = DateTime.UtcNow;
foreach (var user in users)
{
user.AccountRevisionDate = time;
}
}
}

View File

@ -21,6 +21,17 @@ public class EmergencyAccessRepository : Repository<Core.Entities.EmergencyAcces
return await GetCountFromQuery(query);
}
public override async Task DeleteAsync(Core.Entities.EmergencyAccess emergencyAccess)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await dbContext.UserBumpAccountRevisionDateByEmergencyAccessGranteeIdAsync(emergencyAccess.Id);
await dbContext.SaveChangesAsync();
}
await base.DeleteAsync(emergencyAccess);
}
public async Task<EmergencyAccessDetails> GetDetailsByIdGrantorIdAsync(Guid id, Guid grantorId)
{
using (var scope = ServiceScopeFactory.CreateScope())

View File

@ -46,6 +46,7 @@ public class GroupRepository : Repository<Core.Entities.Group, Group, Guid>, IGr
gu.OrganizationUserId == organizationUserId
select gu;
dbContext.RemoveRange(await query.ToListAsync());
await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdAsync(organizationUserId);
await dbContext.SaveChangesAsync();
}
}
@ -134,7 +135,8 @@ public class GroupRepository : Repository<Core.Entities.Group, Group, Guid>, IGr
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await UserBumpAccountRevisionDateByOrganizationId(obj.OrganizationId);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(obj.OrganizationId);
await dbContext.SaveChangesAsync();
}
}
@ -161,7 +163,8 @@ public class GroupRepository : Repository<Core.Entities.Group, Group, Guid>, IGr
select gu;
dbContext.RemoveRange(delete);
await dbContext.SaveChangesAsync();
await UserBumpAccountRevisionDateByOrganizationId(orgId);
await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(orgId);
await dbContext.SaveChangesAsync();
}
}
}

View File

@ -70,9 +70,39 @@ public class OrganizationUserRepository : Repository<Core.Entities.OrganizationU
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var orgUser = await dbContext.FindAsync<OrganizationUser>(organizationUserId);
await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdAsync(organizationUserId);
var orgUser = await dbContext.OrganizationUsers
.Where(ou => ou.Id == organizationUserId)
.FirstAsync();
dbContext.Remove(orgUser);
var organizationId = orgUser?.OrganizationId;
var userId = orgUser?.UserId;
if (orgUser?.OrganizationId != null && orgUser?.UserId != null)
{
var ssoUsers = dbContext.SsoUsers
.Where(su => su.UserId == userId && su.OrganizationId == organizationId);
dbContext.SsoUsers.RemoveRange(ssoUsers);
}
var collectionUsers = dbContext.CollectionUsers
.Where(cu => cu.OrganizationUserId == organizationUserId);
dbContext.CollectionUsers.RemoveRange(collectionUsers);
var groupUsers = dbContext.GroupUsers
.Where(gu => gu.OrganizationUserId == organizationUserId);
dbContext.GroupUsers.RemoveRange(groupUsers);
var orgSponsorships = await dbContext.OrganizationSponsorships
.Where(os => os.SponsoringOrganizationUserId == organizationUserId)
.ToListAsync();
foreach (var orgSponsorship in orgSponsorships)
{
orgSponsorship.ToDelete = true;
}
dbContext.OrganizationUsers.Remove(orgUser);
await dbContext.SaveChangesAsync();
}
}
@ -82,7 +112,9 @@ public class OrganizationUserRepository : Repository<Core.Entities.OrganizationU
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdsAsync(organizationUserIds);
var entities = await dbContext.OrganizationUsers
// TODO: Does this work?
.Where(ou => organizationUserIds.Contains(ou.Id))
.ToListAsync();
@ -309,9 +341,20 @@ public class OrganizationUserRepository : Repository<Core.Entities.OrganizationU
}
}
public async override Task ReplaceAsync(Core.Entities.OrganizationUser organizationUser)
{
await base.ReplaceAsync(organizationUser);
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await dbContext.UserBumpAccountRevisionDateAsync(organizationUser.UserId.GetValueOrDefault());
await dbContext.SaveChangesAsync();
}
}
public async Task ReplaceAsync(Core.Entities.OrganizationUser obj, IEnumerable<SelectionReadOnly> requestedCollections)
{
await base.ReplaceAsync(obj);
await ReplaceAsync(obj);
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
@ -356,7 +399,7 @@ public class OrganizationUserRepository : Repository<Core.Entities.OrganizationU
var dbContext = GetDatabaseContext(scope);
dbContext.UpdateRange(organizationUsers);
await dbContext.SaveChangesAsync();
await UserBumpManyAccountRevisionDates(organizationUsers
await dbContext.UserBumpManyAccountRevisionDatesAsync(organizationUsers
.Where(ou => ou.UserId.HasValue)
.Select(ou => ou.UserId.Value).ToArray());
}
@ -400,7 +443,7 @@ public class OrganizationUserRepository : Repository<Core.Entities.OrganizationU
var delete = procedure.Delete.Run(dbContext);
var deleteData = await delete.ToListAsync();
dbContext.RemoveRange(deleteData);
await UserBumpAccountRevisionDateByOrganizationUserId(orgUserId);
await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdAsync(orgUserId);
await dbContext.SaveChangesAsync();
}
}
@ -449,17 +492,15 @@ public class OrganizationUserRepository : Repository<Core.Entities.OrganizationU
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var orgUser = await GetDbSet(dbContext).FindAsync(id);
if (orgUser != null)
var orgUser = await dbContext.OrganizationUsers.FindAsync(id);
if (orgUser == null)
{
dbContext.Update(orgUser);
orgUser.Status = OrganizationUserStatusType.Revoked;
await dbContext.SaveChangesAsync();
if (orgUser.UserId.HasValue)
{
await UserBumpAccountRevisionDate(orgUser.UserId.Value);
}
return;
}
orgUser.Status = OrganizationUserStatusType.Revoked;
await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdAsync(id);
await dbContext.SaveChangesAsync();
}
}
@ -468,17 +509,17 @@ public class OrganizationUserRepository : Repository<Core.Entities.OrganizationU
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
var orgUser = await GetDbSet(dbContext).FindAsync(id);
if (orgUser != null)
var orgUser = await dbContext.OrganizationUsers
.FirstOrDefaultAsync(ou => ou.Id == id && ou.Status == OrganizationUserStatusType.Revoked);
if (orgUser == null)
{
dbContext.Update(orgUser);
orgUser.Status = status;
await dbContext.SaveChangesAsync();
if (orgUser.UserId.HasValue)
{
await UserBumpAccountRevisionDate(orgUser.UserId.Value);
}
return;
}
orgUser.Status = status;
await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdAsync(id);
await dbContext.SaveChangesAsync();
}
}
}

View File

@ -14,6 +14,17 @@ public class ProviderRepository : Repository<Provider, Models.Provider, Guid>, I
: base(serviceScopeFactory, mapper, context => context.Providers)
{ }
public override async Task DeleteAsync(Provider provider)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await dbContext.UserBumpAccountRevisionDateByProviderIdAsync(provider.Id);
await dbContext.SaveChangesAsync();
}
await base.DeleteAsync(provider);
}
public async Task<ICollection<Provider>> SearchAsync(string name, string userEmail, int skip, int take)
{
using (var scope = ServiceScopeFactory.CreateScope())

View File

@ -16,6 +16,17 @@ public class ProviderUserRepository :
: base(serviceScopeFactory, mapper, (DatabaseContext context) => context.ProviderUsers)
{ }
public override async Task DeleteAsync(ProviderUser providerUser)
{
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await dbContext.UserBumpAccountRevisionDateByProviderUserIdAsync(providerUser.Id);
await dbContext.SaveChangesAsync();
}
await base.DeleteAsync(providerUser);
}
public async Task<int> GetCountByProviderAsync(Guid providerId, string email, bool onlyRegisteredUsers)
{
using (var scope = ServiceScopeFactory.CreateScope())
@ -59,7 +70,10 @@ public class ProviderUserRepository :
using (var scope = ServiceScopeFactory.CreateScope())
{
var dbContext = GetDatabaseContext(scope);
await UserBumpAccountRevisionDateByProviderUserIds(providerUserIds.ToArray());
foreach (var providerUserId in providerUserIds)
{
await dbContext.UserBumpAccountRevisionDateByProviderUserIdAsync(providerUserId);
}
var entities = dbContext.ProviderUsers.Where(pu => providerUserIds.Contains(pu.Id));
dbContext.ProviderUsers.RemoveRange(entities);
await dbContext.SaveChangesAsync();

View File

@ -1,73 +0,0 @@
using Bit.Core.Entities;
using Bit.Core.Enums;
using CollectionCipher = Bit.Infrastructure.EntityFramework.Models.CollectionCipher;
namespace Bit.Infrastructure.EntityFramework.Repositories.Queries;
public class CipherUpdateCollectionsQuery : IQuery<CollectionCipher>
{
private readonly Cipher _cipher;
private readonly IEnumerable<Guid> _collectionIds;
public CipherUpdateCollectionsQuery(Cipher cipher, IEnumerable<Guid> collectionIds)
{
_cipher = cipher;
_collectionIds = collectionIds;
}
public virtual IQueryable<CollectionCipher> Run(DatabaseContext dbContext)
{
if (!_cipher.OrganizationId.HasValue || !_collectionIds.Any())
{
return null;
}
var availibleCollections = !_cipher.UserId.HasValue ?
from c in dbContext.Collections
where c.OrganizationId == _cipher.OrganizationId
select c.Id :
from c in dbContext.Collections
join o in dbContext.Organizations
on c.OrganizationId equals o.Id
join ou in dbContext.OrganizationUsers
on new { OrganizationId = o.Id, _cipher.UserId } equals new { ou.OrganizationId, ou.UserId }
join cu in dbContext.CollectionUsers
on new { ou.AccessAll, CollectionId = c.Id, OrganizationUserId = ou.Id } equals
new { AccessAll = false, cu.CollectionId, cu.OrganizationUserId } into cu_g
from cu in cu_g.DefaultIfEmpty()
join gu in dbContext.GroupUsers
on new { CollectionId = (Guid?)cu.CollectionId, ou.AccessAll, OrganizationUserId = ou.Id } equals
new { CollectionId = (Guid?)null, AccessAll = false, gu.OrganizationUserId } into gu_g
from gu in gu_g.DefaultIfEmpty()
join g in dbContext.Groups
on gu.GroupId equals g.Id into g_g
from g in g_g.DefaultIfEmpty()
join cg in dbContext.CollectionGroups
on new { g.AccessAll, CollectionId = c.Id, gu.GroupId } equals
new { AccessAll = false, cg.CollectionId, cg.GroupId } into cg_g
from cg in cg_g.DefaultIfEmpty()
where o.Id == _cipher.OrganizationId &&
o.Enabled &&
ou.Status == OrganizationUserStatusType.Confirmed &&
(ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly)
select c.Id;
if (!availibleCollections.Any())
{
return null;
}
var query = from c in availibleCollections
select new CollectionCipher { CollectionId = c, CipherId = _cipher.Id };
return query;
}
}

View File

@ -6,11 +6,19 @@ namespace Bit.Infrastructure.EntityFramework.Repositories.Queries;
public class UserBumpAccountRevisionDateByCipherIdQuery : IQuery<User>
{
private readonly Cipher _cipher;
private readonly Guid _cipherId;
private readonly Guid? _organizationId;
public UserBumpAccountRevisionDateByCipherIdQuery(Cipher cipher)
{
_cipher = cipher;
_cipherId = cipher.Id;
_organizationId = cipher.OrganizationId;
}
public UserBumpAccountRevisionDateByCipherIdQuery(Guid cipherId, Guid? organizationId)
{
_cipherId = cipherId;
_organizationId = organizationId;
}
public IQueryable<User> Run(DatabaseContext dbContext)
@ -21,7 +29,7 @@ public class UserBumpAccountRevisionDateByCipherIdQuery : IQuery<User>
on u.Id equals ou.UserId
join collectionCipher in dbContext.CollectionCiphers
on _cipher.Id equals collectionCipher.CipherId into cc_g
on _cipherId equals collectionCipher.CipherId into cc_g
from cc in cc_g.DefaultIfEmpty()
join collectionUser in dbContext.CollectionUsers
@ -43,7 +51,7 @@ public class UserBumpAccountRevisionDateByCipherIdQuery : IQuery<User>
new { AccessAll = false, collectionGroup.GroupId, collectionGroup.CollectionId } into cg_g
from cg in cg_g.DefaultIfEmpty()
where ou.OrganizationId == _cipher.OrganizationId &&
where ou.OrganizationId == _organizationId &&
ou.Status == OrganizationUserStatusType.Confirmed &&
(cu.CollectionId != null ||
cg.CollectionId != null ||

View File

@ -15,11 +15,17 @@ public class SendRepository : Repository<Core.Entities.Send, Send, Guid>, ISendR
public override async Task<Core.Entities.Send> CreateAsync(Core.Entities.Send send)
{
send = await base.CreateAsync(send);
if (send.UserId.HasValue)
using (var scope = ServiceScopeFactory.CreateScope())
{
await UserUpdateStorage(send.UserId.Value);
await UserBumpAccountRevisionDate(send.UserId.Value);
var dbContext = GetDatabaseContext(scope);
if (send.UserId.HasValue)
{
await UserUpdateStorage(send.UserId.Value);
await dbContext.UserBumpAccountRevisionDateAsync(send.UserId.Value);
await dbContext.SaveChangesAsync();
}
}
return send;
}

View File

@ -0,0 +1,17 @@
using Microsoft.Extensions.Configuration;
namespace Bit.Infrastructure.IntegrationTest;
public static class ConfigurationExtensions
{
public static bool TryGetConnectionString(this IConfiguration config, string key, out string connectionString)
{
connectionString = config[key];
if (string.IsNullOrEmpty(connectionString))
{
return false;
}
return true;
}
}

View File

@ -0,0 +1,84 @@
using System.Reflection;
using Bit.Core.Enums;
using Bit.Core.Settings;
using Bit.Infrastructure.Dapper;
using Bit.Infrastructure.EntityFramework;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Xunit.Sdk;
namespace Bit.Infrastructure.IntegrationTest;
public class DatabaseDataAttribute : DataAttribute
{
public bool SelfHosted { get; set; }
public override IEnumerable<object[]> GetData(MethodInfo testMethod)
{
var parameters = testMethod.GetParameters();
var config = DatabaseTheoryAttribute.GetConfiguration();
var serviceProviders = GetDatabaseProviders(config);
foreach (var provider in serviceProviders)
{
var objects = new object[parameters.Length];
for (var i = 0; i < parameters.Length; i++)
{
objects[i] = provider.GetRequiredService(parameters[i].ParameterType);
}
yield return objects;
}
}
private IEnumerable<IServiceProvider> GetDatabaseProviders(IConfiguration config)
{
var configureLogging = (ILoggingBuilder builder) =>
{
if (!config.GetValue<bool>("Quiet"))
{
builder.AddConfiguration(config);
builder.AddConsole();
builder.AddDebug();
}
};
if (config.TryGetConnectionString(DatabaseTheoryAttribute.DapperSqlServerKey, out var dapperSqlServerConnectionString))
{
var dapperSqlServerCollection = new ServiceCollection();
dapperSqlServerCollection.AddLogging(configureLogging);
dapperSqlServerCollection.AddDapperRepositories(SelfHosted);
var globalSettings = new GlobalSettings
{
DatabaseProvider = "sqlServer",
SqlServer = new GlobalSettings.SqlSettings
{
ConnectionString = dapperSqlServerConnectionString,
},
};
dapperSqlServerCollection.AddSingleton(globalSettings);
dapperSqlServerCollection.AddSingleton<IGlobalSettings>(globalSettings);
yield return dapperSqlServerCollection.BuildServiceProvider();
}
if (config.TryGetConnectionString(DatabaseTheoryAttribute.EfPostgresKey, out var efPostgresConnectionString))
{
var efPostgresCollection = new ServiceCollection();
efPostgresCollection.AddLogging(configureLogging);
efPostgresCollection.AddEFRepositories(SelfHosted, efPostgresConnectionString, SupportedDatabaseProviders.Postgres);
efPostgresCollection.AddTransient<ITestDatabaseHelper, EfTestDatabaseHelper>();
yield return efPostgresCollection.BuildServiceProvider();
}
if (config.TryGetConnectionString(DatabaseTheoryAttribute.EfMySqlKey, out var efMySqlConnectionString))
{
var efMySqlCollection = new ServiceCollection();
efMySqlCollection.AddLogging(configureLogging);
efMySqlCollection.AddEFRepositories(SelfHosted, efMySqlConnectionString, SupportedDatabaseProviders.MySql);
efMySqlCollection.AddTransient<ITestDatabaseHelper, EfTestDatabaseHelper>();
yield return efMySqlCollection.BuildServiceProvider();
}
}
}

View File

@ -0,0 +1,38 @@
using Microsoft.Extensions.Configuration;
using Xunit;
namespace Bit.Infrastructure.IntegrationTest;
public class DatabaseTheoryAttribute : TheoryAttribute
{
private static IConfiguration? _cachedConfiguration;
public const string DapperSqlServerKey = "Dapper:SqlServer";
public const string EfPostgresKey = "Ef:Postgres";
public const string EfMySqlKey = "Ef:MySql";
public DatabaseTheoryAttribute()
{
if (!HasAnyDatabaseSetup())
{
Skip = "No database connections strings setup.";
}
}
private static bool HasAnyDatabaseSetup()
{
var config = GetConfiguration();
return config.TryGetConnectionString(DapperSqlServerKey, out _) ||
config.TryGetConnectionString(EfPostgresKey, out _) ||
config.TryGetConnectionString(EfMySqlKey, out _);
}
public static IConfiguration GetConfiguration()
{
return _cachedConfiguration ??= new ConfigurationBuilder()
.AddUserSecrets<DatabaseDataAttribute>(optional: true, reloadOnChange: false)
.AddEnvironmentVariables("BW_TEST_")
.AddCommandLine(Environment.GetCommandLineArgs())
.Build();
}
}

View File

@ -0,0 +1,31 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<Nullable>enable</Nullable>
<IsPackable>false</IsPackable>
<UserSecretsId>6570f288-5c2c-47ad-8978-f3da255079c2</UserSecretsId>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Configuration" Version="6.0.1" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="6.0.1" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="6.0.0" />
<PackageReference Include="Microsoft.Extensions.Logging" Version="6.0.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.11.0" />
<PackageReference Include="xunit" Version="2.4.1" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.3">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="coverlet.collector" Version="3.1.0">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\Infrastructure.Dapper\Infrastructure.Dapper.csproj" />
<ProjectReference Include="..\..\src\Infrastructure.EntityFramework\Infrastructure.EntityFramework.csproj" />
</ItemGroup>
</Project>

View File

@ -0,0 +1,97 @@
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Core.Models.Data;
using Xunit;
namespace Bit.Infrastructure.IntegrationTest.Repositories;
public class CipherRepositoryTests
{
[DatabaseTheory, DatabaseData]
public async Task DeleteAsync_UpdatesUserRevisionDate(
IUserRepository userRepository,
ICipherRepository cipherRepository,
ITestDatabaseHelper helper)
{
var user = await userRepository.CreateAsync(new User
{
Name = "Test User",
Email = "test@email.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
});
var cipher = await cipherRepository.CreateAsync(new Cipher
{
Type = CipherType.Login,
UserId = user.Id,
});
helper.ClearTracker();
await cipherRepository.DeleteAsync(cipher);
var deletedCipher = await cipherRepository.GetByIdAsync(cipher.Id);
Assert.Null(deletedCipher);
var updatedUser = await userRepository.GetByIdAsync(user.Id);
Assert.NotEqual(updatedUser.AccountRevisionDate, user.AccountRevisionDate);
}
[DatabaseTheory, DatabaseData]
public async Task CreateAsync_UpdateWithCollecitons_Works(
IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
ICollectionRepository collectionRepository,
ICipherRepository cipherRepository,
ICollectionCipherRepository collectionCipherRepository,
ITestDatabaseHelper helper)
{
var user = await userRepository.CreateAsync(new User
{
Name = "Test User",
Email = "test@email.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
});
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Organization",
});
await organizationUserRepository.CreateAsync(new OrganizationUser
{
UserId = user.Id,
OrganizationId = organization.Id,
Status = OrganizationUserStatusType.Accepted,
Type = OrganizationUserType.Owner,
});
var collection = await collectionRepository.CreateAsync(new Collection
{
Name = "Test Collection",
OrganizationId = organization.Id
});
helper.ClearTracker();
await cipherRepository.CreateAsync(new CipherDetails
{
Type = CipherType.Login,
OrganizationId = organization.Id,
}, new List<Guid>
{
collection.Id,
});
var updatedUser = await userRepository.GetByIdAsync(user.Id);
Assert.NotEqual(updatedUser.AccountRevisionDate, user.AccountRevisionDate);
var collectionCiphers = await collectionCipherRepository.GetManyByOrganizationIdAsync(organization.Id);
Assert.NotEmpty(collectionCiphers);
}
}

View File

@ -0,0 +1,46 @@
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Xunit;
namespace Bit.Infrastructure.IntegrationTest.Repositories;
public class EmergencyAccessRepositoriesTests
{
[DatabaseTheory, DatabaseData]
public async Task DeleteAsync_UpdatesRevisionDate(IUserRepository userRepository,
IEmergencyAccessRepository emergencyAccessRepository,
ITestDatabaseHelper helper)
{
var grantorUser = await userRepository.CreateAsync(new User
{
Name = "Test Grantor User",
Email = "test+grantor@email.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
});
var granteeUser = await userRepository.CreateAsync(new User
{
Name = "Test Grantee User",
Email = "test+grantee@email.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
});
var emergencyAccess = await emergencyAccessRepository.CreateAsync(new EmergencyAccess
{
GrantorId = grantorUser.Id,
GranteeId = granteeUser.Id,
Status = EmergencyAccessStatusType.Confirmed,
});
helper.ClearTracker();
await emergencyAccessRepository.DeleteAsync(emergencyAccess);
var updatedGrantee = await userRepository.GetByIdAsync(granteeUser.Id);
Assert.NotEqual(updatedGrantee.AccountRevisionDate, granteeUser.AccountRevisionDate);
}
}

View File

@ -0,0 +1,97 @@
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Xunit;
namespace Bit.Infrastructure.IntegrationTest.Repositories;
public class OrganizationUserRepositoryTests
{
[DatabaseTheory, DatabaseData]
public async Task DeleteAsync_Works(IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
ITestDatabaseHelper helper)
{
var user = await userRepository.CreateAsync(new User
{
Name = "Test User",
Email = "test@email.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
});
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
});
var orgUser = await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization.Id,
UserId = user.Id,
Status = OrganizationUserStatusType.Confirmed,
});
helper.ClearTracker();
await organizationUserRepository.DeleteAsync(orgUser);
var newUser = await userRepository.GetByIdAsync(user.Id);
Assert.NotEqual(newUser.AccountRevisionDate, user.AccountRevisionDate);
}
[DatabaseTheory, DatabaseData]
public async Task DeleteManyAsync_Works(IUserRepository userRepository,
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
ITestDatabaseHelper helper)
{
var user1 = await userRepository.CreateAsync(new User
{
Name = "Test User 1",
Email = "test1@email.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
});
var user2 = await userRepository.CreateAsync(new User
{
Name = "Test User 2",
Email = "test1@email.com",
ApiKey = "TEST",
SecurityStamp = "stamp",
});
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
});
var orgUser1 = await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization.Id,
UserId = user1.Id,
});
var orgUser2 = await organizationUserRepository.CreateAsync(new OrganizationUser
{
OrganizationId = organization.Id,
UserId = user2.Id,
});
helper.ClearTracker();
await organizationUserRepository.DeleteManyAsync(new List<Guid>
{
orgUser1.Id,
orgUser2.Id,
});
var updatedUser1 = await userRepository.GetByIdAsync(user1.Id);
var updatedUser2 = await userRepository.GetByIdAsync(user2.Id);
Assert.NotEqual(updatedUser1.AccountRevisionDate, user1.AccountRevisionDate);
Assert.NotEqual(updatedUser2.AccountRevisionDate, user2.AccountRevisionDate);
}
}

View File

@ -0,0 +1,36 @@
using Bit.Infrastructure.EntityFramework.Repositories;
namespace Bit.Infrastructure.IntegrationTest;
public interface ITestDatabaseHelper
{
void ClearTracker();
}
public class EfTestDatabaseHelper : ITestDatabaseHelper
{
private readonly DatabaseContext _databaseContext;
public EfTestDatabaseHelper(DatabaseContext databaseContext)
{
_databaseContext = databaseContext;
}
public void ClearTracker()
{
_databaseContext.ChangeTracker.Clear();
}
}
public class DapperSqlServerTestDatabaseHelper : ITestDatabaseHelper
{
public DapperSqlServerTestDatabaseHelper()
{
}
public void ClearTracker()
{
// There are no tracked entities in Dapper SQL Server
}
}

File diff suppressed because it is too large Load Diff

View File

@ -21,6 +21,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Infrastructure.EFIntegratio
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Api.IntegrationTest", "Api.IntegrationTest\Api.IntegrationTest.csproj", "{6ED94433-3423-498C-96C9-F24756357D95}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Infrastructure.IntegrationTest", "Infrastructure.IntegrationTest\Infrastructure.IntegrationTest.csproj", "{5827E256-D1C5-4BBE-BB74-ED28A83578FA}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@ -142,5 +144,17 @@ Global
{6ED94433-3423-498C-96C9-F24756357D95}.Release|x64.Build.0 = Release|Any CPU
{6ED94433-3423-498C-96C9-F24756357D95}.Release|x86.ActiveCfg = Release|Any CPU
{6ED94433-3423-498C-96C9-F24756357D95}.Release|x86.Build.0 = Release|Any CPU
{5827E256-D1C5-4BBE-BB74-ED28A83578FA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{5827E256-D1C5-4BBE-BB74-ED28A83578FA}.Debug|Any CPU.Build.0 = Debug|Any CPU
{5827E256-D1C5-4BBE-BB74-ED28A83578FA}.Debug|x64.ActiveCfg = Debug|Any CPU
{5827E256-D1C5-4BBE-BB74-ED28A83578FA}.Debug|x64.Build.0 = Debug|Any CPU
{5827E256-D1C5-4BBE-BB74-ED28A83578FA}.Debug|x86.ActiveCfg = Debug|Any CPU
{5827E256-D1C5-4BBE-BB74-ED28A83578FA}.Debug|x86.Build.0 = Debug|Any CPU
{5827E256-D1C5-4BBE-BB74-ED28A83578FA}.Release|Any CPU.ActiveCfg = Release|Any CPU
{5827E256-D1C5-4BBE-BB74-ED28A83578FA}.Release|Any CPU.Build.0 = Release|Any CPU
{5827E256-D1C5-4BBE-BB74-ED28A83578FA}.Release|x64.ActiveCfg = Release|Any CPU
{5827E256-D1C5-4BBE-BB74-ED28A83578FA}.Release|x64.Build.0 = Release|Any CPU
{5827E256-D1C5-4BBE-BB74-ED28A83578FA}.Release|x86.ActiveCfg = Release|Any CPU
{5827E256-D1C5-4BBE-BB74-ED28A83578FA}.Release|x86.Build.0 = Release|Any CPU
EndGlobalSection
EndGlobal