diff --git a/.editorconfig b/.editorconfig index 9c2beed95..828d0a7b3 100644 --- a/.editorconfig +++ b/.editorconfig @@ -114,6 +114,9 @@ csharp_new_line_before_finally = true csharp_new_line_before_members_in_object_initializers = true csharp_new_line_before_members_in_anonymous_types = true +# Namespace settigns +csharp_style_namespace_declarations = file_scoped:warning + # All files [*] guidelines = 120 diff --git a/bitwarden_license/src/Commercial.Core/Services/ProviderService.cs b/bitwarden_license/src/Commercial.Core/Services/ProviderService.cs index 22c180524..e50388f40 100644 --- a/bitwarden_license/src/Commercial.Core/Services/ProviderService.cs +++ b/bitwarden_license/src/Commercial.Core/Services/ProviderService.cs @@ -13,497 +13,496 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.AspNetCore.DataProtection; -namespace Bit.Commercial.Core.Services +namespace Bit.Commercial.Core.Services; + +public class ProviderService : IProviderService { - public class ProviderService : IProviderService + public static PlanType[] ProviderDisllowedOrganizationTypes = new[] { PlanType.Free, PlanType.FamiliesAnnually, PlanType.FamiliesAnnually2019 }; + + private readonly IDataProtector _dataProtector; + private readonly IMailService _mailService; + private readonly IEventService _eventService; + private readonly GlobalSettings _globalSettings; + private readonly IProviderRepository _providerRepository; + private readonly IProviderUserRepository _providerUserRepository; + private readonly IProviderOrganizationRepository _providerOrganizationRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IUserRepository _userRepository; + private readonly IUserService _userService; + private readonly IOrganizationService _organizationService; + private readonly ICurrentContext _currentContext; + + public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository, + IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository, + IUserService userService, IOrganizationService organizationService, IMailService mailService, + IDataProtectionProvider dataProtectionProvider, IEventService eventService, + IOrganizationRepository organizationRepository, GlobalSettings globalSettings, + ICurrentContext currentContext) { - public static PlanType[] ProviderDisllowedOrganizationTypes = new[] { PlanType.Free, PlanType.FamiliesAnnually, PlanType.FamiliesAnnually2019 }; + _providerRepository = providerRepository; + _providerUserRepository = providerUserRepository; + _providerOrganizationRepository = providerOrganizationRepository; + _organizationRepository = organizationRepository; + _userRepository = userRepository; + _userService = userService; + _organizationService = organizationService; + _mailService = mailService; + _eventService = eventService; + _globalSettings = globalSettings; + _dataProtector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); + _currentContext = currentContext; + } - private readonly IDataProtector _dataProtector; - private readonly IMailService _mailService; - private readonly IEventService _eventService; - private readonly GlobalSettings _globalSettings; - private readonly IProviderRepository _providerRepository; - private readonly IProviderUserRepository _providerUserRepository; - private readonly IProviderOrganizationRepository _providerOrganizationRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IUserRepository _userRepository; - private readonly IUserService _userService; - private readonly IOrganizationService _organizationService; - private readonly ICurrentContext _currentContext; - - public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository, - IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository, - IUserService userService, IOrganizationService organizationService, IMailService mailService, - IDataProtectionProvider dataProtectionProvider, IEventService eventService, - IOrganizationRepository organizationRepository, GlobalSettings globalSettings, - ICurrentContext currentContext) + public async Task CreateAsync(string ownerEmail) + { + var owner = await _userRepository.GetByEmailAsync(ownerEmail); + if (owner == null) { - _providerRepository = providerRepository; - _providerUserRepository = providerUserRepository; - _providerOrganizationRepository = providerOrganizationRepository; - _organizationRepository = organizationRepository; - _userRepository = userRepository; - _userService = userService; - _organizationService = organizationService; - _mailService = mailService; - _eventService = eventService; - _globalSettings = globalSettings; - _dataProtector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); - _currentContext = currentContext; + throw new BadRequestException("Invalid owner. Owner must be an existing Bitwarden user."); } - public async Task CreateAsync(string ownerEmail) + var provider = new Provider { - var owner = await _userRepository.GetByEmailAsync(ownerEmail); - if (owner == null) - { - throw new BadRequestException("Invalid owner. Owner must be an existing Bitwarden user."); - } + Status = ProviderStatusType.Pending, + Enabled = true, + UseEvents = true, + }; + await _providerRepository.CreateAsync(provider); - var provider = new Provider + var providerUser = new ProviderUser + { + ProviderId = provider.Id, + UserId = owner.Id, + Type = ProviderUserType.ProviderAdmin, + Status = ProviderUserStatusType.Confirmed, + }; + await _providerUserRepository.CreateAsync(providerUser); + await SendProviderSetupInviteEmailAsync(provider, owner.Email); + } + + public async Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key) + { + var owner = await _userService.GetUserByIdAsync(ownerUserId); + if (owner == null) + { + throw new BadRequestException("Invalid owner."); + } + + if (provider.Status != ProviderStatusType.Pending) + { + throw new BadRequestException("Provider is already setup."); + } + + if (!CoreHelpers.TokenIsValid("ProviderSetupInvite", _dataProtector, token, owner.Email, provider.Id, + _globalSettings.OrganizationInviteExpirationHours)) + { + throw new BadRequestException("Invalid token."); + } + + var providerUser = await _providerUserRepository.GetByProviderUserAsync(provider.Id, ownerUserId); + if (!(providerUser is { Type: ProviderUserType.ProviderAdmin })) + { + throw new BadRequestException("Invalid owner."); + } + + provider.Status = ProviderStatusType.Created; + await _providerRepository.UpsertAsync(provider); + + providerUser.Key = key; + await _providerUserRepository.ReplaceAsync(providerUser); + + return provider; + } + + public async Task UpdateAsync(Provider provider, bool updateBilling = false) + { + if (provider.Id == default) + { + throw new ArgumentException("Cannot create provider this way."); + } + + await _providerRepository.ReplaceAsync(provider); + } + + public async Task> InviteUserAsync(ProviderUserInvite invite) + { + if (!_currentContext.ProviderManageUsers(invite.ProviderId)) + { + throw new InvalidOperationException("Invalid permissions."); + } + + var emails = invite?.UserIdentifiers; + var invitingUser = await _providerUserRepository.GetByProviderUserAsync(invite.ProviderId, invite.InvitingUserId); + + var provider = await _providerRepository.GetByIdAsync(invite.ProviderId); + if (provider == null || emails == null || !emails.Any()) + { + throw new NotFoundException(); + } + + var providerUsers = new List(); + foreach (var email in emails) + { + // Make sure user is not already invited + var existingProviderUserCount = + await _providerUserRepository.GetCountByProviderAsync(invite.ProviderId, email, false); + if (existingProviderUserCount > 0) { - Status = ProviderStatusType.Pending, - Enabled = true, - UseEvents = true, - }; - await _providerRepository.CreateAsync(provider); + continue; + } var providerUser = new ProviderUser { - ProviderId = provider.Id, - UserId = owner.Id, - Type = ProviderUserType.ProviderAdmin, - Status = ProviderUserStatusType.Confirmed, + ProviderId = invite.ProviderId, + UserId = null, + Email = email.ToLowerInvariant(), + Key = null, + Type = invite.Type, + Status = ProviderUserStatusType.Invited, + CreationDate = DateTime.UtcNow, + RevisionDate = DateTime.UtcNow, }; + await _providerUserRepository.CreateAsync(providerUser); - await SendProviderSetupInviteEmailAsync(provider, owner.Email); + + await SendInviteAsync(providerUser, provider); + providerUsers.Add(providerUser); } - public async Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key) + await _eventService.LogProviderUsersEventAsync(providerUsers.Select(pu => (pu, EventType.ProviderUser_Invited, null as DateTime?))); + + return providerUsers; + } + + public async Task>> ResendInvitesAsync(ProviderUserInvite invite) + { + if (!_currentContext.ProviderManageUsers(invite.ProviderId)) { - var owner = await _userService.GetUserByIdAsync(ownerUserId); - if (owner == null) - { - throw new BadRequestException("Invalid owner."); - } - - if (provider.Status != ProviderStatusType.Pending) - { - throw new BadRequestException("Provider is already setup."); - } - - if (!CoreHelpers.TokenIsValid("ProviderSetupInvite", _dataProtector, token, owner.Email, provider.Id, - _globalSettings.OrganizationInviteExpirationHours)) - { - throw new BadRequestException("Invalid token."); - } - - var providerUser = await _providerUserRepository.GetByProviderUserAsync(provider.Id, ownerUserId); - if (!(providerUser is { Type: ProviderUserType.ProviderAdmin })) - { - throw new BadRequestException("Invalid owner."); - } - - provider.Status = ProviderStatusType.Created; - await _providerRepository.UpsertAsync(provider); - - providerUser.Key = key; - await _providerUserRepository.ReplaceAsync(providerUser); - - return provider; + throw new BadRequestException("Invalid permissions."); } - public async Task UpdateAsync(Provider provider, bool updateBilling = false) + var providerUsers = await _providerUserRepository.GetManyAsync(invite.UserIdentifiers); + var provider = await _providerRepository.GetByIdAsync(invite.ProviderId); + + var result = new List>(); + foreach (var providerUser in providerUsers) { - if (provider.Id == default) + if (providerUser.Status != ProviderUserStatusType.Invited || providerUser.ProviderId != invite.ProviderId) { - throw new ArgumentException("Cannot create provider this way."); + result.Add(Tuple.Create(providerUser, "User invalid.")); + continue; } - await _providerRepository.ReplaceAsync(provider); + await SendInviteAsync(providerUser, provider); + result.Add(Tuple.Create(providerUser, "")); } - public async Task> InviteUserAsync(ProviderUserInvite invite) + return result; + } + + public async Task AcceptUserAsync(Guid providerUserId, User user, string token) + { + var providerUser = await _providerUserRepository.GetByIdAsync(providerUserId); + if (providerUser == null) { - if (!_currentContext.ProviderManageUsers(invite.ProviderId)) + throw new BadRequestException("User invalid."); + } + + if (providerUser.Status != ProviderUserStatusType.Invited) + { + throw new BadRequestException("Already accepted."); + } + + if (!CoreHelpers.TokenIsValid("ProviderUserInvite", _dataProtector, token, user.Email, providerUser.Id, + _globalSettings.OrganizationInviteExpirationHours)) + { + throw new BadRequestException("Invalid token."); + } + + if (string.IsNullOrWhiteSpace(providerUser.Email) || + !providerUser.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase)) + { + throw new BadRequestException("User email does not match invite."); + } + + providerUser.Status = ProviderUserStatusType.Accepted; + providerUser.UserId = user.Id; + providerUser.Email = null; + + await _providerUserRepository.ReplaceAsync(providerUser); + + return providerUser; + } + + public async Task>> ConfirmUsersAsync(Guid providerId, Dictionary keys, + Guid confirmingUserId) + { + var providerUsers = await _providerUserRepository.GetManyAsync(keys.Keys); + var validProviderUsers = providerUsers + .Where(u => u.UserId != null) + .ToList(); + + if (!validProviderUsers.Any()) + { + return new List>(); + } + + var validOrganizationUserIds = validProviderUsers.Select(u => u.UserId.Value).ToList(); + + var provider = await _providerRepository.GetByIdAsync(providerId); + var users = await _userRepository.GetManyAsync(validOrganizationUserIds); + + var keyedFilteredUsers = validProviderUsers.ToDictionary(u => u.UserId.Value, u => u); + + var result = new List>(); + var events = new List<(ProviderUser, EventType, DateTime?)>(); + + foreach (var user in users) + { + if (!keyedFilteredUsers.ContainsKey(user.Id)) { - throw new InvalidOperationException("Invalid permissions."); + continue; } - - var emails = invite?.UserIdentifiers; - var invitingUser = await _providerUserRepository.GetByProviderUserAsync(invite.ProviderId, invite.InvitingUserId); - - var provider = await _providerRepository.GetByIdAsync(invite.ProviderId); - if (provider == null || emails == null || !emails.Any()) + var providerUser = keyedFilteredUsers[user.Id]; + try { - throw new NotFoundException(); - } - - var providerUsers = new List(); - foreach (var email in emails) - { - // Make sure user is not already invited - var existingProviderUserCount = - await _providerUserRepository.GetCountByProviderAsync(invite.ProviderId, email, false); - if (existingProviderUserCount > 0) + if (providerUser.Status != ProviderUserStatusType.Accepted || providerUser.ProviderId != providerId) { - continue; + throw new BadRequestException("Invalid user."); } - var providerUser = new ProviderUser - { - ProviderId = invite.ProviderId, - UserId = null, - Email = email.ToLowerInvariant(), - Key = null, - Type = invite.Type, - Status = ProviderUserStatusType.Invited, - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow, - }; + providerUser.Status = ProviderUserStatusType.Confirmed; + providerUser.Key = keys[providerUser.Id]; + providerUser.Email = null; - await _providerUserRepository.CreateAsync(providerUser); - - await SendInviteAsync(providerUser, provider); - providerUsers.Add(providerUser); - } - - await _eventService.LogProviderUsersEventAsync(providerUsers.Select(pu => (pu, EventType.ProviderUser_Invited, null as DateTime?))); - - return providerUsers; - } - - public async Task>> ResendInvitesAsync(ProviderUserInvite invite) - { - if (!_currentContext.ProviderManageUsers(invite.ProviderId)) - { - throw new BadRequestException("Invalid permissions."); - } - - var providerUsers = await _providerUserRepository.GetManyAsync(invite.UserIdentifiers); - var provider = await _providerRepository.GetByIdAsync(invite.ProviderId); - - var result = new List>(); - foreach (var providerUser in providerUsers) - { - if (providerUser.Status != ProviderUserStatusType.Invited || providerUser.ProviderId != invite.ProviderId) - { - result.Add(Tuple.Create(providerUser, "User invalid.")); - continue; - } - - await SendInviteAsync(providerUser, provider); + await _providerUserRepository.ReplaceAsync(providerUser); + events.Add((providerUser, EventType.ProviderUser_Confirmed, null)); + await _mailService.SendProviderConfirmedEmailAsync(provider.Name, user.Email); result.Add(Tuple.Create(providerUser, "")); } - - return result; + catch (BadRequestException e) + { + result.Add(Tuple.Create(providerUser, e.Message)); + } } - public async Task AcceptUserAsync(Guid providerUserId, User user, string token) + await _eventService.LogProviderUsersEventAsync(events); + + return result; + } + + public async Task SaveUserAsync(ProviderUser user, Guid savingUserId) + { + if (user.Id.Equals(default)) { - var providerUser = await _providerUserRepository.GetByIdAsync(providerUserId); - if (providerUser == null) - { - throw new BadRequestException("User invalid."); - } - - if (providerUser.Status != ProviderUserStatusType.Invited) - { - throw new BadRequestException("Already accepted."); - } - - if (!CoreHelpers.TokenIsValid("ProviderUserInvite", _dataProtector, token, user.Email, providerUser.Id, - _globalSettings.OrganizationInviteExpirationHours)) - { - throw new BadRequestException("Invalid token."); - } - - if (string.IsNullOrWhiteSpace(providerUser.Email) || - !providerUser.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase)) - { - throw new BadRequestException("User email does not match invite."); - } - - providerUser.Status = ProviderUserStatusType.Accepted; - providerUser.UserId = user.Id; - providerUser.Email = null; - - await _providerUserRepository.ReplaceAsync(providerUser); - - return providerUser; + throw new BadRequestException("Invite the user first."); } - public async Task>> ConfirmUsersAsync(Guid providerId, Dictionary keys, - Guid confirmingUserId) + if (user.Type != ProviderUserType.ProviderAdmin && + !await HasConfirmedProviderAdminExceptAsync(user.ProviderId, new[] { user.Id })) { - var providerUsers = await _providerUserRepository.GetManyAsync(keys.Keys); - var validProviderUsers = providerUsers - .Where(u => u.UserId != null) - .ToList(); + throw new BadRequestException("Provider must have at least one confirmed ProviderAdmin."); + } - if (!validProviderUsers.Any()) + await _providerUserRepository.ReplaceAsync(user); + await _eventService.LogProviderUserEventAsync(user, EventType.ProviderUser_Updated); + } + + public async Task>> DeleteUsersAsync(Guid providerId, + IEnumerable providerUserIds, Guid deletingUserId) + { + var provider = await _providerRepository.GetByIdAsync(providerId); + + if (provider == null) + { + throw new NotFoundException(); + } + + var providerUsers = await _providerUserRepository.GetManyAsync(providerUserIds); + var users = await _userRepository.GetManyAsync(providerUsers.Where(pu => pu.UserId.HasValue) + .Select(pu => pu.UserId.Value)); + var keyedUsers = users.ToDictionary(u => u.Id); + + if (!await HasConfirmedProviderAdminExceptAsync(providerId, providerUserIds)) + { + throw new BadRequestException("Provider must have at least one confirmed ProviderAdmin."); + } + + var result = new List>(); + var deletedUserIds = new List(); + var events = new List<(ProviderUser, EventType, DateTime?)>(); + + foreach (var providerUser in providerUsers) + { + try { - return new List>(); - } - - var validOrganizationUserIds = validProviderUsers.Select(u => u.UserId.Value).ToList(); - - var provider = await _providerRepository.GetByIdAsync(providerId); - var users = await _userRepository.GetManyAsync(validOrganizationUserIds); - - var keyedFilteredUsers = validProviderUsers.ToDictionary(u => u.UserId.Value, u => u); - - var result = new List>(); - var events = new List<(ProviderUser, EventType, DateTime?)>(); - - foreach (var user in users) - { - if (!keyedFilteredUsers.ContainsKey(user.Id)) + if (providerUser.ProviderId != providerId) { - continue; + throw new BadRequestException("Invalid user."); } - var providerUser = keyedFilteredUsers[user.Id]; - try + if (providerUser.UserId == deletingUserId) { - if (providerUser.Status != ProviderUserStatusType.Accepted || providerUser.ProviderId != providerId) + throw new BadRequestException("You cannot remove yourself."); + } + + events.Add((providerUser, EventType.ProviderUser_Removed, null)); + + var user = keyedUsers.GetValueOrDefault(providerUser.UserId.GetValueOrDefault()); + var email = user == null ? providerUser.Email : user.Email; + if (!string.IsNullOrWhiteSpace(email)) + { + await _mailService.SendProviderUserRemoved(provider.Name, email); + } + + result.Add(Tuple.Create(providerUser, "")); + deletedUserIds.Add(providerUser.Id); + } + catch (BadRequestException e) + { + result.Add(Tuple.Create(providerUser, e.Message)); + } + + await _providerUserRepository.DeleteManyAsync(deletedUserIds); + } + + await _eventService.LogProviderUsersEventAsync(events); + + return result; + } + + public async Task AddOrganization(Guid providerId, Guid organizationId, Guid addingUserId, string key) + { + var po = await _providerOrganizationRepository.GetByOrganizationId(organizationId); + if (po != null) + { + throw new BadRequestException("Organization already belongs to a provider."); + } + + var organization = await _organizationRepository.GetByIdAsync(organizationId); + ThrowOnInvalidPlanType(organization.PlanType); + + var providerOrganization = new ProviderOrganization + { + ProviderId = providerId, + OrganizationId = organizationId, + Key = key, + }; + + await _providerOrganizationRepository.CreateAsync(providerOrganization); + await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Added); + } + + public async Task CreateOrganizationAsync(Guid providerId, + OrganizationSignup organizationSignup, string clientOwnerEmail, User user) + { + ThrowOnInvalidPlanType(organizationSignup.Plan); + + var (organization, _) = await _organizationService.SignUpAsync(organizationSignup, true); + + var providerOrganization = new ProviderOrganization + { + ProviderId = providerId, + OrganizationId = organization.Id, + Key = organizationSignup.OwnerKey, + }; + + await _providerOrganizationRepository.CreateAsync(providerOrganization); + await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Created); + + await _organizationService.InviteUsersAsync(organization.Id, user.Id, + new (OrganizationUserInvite, string)[] + { + ( + new OrganizationUserInvite { - throw new BadRequestException("Invalid user."); - } + Emails = new[] { clientOwnerEmail }, + AccessAll = true, + Type = OrganizationUserType.Owner, + Permissions = null, + Collections = Array.Empty(), + }, + null + ) + }); - providerUser.Status = ProviderUserStatusType.Confirmed; - providerUser.Key = keys[providerUser.Id]; - providerUser.Email = null; + return providerOrganization; + } - await _providerUserRepository.ReplaceAsync(providerUser); - events.Add((providerUser, EventType.ProviderUser_Confirmed, null)); - await _mailService.SendProviderConfirmedEmailAsync(provider.Name, user.Email); - result.Add(Tuple.Create(providerUser, "")); - } - catch (BadRequestException e) - { - result.Add(Tuple.Create(providerUser, e.Message)); - } - } - - await _eventService.LogProviderUsersEventAsync(events); - - return result; + public async Task RemoveOrganizationAsync(Guid providerId, Guid providerOrganizationId, Guid removingUserId) + { + var providerOrganization = await _providerOrganizationRepository.GetByIdAsync(providerOrganizationId); + if (providerOrganization == null || providerOrganization.ProviderId != providerId) + { + throw new BadRequestException("Invalid organization."); } - public async Task SaveUserAsync(ProviderUser user, Guid savingUserId) + if (!await _organizationService.HasConfirmedOwnersExceptAsync(providerOrganization.OrganizationId, new Guid[] { }, includeProvider: false)) { - if (user.Id.Equals(default)) - { - throw new BadRequestException("Invite the user first."); - } - - if (user.Type != ProviderUserType.ProviderAdmin && - !await HasConfirmedProviderAdminExceptAsync(user.ProviderId, new[] { user.Id })) - { - throw new BadRequestException("Provider must have at least one confirmed ProviderAdmin."); - } - - await _providerUserRepository.ReplaceAsync(user); - await _eventService.LogProviderUserEventAsync(user, EventType.ProviderUser_Updated); + throw new BadRequestException("Organization needs to have at least one confirmed owner."); } - public async Task>> DeleteUsersAsync(Guid providerId, - IEnumerable providerUserIds, Guid deletingUserId) + await _providerOrganizationRepository.DeleteAsync(providerOrganization); + await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed); + } + + public async Task ResendProviderSetupInviteEmailAsync(Guid providerId, Guid ownerId) + { + var provider = await _providerRepository.GetByIdAsync(providerId); + var owner = await _userRepository.GetByIdAsync(ownerId); + if (owner == null) { - var provider = await _providerRepository.GetByIdAsync(providerId); + throw new BadRequestException("Invalid owner."); + } + await SendProviderSetupInviteEmailAsync(provider, owner.Email); + } - if (provider == null) - { - throw new NotFoundException(); - } + private async Task SendProviderSetupInviteEmailAsync(Provider provider, string ownerEmail) + { + var token = _dataProtector.Protect($"ProviderSetupInvite {provider.Id} {ownerEmail} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + await _mailService.SendProviderSetupInviteEmailAsync(provider, token, ownerEmail); + } - var providerUsers = await _providerUserRepository.GetManyAsync(providerUserIds); - var users = await _userRepository.GetManyAsync(providerUsers.Where(pu => pu.UserId.HasValue) - .Select(pu => pu.UserId.Value)); - var keyedUsers = users.ToDictionary(u => u.Id); - - if (!await HasConfirmedProviderAdminExceptAsync(providerId, providerUserIds)) - { - throw new BadRequestException("Provider must have at least one confirmed ProviderAdmin."); - } - - var result = new List>(); - var deletedUserIds = new List(); - var events = new List<(ProviderUser, EventType, DateTime?)>(); - - foreach (var providerUser in providerUsers) - { - try - { - if (providerUser.ProviderId != providerId) - { - throw new BadRequestException("Invalid user."); - } - if (providerUser.UserId == deletingUserId) - { - throw new BadRequestException("You cannot remove yourself."); - } - - events.Add((providerUser, EventType.ProviderUser_Removed, null)); - - var user = keyedUsers.GetValueOrDefault(providerUser.UserId.GetValueOrDefault()); - var email = user == null ? providerUser.Email : user.Email; - if (!string.IsNullOrWhiteSpace(email)) - { - await _mailService.SendProviderUserRemoved(provider.Name, email); - } - - result.Add(Tuple.Create(providerUser, "")); - deletedUserIds.Add(providerUser.Id); - } - catch (BadRequestException e) - { - result.Add(Tuple.Create(providerUser, e.Message)); - } - - await _providerUserRepository.DeleteManyAsync(deletedUserIds); - } - - await _eventService.LogProviderUsersEventAsync(events); - - return result; + public async Task LogProviderAccessToOrganizationAsync(Guid organizationId) + { + if (organizationId == default) + { + return; } - public async Task AddOrganization(Guid providerId, Guid organizationId, Guid addingUserId, string key) + var providerOrganization = await _providerOrganizationRepository.GetByOrganizationId(organizationId); + var organization = await _organizationRepository.GetByIdAsync(organizationId); + if (providerOrganization != null) { - var po = await _providerOrganizationRepository.GetByOrganizationId(organizationId); - if (po != null) - { - throw new BadRequestException("Organization already belongs to a provider."); - } - - var organization = await _organizationRepository.GetByIdAsync(organizationId); - ThrowOnInvalidPlanType(organization.PlanType); - - var providerOrganization = new ProviderOrganization - { - ProviderId = providerId, - OrganizationId = organizationId, - Key = key, - }; - - await _providerOrganizationRepository.CreateAsync(providerOrganization); - await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Added); + await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_VaultAccessed); } - - public async Task CreateOrganizationAsync(Guid providerId, - OrganizationSignup organizationSignup, string clientOwnerEmail, User user) + if (organization != null) { - ThrowOnInvalidPlanType(organizationSignup.Plan); - - var (organization, _) = await _organizationService.SignUpAsync(organizationSignup, true); - - var providerOrganization = new ProviderOrganization - { - ProviderId = providerId, - OrganizationId = organization.Id, - Key = organizationSignup.OwnerKey, - }; - - await _providerOrganizationRepository.CreateAsync(providerOrganization); - await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Created); - - await _organizationService.InviteUsersAsync(organization.Id, user.Id, - new (OrganizationUserInvite, string)[] - { - ( - new OrganizationUserInvite - { - Emails = new[] { clientOwnerEmail }, - AccessAll = true, - Type = OrganizationUserType.Owner, - Permissions = null, - Collections = Array.Empty(), - }, - null - ) - }); - - return providerOrganization; + await _eventService.LogOrganizationEventAsync(organization, EventType.Organization_VaultAccessed); } + } - public async Task RemoveOrganizationAsync(Guid providerId, Guid providerOrganizationId, Guid removingUserId) + private async Task SendInviteAsync(ProviderUser providerUser, Provider provider) + { + var nowMillis = CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow); + var token = _dataProtector.Protect( + $"ProviderUserInvite {providerUser.Id} {providerUser.Email} {nowMillis}"); + await _mailService.SendProviderInviteEmailAsync(provider.Name, providerUser, token, providerUser.Email); + } + + private async Task HasConfirmedProviderAdminExceptAsync(Guid providerId, IEnumerable providerUserIds) + { + var providerAdmins = await _providerUserRepository.GetManyByProviderAsync(providerId, + ProviderUserType.ProviderAdmin); + var confirmedOwners = providerAdmins.Where(o => o.Status == ProviderUserStatusType.Confirmed); + var confirmedOwnersIds = confirmedOwners.Select(u => u.Id); + return confirmedOwnersIds.Except(providerUserIds).Any(); + } + + private void ThrowOnInvalidPlanType(PlanType requestedType) + { + if (ProviderDisllowedOrganizationTypes.Contains(requestedType)) { - var providerOrganization = await _providerOrganizationRepository.GetByIdAsync(providerOrganizationId); - if (providerOrganization == null || providerOrganization.ProviderId != providerId) - { - throw new BadRequestException("Invalid organization."); - } - - if (!await _organizationService.HasConfirmedOwnersExceptAsync(providerOrganization.OrganizationId, new Guid[] { }, includeProvider: false)) - { - throw new BadRequestException("Organization needs to have at least one confirmed owner."); - } - - await _providerOrganizationRepository.DeleteAsync(providerOrganization); - await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed); - } - - public async Task ResendProviderSetupInviteEmailAsync(Guid providerId, Guid ownerId) - { - var provider = await _providerRepository.GetByIdAsync(providerId); - var owner = await _userRepository.GetByIdAsync(ownerId); - if (owner == null) - { - throw new BadRequestException("Invalid owner."); - } - await SendProviderSetupInviteEmailAsync(provider, owner.Email); - } - - private async Task SendProviderSetupInviteEmailAsync(Provider provider, string ownerEmail) - { - var token = _dataProtector.Protect($"ProviderSetupInvite {provider.Id} {ownerEmail} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); - await _mailService.SendProviderSetupInviteEmailAsync(provider, token, ownerEmail); - } - - public async Task LogProviderAccessToOrganizationAsync(Guid organizationId) - { - if (organizationId == default) - { - return; - } - - var providerOrganization = await _providerOrganizationRepository.GetByOrganizationId(organizationId); - var organization = await _organizationRepository.GetByIdAsync(organizationId); - if (providerOrganization != null) - { - await _eventService.LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_VaultAccessed); - } - if (organization != null) - { - await _eventService.LogOrganizationEventAsync(organization, EventType.Organization_VaultAccessed); - } - } - - private async Task SendInviteAsync(ProviderUser providerUser, Provider provider) - { - var nowMillis = CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow); - var token = _dataProtector.Protect( - $"ProviderUserInvite {providerUser.Id} {providerUser.Email} {nowMillis}"); - await _mailService.SendProviderInviteEmailAsync(provider.Name, providerUser, token, providerUser.Email); - } - - private async Task HasConfirmedProviderAdminExceptAsync(Guid providerId, IEnumerable providerUserIds) - { - var providerAdmins = await _providerUserRepository.GetManyByProviderAsync(providerId, - ProviderUserType.ProviderAdmin); - var confirmedOwners = providerAdmins.Where(o => o.Status == ProviderUserStatusType.Confirmed); - var confirmedOwnersIds = confirmedOwners.Select(u => u.Id); - return confirmedOwnersIds.Except(providerUserIds).Any(); - } - - private void ThrowOnInvalidPlanType(PlanType requestedType) - { - if (ProviderDisllowedOrganizationTypes.Contains(requestedType)) - { - throw new BadRequestException($"Providers cannot manage organizations with the requested plan type ({requestedType}). Only Teams and Enterprise accounts are allowed."); - } + throw new BadRequestException($"Providers cannot manage organizations with the requested plan type ({requestedType}). Only Teams and Enterprise accounts are allowed."); } } } diff --git a/bitwarden_license/src/Commercial.Core/Utilities/ServiceCollectionExtensions.cs b/bitwarden_license/src/Commercial.Core/Utilities/ServiceCollectionExtensions.cs index 4074fe5f7..5bb1a5bde 100644 --- a/bitwarden_license/src/Commercial.Core/Utilities/ServiceCollectionExtensions.cs +++ b/bitwarden_license/src/Commercial.Core/Utilities/ServiceCollectionExtensions.cs @@ -2,13 +2,12 @@ using Bit.Core.Services; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Commercial.Core.Utilities +namespace Bit.Commercial.Core.Utilities; + +public static class ServiceCollectionExtensions { - public static class ServiceCollectionExtensions + public static void AddCommCoreServices(this IServiceCollection services) { - public static void AddCommCoreServices(this IServiceCollection services) - { - services.AddScoped(); - } + services.AddScoped(); } } diff --git a/bitwarden_license/src/Scim/Context/IScimContext.cs b/bitwarden_license/src/Scim/Context/IScimContext.cs index 90e5aca3a..1e7010bd2 100644 --- a/bitwarden_license/src/Scim/Context/IScimContext.cs +++ b/bitwarden_license/src/Scim/Context/IScimContext.cs @@ -4,18 +4,17 @@ using Bit.Core.Models.OrganizationConnectionConfigs; using Bit.Core.Repositories; using Bit.Core.Settings; -namespace Bit.Scim.Context +namespace Bit.Scim.Context; + +public interface IScimContext { - public interface IScimContext - { - ScimProviderType RequestScimProvider { get; set; } - ScimConfig ScimConfiguration { get; set; } - Guid? OrganizationId { get; set; } - Organization Organization { get; set; } - Task BuildAsync( - HttpContext httpContext, - GlobalSettings globalSettings, - IOrganizationRepository organizationRepository, - IOrganizationConnectionRepository organizationConnectionRepository); - } + ScimProviderType RequestScimProvider { get; set; } + ScimConfig ScimConfiguration { get; set; } + Guid? OrganizationId { get; set; } + Organization Organization { get; set; } + Task BuildAsync( + HttpContext httpContext, + GlobalSettings globalSettings, + IOrganizationRepository organizationRepository, + IOrganizationConnectionRepository organizationConnectionRepository); } diff --git a/bitwarden_license/src/Scim/Context/ScimContext.cs b/bitwarden_license/src/Scim/Context/ScimContext.cs index 0e489d33d..ae8d30807 100644 --- a/bitwarden_license/src/Scim/Context/ScimContext.cs +++ b/bitwarden_license/src/Scim/Context/ScimContext.cs @@ -4,61 +4,60 @@ using Bit.Core.Models.OrganizationConnectionConfigs; using Bit.Core.Repositories; using Bit.Core.Settings; -namespace Bit.Scim.Context +namespace Bit.Scim.Context; + +public class ScimContext : IScimContext { - public class ScimContext : IScimContext + private bool _builtHttpContext; + + public ScimProviderType RequestScimProvider { get; set; } = ScimProviderType.Default; + public ScimConfig ScimConfiguration { get; set; } + public Guid? OrganizationId { get; set; } + public Organization Organization { get; set; } + + public async virtual Task BuildAsync( + HttpContext httpContext, + GlobalSettings globalSettings, + IOrganizationRepository organizationRepository, + IOrganizationConnectionRepository organizationConnectionRepository) { - private bool _builtHttpContext; - - public ScimProviderType RequestScimProvider { get; set; } = ScimProviderType.Default; - public ScimConfig ScimConfiguration { get; set; } - public Guid? OrganizationId { get; set; } - public Organization Organization { get; set; } - - public async virtual Task BuildAsync( - HttpContext httpContext, - GlobalSettings globalSettings, - IOrganizationRepository organizationRepository, - IOrganizationConnectionRepository organizationConnectionRepository) + if (_builtHttpContext) { - if (_builtHttpContext) - { - return; - } + return; + } - _builtHttpContext = true; + _builtHttpContext = true; - string orgIdString = null; - if (httpContext.Request.RouteValues.TryGetValue("organizationId", out var orgIdObject)) - { - orgIdString = orgIdObject?.ToString(); - } + string orgIdString = null; + if (httpContext.Request.RouteValues.TryGetValue("organizationId", out var orgIdObject)) + { + orgIdString = orgIdObject?.ToString(); + } - if (Guid.TryParse(orgIdString, out var orgId)) + if (Guid.TryParse(orgIdString, out var orgId)) + { + OrganizationId = orgId; + Organization = await organizationRepository.GetByIdAsync(orgId); + if (Organization != null) { - OrganizationId = orgId; - Organization = await organizationRepository.GetByIdAsync(orgId); - if (Organization != null) - { - var scimConnections = await organizationConnectionRepository.GetByOrganizationIdTypeAsync(Organization.Id, - OrganizationConnectionType.Scim); - ScimConfiguration = scimConnections?.FirstOrDefault()?.GetConfig(); - } + var scimConnections = await organizationConnectionRepository.GetByOrganizationIdTypeAsync(Organization.Id, + OrganizationConnectionType.Scim); + ScimConfiguration = scimConnections?.FirstOrDefault()?.GetConfig(); } + } - if (RequestScimProvider == ScimProviderType.Default && - httpContext.Request.Headers.TryGetValue("User-Agent", out var userAgent)) + if (RequestScimProvider == ScimProviderType.Default && + httpContext.Request.Headers.TryGetValue("User-Agent", out var userAgent)) + { + if (userAgent.ToString().StartsWith("Okta")) { - if (userAgent.ToString().StartsWith("Okta")) - { - RequestScimProvider = ScimProviderType.Okta; - } - } - if (RequestScimProvider == ScimProviderType.Default && - httpContext.Request.Headers.ContainsKey("Adscimversion")) - { - RequestScimProvider = ScimProviderType.AzureAd; + RequestScimProvider = ScimProviderType.Okta; } } + if (RequestScimProvider == ScimProviderType.Default && + httpContext.Request.Headers.ContainsKey("Adscimversion")) + { + RequestScimProvider = ScimProviderType.AzureAd; + } } } diff --git a/bitwarden_license/src/Scim/Controllers/InfoController.cs b/bitwarden_license/src/Scim/Controllers/InfoController.cs index 67967ed37..aa08ce9bf 100644 --- a/bitwarden_license/src/Scim/Controllers/InfoController.cs +++ b/bitwarden_license/src/Scim/Controllers/InfoController.cs @@ -2,22 +2,21 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Scim.Controllers -{ - [AllowAnonymous] - public class InfoController : Controller - { - [HttpGet("~/alive")] - [HttpGet("~/now")] - public DateTime GetAlive() - { - return DateTime.UtcNow; - } +namespace Bit.Scim.Controllers; - [HttpGet("~/version")] - public JsonResult GetVersion() - { - return Json(CoreHelpers.GetVersion()); - } +[AllowAnonymous] +public class InfoController : Controller +{ + [HttpGet("~/alive")] + [HttpGet("~/now")] + public DateTime GetAlive() + { + return DateTime.UtcNow; + } + + [HttpGet("~/version")] + public JsonResult GetVersion() + { + return Json(CoreHelpers.GetVersion()); } } diff --git a/bitwarden_license/src/Scim/Controllers/v2/GroupsController.cs b/bitwarden_license/src/Scim/Controllers/v2/GroupsController.cs index ff55c411d..6fe47db87 100644 --- a/bitwarden_license/src/Scim/Controllers/v2/GroupsController.cs +++ b/bitwarden_license/src/Scim/Controllers/v2/GroupsController.cs @@ -8,321 +8,320 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Options; -namespace Bit.Scim.Controllers.v2 +namespace Bit.Scim.Controllers.v2; + +[Authorize("Scim")] +[Route("v2/{organizationId}/groups")] +public class GroupsController : Controller { - [Authorize("Scim")] - [Route("v2/{organizationId}/groups")] - public class GroupsController : Controller + private readonly ScimSettings _scimSettings; + private readonly IGroupRepository _groupRepository; + private readonly IGroupService _groupService; + private readonly IScimContext _scimContext; + private readonly ILogger _logger; + + public GroupsController( + IGroupRepository groupRepository, + IGroupService groupService, + IOptions scimSettings, + IScimContext scimContext, + ILogger logger) { - private readonly ScimSettings _scimSettings; - private readonly IGroupRepository _groupRepository; - private readonly IGroupService _groupService; - private readonly IScimContext _scimContext; - private readonly ILogger _logger; + _scimSettings = scimSettings?.Value; + _groupRepository = groupRepository; + _groupService = groupService; + _scimContext = scimContext; + _logger = logger; + } - public GroupsController( - IGroupRepository groupRepository, - IGroupService groupService, - IOptions scimSettings, - IScimContext scimContext, - ILogger logger) + [HttpGet("{id}")] + public async Task Get(Guid organizationId, Guid id) + { + var group = await _groupRepository.GetByIdAsync(id); + if (group == null || group.OrganizationId != organizationId) { - _scimSettings = scimSettings?.Value; - _groupRepository = groupRepository; - _groupService = groupService; - _scimContext = scimContext; - _logger = logger; + return new NotFoundObjectResult(new ScimErrorResponseModel + { + Status = 404, + Detail = "Group not found." + }); + } + return new ObjectResult(new ScimGroupResponseModel(group)); + } + + [HttpGet("")] + public async Task Get( + Guid organizationId, + [FromQuery] string filter, + [FromQuery] int? count, + [FromQuery] int? startIndex) + { + string nameFilter = null; + string externalIdFilter = null; + if (!string.IsNullOrWhiteSpace(filter)) + { + if (filter.StartsWith("displayName eq ")) + { + nameFilter = filter.Substring(15).Trim('"'); + } + else if (filter.StartsWith("externalId eq ")) + { + externalIdFilter = filter.Substring(14).Trim('"'); + } } - [HttpGet("{id}")] - public async Task Get(Guid organizationId, Guid id) + var groupList = new List(); + var groups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId); + var totalResults = 0; + if (!string.IsNullOrWhiteSpace(nameFilter)) { - var group = await _groupRepository.GetByIdAsync(id); - if (group == null || group.OrganizationId != organizationId) + var group = groups.FirstOrDefault(g => g.Name == nameFilter); + if (group != null) { - return new NotFoundObjectResult(new ScimErrorResponseModel - { - Status = 404, - Detail = "Group not found." - }); + groupList.Add(new ScimGroupResponseModel(group)); } - return new ObjectResult(new ScimGroupResponseModel(group)); + totalResults = groupList.Count; + } + else if (!string.IsNullOrWhiteSpace(externalIdFilter)) + { + var group = groups.FirstOrDefault(ou => ou.ExternalId == externalIdFilter); + if (group != null) + { + groupList.Add(new ScimGroupResponseModel(group)); + } + totalResults = groupList.Count; + } + else if (string.IsNullOrWhiteSpace(filter) && startIndex.HasValue && count.HasValue) + { + groupList = groups.OrderBy(g => g.Name) + .Skip(startIndex.Value - 1) + .Take(count.Value) + .Select(g => new ScimGroupResponseModel(g)) + .ToList(); + totalResults = groups.Count; } - [HttpGet("")] - public async Task Get( - Guid organizationId, - [FromQuery] string filter, - [FromQuery] int? count, - [FromQuery] int? startIndex) + var result = new ScimListResponseModel { - string nameFilter = null; - string externalIdFilter = null; - if (!string.IsNullOrWhiteSpace(filter)) - { - if (filter.StartsWith("displayName eq ")) - { - nameFilter = filter.Substring(15).Trim('"'); - } - else if (filter.StartsWith("externalId eq ")) - { - externalIdFilter = filter.Substring(14).Trim('"'); - } - } + Resources = groupList, + ItemsPerPage = count.GetValueOrDefault(groupList.Count), + TotalResults = totalResults, + StartIndex = startIndex.GetValueOrDefault(1), + }; + return new ObjectResult(result); + } - var groupList = new List(); - var groups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId); - var totalResults = 0; - if (!string.IsNullOrWhiteSpace(nameFilter)) - { - var group = groups.FirstOrDefault(g => g.Name == nameFilter); - if (group != null) - { - groupList.Add(new ScimGroupResponseModel(group)); - } - totalResults = groupList.Count; - } - else if (!string.IsNullOrWhiteSpace(externalIdFilter)) - { - var group = groups.FirstOrDefault(ou => ou.ExternalId == externalIdFilter); - if (group != null) - { - groupList.Add(new ScimGroupResponseModel(group)); - } - totalResults = groupList.Count; - } - else if (string.IsNullOrWhiteSpace(filter) && startIndex.HasValue && count.HasValue) - { - groupList = groups.OrderBy(g => g.Name) - .Skip(startIndex.Value - 1) - .Take(count.Value) - .Select(g => new ScimGroupResponseModel(g)) - .ToList(); - totalResults = groups.Count; - } - - var result = new ScimListResponseModel - { - Resources = groupList, - ItemsPerPage = count.GetValueOrDefault(groupList.Count), - TotalResults = totalResults, - StartIndex = startIndex.GetValueOrDefault(1), - }; - return new ObjectResult(result); + [HttpPost("")] + public async Task Post(Guid organizationId, [FromBody] ScimGroupRequestModel model) + { + if (string.IsNullOrWhiteSpace(model.DisplayName)) + { + return new BadRequestResult(); } - [HttpPost("")] - public async Task Post(Guid organizationId, [FromBody] ScimGroupRequestModel model) + var groups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId); + if (!string.IsNullOrWhiteSpace(model.ExternalId) && groups.Any(g => g.ExternalId == model.ExternalId)) { - if (string.IsNullOrWhiteSpace(model.DisplayName)) - { - return new BadRequestResult(); - } - - var groups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId); - if (!string.IsNullOrWhiteSpace(model.ExternalId) && groups.Any(g => g.ExternalId == model.ExternalId)) - { - return new ConflictResult(); - } - - var group = model.ToGroup(organizationId); - await _groupService.SaveAsync(group, null); - await UpdateGroupMembersAsync(group, model, true); - var response = new ScimGroupResponseModel(group); - return new CreatedResult(Url.Action(nameof(Get), new { group.OrganizationId, group.Id }), response); + return new ConflictResult(); } - [HttpPut("{id}")] - public async Task Put(Guid organizationId, Guid id, [FromBody] ScimGroupRequestModel model) - { - var group = await _groupRepository.GetByIdAsync(id); - if (group == null || group.OrganizationId != organizationId) - { - return new NotFoundObjectResult(new ScimErrorResponseModel - { - Status = 404, - Detail = "Group not found." - }); - } + var group = model.ToGroup(organizationId); + await _groupService.SaveAsync(group, null); + await UpdateGroupMembersAsync(group, model, true); + var response = new ScimGroupResponseModel(group); + return new CreatedResult(Url.Action(nameof(Get), new { group.OrganizationId, group.Id }), response); + } - group.Name = model.DisplayName; - await _groupService.SaveAsync(group); - await UpdateGroupMembersAsync(group, model, false); - return new ObjectResult(new ScimGroupResponseModel(group)); + [HttpPut("{id}")] + public async Task Put(Guid organizationId, Guid id, [FromBody] ScimGroupRequestModel model) + { + var group = await _groupRepository.GetByIdAsync(id); + if (group == null || group.OrganizationId != organizationId) + { + return new NotFoundObjectResult(new ScimErrorResponseModel + { + Status = 404, + Detail = "Group not found." + }); } - [HttpPatch("{id}")] - public async Task Patch(Guid organizationId, Guid id, [FromBody] ScimPatchModel model) - { - var group = await _groupRepository.GetByIdAsync(id); - if (group == null || group.OrganizationId != organizationId) - { - return new NotFoundObjectResult(new ScimErrorResponseModel - { - Status = 404, - Detail = "Group not found." - }); - } + group.Name = model.DisplayName; + await _groupService.SaveAsync(group); + await UpdateGroupMembersAsync(group, model, false); + return new ObjectResult(new ScimGroupResponseModel(group)); + } - var operationHandled = false; - foreach (var operation in model.Operations) + [HttpPatch("{id}")] + public async Task Patch(Guid organizationId, Guid id, [FromBody] ScimPatchModel model) + { + var group = await _groupRepository.GetByIdAsync(id); + if (group == null || group.OrganizationId != organizationId) + { + return new NotFoundObjectResult(new ScimErrorResponseModel { - // Replace operations - if (operation.Op?.ToLowerInvariant() == "replace") + Status = 404, + Detail = "Group not found." + }); + } + + var operationHandled = false; + foreach (var operation in model.Operations) + { + // Replace operations + if (operation.Op?.ToLowerInvariant() == "replace") + { + // Replace a list of members + if (operation.Path?.ToLowerInvariant() == "members") { - // Replace a list of members - if (operation.Path?.ToLowerInvariant() == "members") - { - var ids = GetOperationValueIds(operation.Value); - await _groupRepository.UpdateUsersAsync(group.Id, ids); - operationHandled = true; - } - // Replace group name from path - else if (operation.Path?.ToLowerInvariant() == "displayname") - { - group.Name = operation.Value.GetString(); - await _groupService.SaveAsync(group); - operationHandled = true; - } - // Replace group name from value object - else if (string.IsNullOrWhiteSpace(operation.Path) && - operation.Value.TryGetProperty("displayName", out var displayNameProperty)) - { - group.Name = displayNameProperty.GetString(); - await _groupService.SaveAsync(group); - operationHandled = true; - } - } - // Add a single member - else if (operation.Op?.ToLowerInvariant() == "add" && - !string.IsNullOrWhiteSpace(operation.Path) && - operation.Path.ToLowerInvariant().StartsWith("members[value eq ")) - { - var addId = GetOperationPathId(operation.Path); - if (addId.HasValue) - { - var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet(); - orgUserIds.Add(addId.Value); - await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds); - operationHandled = true; - } - } - // Add a list of members - else if (operation.Op?.ToLowerInvariant() == "add" && - operation.Path?.ToLowerInvariant() == "members") - { - var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet(); - foreach (var v in GetOperationValueIds(operation.Value)) - { - orgUserIds.Add(v); - } - await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds); + var ids = GetOperationValueIds(operation.Value); + await _groupRepository.UpdateUsersAsync(group.Id, ids); operationHandled = true; } - // Remove a single member - else if (operation.Op?.ToLowerInvariant() == "remove" && - !string.IsNullOrWhiteSpace(operation.Path) && - operation.Path.ToLowerInvariant().StartsWith("members[value eq ")) + // Replace group name from path + else if (operation.Path?.ToLowerInvariant() == "displayname") { - var removeId = GetOperationPathId(operation.Path); - if (removeId.HasValue) - { - await _groupService.DeleteUserAsync(group, removeId.Value); - operationHandled = true; - } + group.Name = operation.Value.GetString(); + await _groupService.SaveAsync(group); + operationHandled = true; } - // Remove a list of members - else if (operation.Op?.ToLowerInvariant() == "remove" && - operation.Path?.ToLowerInvariant() == "members") + // Replace group name from value object + else if (string.IsNullOrWhiteSpace(operation.Path) && + operation.Value.TryGetProperty("displayName", out var displayNameProperty)) + { + group.Name = displayNameProperty.GetString(); + await _groupService.SaveAsync(group); + operationHandled = true; + } + } + // Add a single member + else if (operation.Op?.ToLowerInvariant() == "add" && + !string.IsNullOrWhiteSpace(operation.Path) && + operation.Path.ToLowerInvariant().StartsWith("members[value eq ")) + { + var addId = GetOperationPathId(operation.Path); + if (addId.HasValue) { var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet(); - foreach (var v in GetOperationValueIds(operation.Value)) - { - orgUserIds.Remove(v); - } + orgUserIds.Add(addId.Value); await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds); operationHandled = true; } } - - if (!operationHandled) + // Add a list of members + else if (operation.Op?.ToLowerInvariant() == "add" && + operation.Path?.ToLowerInvariant() == "members") { - _logger.LogWarning("Group patch operation not handled: {0} : ", - string.Join(", ", model.Operations.Select(o => $"{o.Op}:{o.Path}"))); - } - - return new NoContentResult(); - } - - [HttpDelete("{id}")] - public async Task Delete(Guid organizationId, Guid id) - { - var group = await _groupRepository.GetByIdAsync(id); - if (group == null || group.OrganizationId != organizationId) - { - return new NotFoundObjectResult(new ScimErrorResponseModel + var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet(); + foreach (var v in GetOperationValueIds(operation.Value)) { - Status = 404, - Detail = "Group not found." - }); + orgUserIds.Add(v); + } + await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds); + operationHandled = true; } - await _groupService.DeleteAsync(group); - return new NoContentResult(); - } - - private List GetOperationValueIds(JsonElement objArray) - { - var ids = new List(); - foreach (var obj in objArray.EnumerateArray()) + // Remove a single member + else if (operation.Op?.ToLowerInvariant() == "remove" && + !string.IsNullOrWhiteSpace(operation.Path) && + operation.Path.ToLowerInvariant().StartsWith("members[value eq ")) { - if (obj.TryGetProperty("value", out var valueProperty)) + var removeId = GetOperationPathId(operation.Path); + if (removeId.HasValue) { - if (valueProperty.TryGetGuid(out var guid)) - { - ids.Add(guid); - } + await _groupService.DeleteUserAsync(group, removeId.Value); + operationHandled = true; } } - return ids; - } - - private Guid? GetOperationPathId(string path) - { - // Parse Guid from string like: members[value eq "{GUID}"}] - if (Guid.TryParse(path.Substring(18).Replace("\"]", string.Empty), out var id)) + // Remove a list of members + else if (operation.Op?.ToLowerInvariant() == "remove" && + operation.Path?.ToLowerInvariant() == "members") { - return id; - } - return null; - } - - private async Task UpdateGroupMembersAsync(Group group, ScimGroupRequestModel model, bool skipIfEmpty) - { - if (_scimContext.RequestScimProvider != Core.Enums.ScimProviderType.Okta) - { - return; - } - - if (model.Members == null) - { - return; - } - - var memberIds = new List(); - foreach (var id in model.Members.Select(i => i.Value)) - { - if (Guid.TryParse(id, out var guidId)) + var orgUserIds = (await _groupRepository.GetManyUserIdsByIdAsync(group.Id)).ToHashSet(); + foreach (var v in GetOperationValueIds(operation.Value)) { - memberIds.Add(guidId); + orgUserIds.Remove(v); + } + await _groupRepository.UpdateUsersAsync(group.Id, orgUserIds); + operationHandled = true; + } + } + + if (!operationHandled) + { + _logger.LogWarning("Group patch operation not handled: {0} : ", + string.Join(", ", model.Operations.Select(o => $"{o.Op}:{o.Path}"))); + } + + return new NoContentResult(); + } + + [HttpDelete("{id}")] + public async Task Delete(Guid organizationId, Guid id) + { + var group = await _groupRepository.GetByIdAsync(id); + if (group == null || group.OrganizationId != organizationId) + { + return new NotFoundObjectResult(new ScimErrorResponseModel + { + Status = 404, + Detail = "Group not found." + }); + } + await _groupService.DeleteAsync(group); + return new NoContentResult(); + } + + private List GetOperationValueIds(JsonElement objArray) + { + var ids = new List(); + foreach (var obj in objArray.EnumerateArray()) + { + if (obj.TryGetProperty("value", out var valueProperty)) + { + if (valueProperty.TryGetGuid(out var guid)) + { + ids.Add(guid); } } - - if (!memberIds.Any() && skipIfEmpty) - { - return; - } - - await _groupRepository.UpdateUsersAsync(group.Id, memberIds); } + return ids; + } + + private Guid? GetOperationPathId(string path) + { + // Parse Guid from string like: members[value eq "{GUID}"}] + if (Guid.TryParse(path.Substring(18).Replace("\"]", string.Empty), out var id)) + { + return id; + } + return null; + } + + private async Task UpdateGroupMembersAsync(Group group, ScimGroupRequestModel model, bool skipIfEmpty) + { + if (_scimContext.RequestScimProvider != Core.Enums.ScimProviderType.Okta) + { + return; + } + + if (model.Members == null) + { + return; + } + + var memberIds = new List(); + foreach (var id in model.Members.Select(i => i.Value)) + { + if (Guid.TryParse(id, out var guidId)) + { + memberIds.Add(guidId); + } + } + + if (!memberIds.Any() && skipIfEmpty) + { + return; + } + + await _groupRepository.UpdateUsersAsync(group.Id, memberIds); } } diff --git a/bitwarden_license/src/Scim/Controllers/v2/UsersController.cs b/bitwarden_license/src/Scim/Controllers/v2/UsersController.cs index ff650c64e..7291be7f6 100644 --- a/bitwarden_license/src/Scim/Controllers/v2/UsersController.cs +++ b/bitwarden_license/src/Scim/Controllers/v2/UsersController.cs @@ -9,287 +9,286 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Options; -namespace Bit.Scim.Controllers.v2 +namespace Bit.Scim.Controllers.v2; + +[Authorize("Scim")] +[Route("v2/{organizationId}/users")] +public class UsersController : Controller { - [Authorize("Scim")] - [Route("v2/{organizationId}/users")] - public class UsersController : Controller + private readonly IUserService _userService; + private readonly IUserRepository _userRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IOrganizationService _organizationService; + private readonly IScimContext _scimContext; + private readonly ScimSettings _scimSettings; + private readonly ILogger _logger; + + public UsersController( + IUserService userService, + IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationService organizationService, + IScimContext scimContext, + IOptions scimSettings, + ILogger logger) { - private readonly IUserService _userService; - private readonly IUserRepository _userRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IOrganizationService _organizationService; - private readonly IScimContext _scimContext; - private readonly ScimSettings _scimSettings; - private readonly ILogger _logger; + _userService = userService; + _userRepository = userRepository; + _organizationUserRepository = organizationUserRepository; + _organizationService = organizationService; + _scimContext = scimContext; + _scimSettings = scimSettings?.Value; + _logger = logger; + } - public UsersController( - IUserService userService, - IUserRepository userRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationService organizationService, - IScimContext scimContext, - IOptions scimSettings, - ILogger logger) + [HttpGet("{id}")] + public async Task Get(Guid organizationId, Guid id) + { + var orgUser = await _organizationUserRepository.GetDetailsByIdAsync(id); + if (orgUser == null || orgUser.OrganizationId != organizationId) { - _userService = userService; - _userRepository = userRepository; - _organizationUserRepository = organizationUserRepository; - _organizationService = organizationService; - _scimContext = scimContext; - _scimSettings = scimSettings?.Value; - _logger = logger; - } - - [HttpGet("{id}")] - public async Task Get(Guid organizationId, Guid id) - { - var orgUser = await _organizationUserRepository.GetDetailsByIdAsync(id); - if (orgUser == null || orgUser.OrganizationId != organizationId) + return new NotFoundObjectResult(new ScimErrorResponseModel { - return new NotFoundObjectResult(new ScimErrorResponseModel + Status = 404, + Detail = "User not found." + }); + } + return new ObjectResult(new ScimUserResponseModel(orgUser)); + } + + [HttpGet("")] + public async Task Get( + Guid organizationId, + [FromQuery] string filter, + [FromQuery] int? count, + [FromQuery] int? startIndex) + { + string emailFilter = null; + string usernameFilter = null; + string externalIdFilter = null; + if (!string.IsNullOrWhiteSpace(filter)) + { + if (filter.StartsWith("userName eq ")) + { + usernameFilter = filter.Substring(12).Trim('"').ToLowerInvariant(); + if (usernameFilter.Contains("@")) { - Status = 404, - Detail = "User not found." - }); + emailFilter = usernameFilter; + } + } + else if (filter.StartsWith("externalId eq ")) + { + externalIdFilter = filter.Substring(14).Trim('"'); } - return new ObjectResult(new ScimUserResponseModel(orgUser)); } - [HttpGet("")] - public async Task Get( - Guid organizationId, - [FromQuery] string filter, - [FromQuery] int? count, - [FromQuery] int? startIndex) + var userList = new List { }; + var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); + var totalResults = 0; + if (!string.IsNullOrWhiteSpace(emailFilter)) { - string emailFilter = null; - string usernameFilter = null; - string externalIdFilter = null; - if (!string.IsNullOrWhiteSpace(filter)) + var orgUser = orgUsers.FirstOrDefault(ou => ou.Email.ToLowerInvariant() == emailFilter); + if (orgUser != null) { - if (filter.StartsWith("userName eq ")) - { - usernameFilter = filter.Substring(12).Trim('"').ToLowerInvariant(); - if (usernameFilter.Contains("@")) + userList.Add(new ScimUserResponseModel(orgUser)); + } + totalResults = userList.Count; + } + else if (!string.IsNullOrWhiteSpace(externalIdFilter)) + { + var orgUser = orgUsers.FirstOrDefault(ou => ou.ExternalId == externalIdFilter); + if (orgUser != null) + { + userList.Add(new ScimUserResponseModel(orgUser)); + } + totalResults = userList.Count; + } + else if (string.IsNullOrWhiteSpace(filter) && startIndex.HasValue && count.HasValue) + { + userList = orgUsers.OrderBy(ou => ou.Email) + .Skip(startIndex.Value - 1) + .Take(count.Value) + .Select(ou => new ScimUserResponseModel(ou)) + .ToList(); + totalResults = orgUsers.Count; + } + + var result = new ScimListResponseModel + { + Resources = userList, + ItemsPerPage = count.GetValueOrDefault(userList.Count), + TotalResults = totalResults, + StartIndex = startIndex.GetValueOrDefault(1), + }; + return new ObjectResult(result); + } + + [HttpPost("")] + public async Task Post(Guid organizationId, [FromBody] ScimUserRequestModel model) + { + var email = model.PrimaryEmail?.ToLowerInvariant(); + if (string.IsNullOrWhiteSpace(email)) + { + switch (_scimContext.RequestScimProvider) + { + case ScimProviderType.AzureAd: + email = model.UserName?.ToLowerInvariant(); + break; + default: + email = model.WorkEmail?.ToLowerInvariant(); + if (string.IsNullOrWhiteSpace(email)) { - emailFilter = usernameFilter; + email = model.Emails?.FirstOrDefault()?.Value?.ToLowerInvariant(); + } + break; + } + } + + if (string.IsNullOrWhiteSpace(email) || !model.Active) + { + return new BadRequestResult(); + } + + var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); + var orgUserByEmail = orgUsers.FirstOrDefault(ou => ou.Email?.ToLowerInvariant() == email); + if (orgUserByEmail != null) + { + return new ConflictResult(); + } + + string externalId = null; + if (!string.IsNullOrWhiteSpace(model.ExternalId)) + { + externalId = model.ExternalId; + } + else if (!string.IsNullOrWhiteSpace(model.UserName)) + { + externalId = model.UserName; + } + else + { + externalId = CoreHelpers.RandomString(15); + } + + var orgUserByExternalId = orgUsers.FirstOrDefault(ou => ou.ExternalId == externalId); + if (orgUserByExternalId != null) + { + return new ConflictResult(); + } + + var invitedOrgUser = await _organizationService.InviteUserAsync(organizationId, null, email, + OrganizationUserType.User, false, externalId, new List()); + var orgUser = await _organizationUserRepository.GetDetailsByIdAsync(invitedOrgUser.Id); + var response = new ScimUserResponseModel(orgUser); + return new CreatedResult(Url.Action(nameof(Get), new { orgUser.OrganizationId, orgUser.Id }), response); + } + + [HttpPut("{id}")] + public async Task Put(Guid organizationId, Guid id, [FromBody] ScimUserRequestModel model) + { + var orgUser = await _organizationUserRepository.GetByIdAsync(id); + if (orgUser == null || orgUser.OrganizationId != organizationId) + { + return new NotFoundObjectResult(new ScimErrorResponseModel + { + Status = 404, + Detail = "User not found." + }); + } + + if (model.Active && orgUser.Status == OrganizationUserStatusType.Revoked) + { + await _organizationService.RestoreUserAsync(orgUser, null, _userService); + } + else if (!model.Active && orgUser.Status != OrganizationUserStatusType.Revoked) + { + await _organizationService.RevokeUserAsync(orgUser, null); + } + + // Have to get full details object for response model + var orgUserDetails = await _organizationUserRepository.GetDetailsByIdAsync(id); + return new ObjectResult(new ScimUserResponseModel(orgUserDetails)); + } + + [HttpPatch("{id}")] + public async Task Patch(Guid organizationId, Guid id, [FromBody] ScimPatchModel model) + { + var orgUser = await _organizationUserRepository.GetByIdAsync(id); + if (orgUser == null || orgUser.OrganizationId != organizationId) + { + return new NotFoundObjectResult(new ScimErrorResponseModel + { + Status = 404, + Detail = "User not found." + }); + } + + var operationHandled = false; + foreach (var operation in model.Operations) + { + // Replace operations + if (operation.Op?.ToLowerInvariant() == "replace") + { + // Active from path + if (operation.Path?.ToLowerInvariant() == "active") + { + var active = operation.Value.ToString()?.ToLowerInvariant(); + var handled = await HandleActiveOperationAsync(orgUser, active == "true"); + if (!operationHandled) + { + operationHandled = handled; } } - else if (filter.StartsWith("externalId eq ")) + // Active from value object + else if (string.IsNullOrWhiteSpace(operation.Path) && + operation.Value.TryGetProperty("active", out var activeProperty)) { - externalIdFilter = filter.Substring(14).Trim('"'); - } - } - - var userList = new List { }; - var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); - var totalResults = 0; - if (!string.IsNullOrWhiteSpace(emailFilter)) - { - var orgUser = orgUsers.FirstOrDefault(ou => ou.Email.ToLowerInvariant() == emailFilter); - if (orgUser != null) - { - userList.Add(new ScimUserResponseModel(orgUser)); - } - totalResults = userList.Count; - } - else if (!string.IsNullOrWhiteSpace(externalIdFilter)) - { - var orgUser = orgUsers.FirstOrDefault(ou => ou.ExternalId == externalIdFilter); - if (orgUser != null) - { - userList.Add(new ScimUserResponseModel(orgUser)); - } - totalResults = userList.Count; - } - else if (string.IsNullOrWhiteSpace(filter) && startIndex.HasValue && count.HasValue) - { - userList = orgUsers.OrderBy(ou => ou.Email) - .Skip(startIndex.Value - 1) - .Take(count.Value) - .Select(ou => new ScimUserResponseModel(ou)) - .ToList(); - totalResults = orgUsers.Count; - } - - var result = new ScimListResponseModel - { - Resources = userList, - ItemsPerPage = count.GetValueOrDefault(userList.Count), - TotalResults = totalResults, - StartIndex = startIndex.GetValueOrDefault(1), - }; - return new ObjectResult(result); - } - - [HttpPost("")] - public async Task Post(Guid organizationId, [FromBody] ScimUserRequestModel model) - { - var email = model.PrimaryEmail?.ToLowerInvariant(); - if (string.IsNullOrWhiteSpace(email)) - { - switch (_scimContext.RequestScimProvider) - { - case ScimProviderType.AzureAd: - email = model.UserName?.ToLowerInvariant(); - break; - default: - email = model.WorkEmail?.ToLowerInvariant(); - if (string.IsNullOrWhiteSpace(email)) - { - email = model.Emails?.FirstOrDefault()?.Value?.ToLowerInvariant(); - } - break; - } - } - - if (string.IsNullOrWhiteSpace(email) || !model.Active) - { - return new BadRequestResult(); - } - - var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); - var orgUserByEmail = orgUsers.FirstOrDefault(ou => ou.Email?.ToLowerInvariant() == email); - if (orgUserByEmail != null) - { - return new ConflictResult(); - } - - string externalId = null; - if (!string.IsNullOrWhiteSpace(model.ExternalId)) - { - externalId = model.ExternalId; - } - else if (!string.IsNullOrWhiteSpace(model.UserName)) - { - externalId = model.UserName; - } - else - { - externalId = CoreHelpers.RandomString(15); - } - - var orgUserByExternalId = orgUsers.FirstOrDefault(ou => ou.ExternalId == externalId); - if (orgUserByExternalId != null) - { - return new ConflictResult(); - } - - var invitedOrgUser = await _organizationService.InviteUserAsync(organizationId, null, email, - OrganizationUserType.User, false, externalId, new List()); - var orgUser = await _organizationUserRepository.GetDetailsByIdAsync(invitedOrgUser.Id); - var response = new ScimUserResponseModel(orgUser); - return new CreatedResult(Url.Action(nameof(Get), new { orgUser.OrganizationId, orgUser.Id }), response); - } - - [HttpPut("{id}")] - public async Task Put(Guid organizationId, Guid id, [FromBody] ScimUserRequestModel model) - { - var orgUser = await _organizationUserRepository.GetByIdAsync(id); - if (orgUser == null || orgUser.OrganizationId != organizationId) - { - return new NotFoundObjectResult(new ScimErrorResponseModel - { - Status = 404, - Detail = "User not found." - }); - } - - if (model.Active && orgUser.Status == OrganizationUserStatusType.Revoked) - { - await _organizationService.RestoreUserAsync(orgUser, null, _userService); - } - else if (!model.Active && orgUser.Status != OrganizationUserStatusType.Revoked) - { - await _organizationService.RevokeUserAsync(orgUser, null); - } - - // Have to get full details object for response model - var orgUserDetails = await _organizationUserRepository.GetDetailsByIdAsync(id); - return new ObjectResult(new ScimUserResponseModel(orgUserDetails)); - } - - [HttpPatch("{id}")] - public async Task Patch(Guid organizationId, Guid id, [FromBody] ScimPatchModel model) - { - var orgUser = await _organizationUserRepository.GetByIdAsync(id); - if (orgUser == null || orgUser.OrganizationId != organizationId) - { - return new NotFoundObjectResult(new ScimErrorResponseModel - { - Status = 404, - Detail = "User not found." - }); - } - - var operationHandled = false; - foreach (var operation in model.Operations) - { - // Replace operations - if (operation.Op?.ToLowerInvariant() == "replace") - { - // Active from path - if (operation.Path?.ToLowerInvariant() == "active") + var handled = await HandleActiveOperationAsync(orgUser, activeProperty.GetBoolean()); + if (!operationHandled) { - var active = operation.Value.ToString()?.ToLowerInvariant(); - var handled = await HandleActiveOperationAsync(orgUser, active == "true"); - if (!operationHandled) - { - operationHandled = handled; - } - } - // Active from value object - else if (string.IsNullOrWhiteSpace(operation.Path) && - operation.Value.TryGetProperty("active", out var activeProperty)) - { - var handled = await HandleActiveOperationAsync(orgUser, activeProperty.GetBoolean()); - if (!operationHandled) - { - operationHandled = handled; - } + operationHandled = handled; } } } - - if (!operationHandled) - { - _logger.LogWarning("User patch operation not handled: {operation} : ", - string.Join(", ", model.Operations.Select(o => $"{o.Op}:{o.Path}"))); - } - - return new NoContentResult(); } - [HttpDelete("{id}")] - public async Task Delete(Guid organizationId, Guid id, [FromBody] ScimUserRequestModel model) + if (!operationHandled) { - var orgUser = await _organizationUserRepository.GetByIdAsync(id); - if (orgUser == null || orgUser.OrganizationId != organizationId) - { - return new NotFoundObjectResult(new ScimErrorResponseModel - { - Status = 404, - Detail = "User not found." - }); - } - await _organizationService.DeleteUserAsync(organizationId, id, null); - return new NoContentResult(); + _logger.LogWarning("User patch operation not handled: {operation} : ", + string.Join(", ", model.Operations.Select(o => $"{o.Op}:{o.Path}"))); } - private async Task HandleActiveOperationAsync(Core.Entities.OrganizationUser orgUser, bool active) + return new NoContentResult(); + } + + [HttpDelete("{id}")] + public async Task Delete(Guid organizationId, Guid id, [FromBody] ScimUserRequestModel model) + { + var orgUser = await _organizationUserRepository.GetByIdAsync(id); + if (orgUser == null || orgUser.OrganizationId != organizationId) { - if (active && orgUser.Status == OrganizationUserStatusType.Revoked) + return new NotFoundObjectResult(new ScimErrorResponseModel { - await _organizationService.RestoreUserAsync(orgUser, null, _userService); - return true; - } - else if (!active && orgUser.Status != OrganizationUserStatusType.Revoked) - { - await _organizationService.RevokeUserAsync(orgUser, null); - return true; - } - return false; + Status = 404, + Detail = "User not found." + }); } + await _organizationService.DeleteUserAsync(organizationId, id, null); + return new NoContentResult(); + } + + private async Task HandleActiveOperationAsync(Core.Entities.OrganizationUser orgUser, bool active) + { + if (active && orgUser.Status == OrganizationUserStatusType.Revoked) + { + await _organizationService.RestoreUserAsync(orgUser, null, _userService); + return true; + } + else if (!active && orgUser.Status != OrganizationUserStatusType.Revoked) + { + await _organizationService.RevokeUserAsync(orgUser, null); + return true; + } + return false; } } diff --git a/bitwarden_license/src/Scim/Models/BaseScimGroupModel.cs b/bitwarden_license/src/Scim/Models/BaseScimGroupModel.cs index 06d57bfad..150885fb5 100644 --- a/bitwarden_license/src/Scim/Models/BaseScimGroupModel.cs +++ b/bitwarden_license/src/Scim/Models/BaseScimGroupModel.cs @@ -1,18 +1,17 @@ using Bit.Scim.Utilities; -namespace Bit.Scim.Models -{ - public abstract class BaseScimGroupModel : BaseScimModel - { - public BaseScimGroupModel(bool initSchema = false) - { - if (initSchema) - { - Schemas = new List { ScimConstants.Scim2SchemaGroup }; - } - } +namespace Bit.Scim.Models; - public string DisplayName { get; set; } - public string ExternalId { get; set; } +public abstract class BaseScimGroupModel : BaseScimModel +{ + public BaseScimGroupModel(bool initSchema = false) + { + if (initSchema) + { + Schemas = new List { ScimConstants.Scim2SchemaGroup }; + } } + + public string DisplayName { get; set; } + public string ExternalId { get; set; } } diff --git a/bitwarden_license/src/Scim/Models/BaseScimModel.cs b/bitwarden_license/src/Scim/Models/BaseScimModel.cs index a2a071786..8f3adfbe4 100644 --- a/bitwarden_license/src/Scim/Models/BaseScimModel.cs +++ b/bitwarden_license/src/Scim/Models/BaseScimModel.cs @@ -1,15 +1,14 @@ -namespace Bit.Scim.Models +namespace Bit.Scim.Models; + +public abstract class BaseScimModel { - public abstract class BaseScimModel + public BaseScimModel() + { } + + public BaseScimModel(string schema) { - public BaseScimModel() - { } - - public BaseScimModel(string schema) - { - Schemas = new List { schema }; - } - - public List Schemas { get; set; } + Schemas = new List { schema }; } + + public List Schemas { get; set; } } diff --git a/bitwarden_license/src/Scim/Models/BaseScimUserModel.cs b/bitwarden_license/src/Scim/Models/BaseScimUserModel.cs index 0af9e652b..d3c69d574 100644 --- a/bitwarden_license/src/Scim/Models/BaseScimUserModel.cs +++ b/bitwarden_license/src/Scim/Models/BaseScimUserModel.cs @@ -1,56 +1,55 @@ using Bit.Scim.Utilities; -namespace Bit.Scim.Models +namespace Bit.Scim.Models; + +public abstract class BaseScimUserModel : BaseScimModel { - public abstract class BaseScimUserModel : BaseScimModel + public BaseScimUserModel(bool initSchema = false) { - public BaseScimUserModel(bool initSchema = false) + if (initSchema) { - if (initSchema) - { - Schemas = new List { ScimConstants.Scim2SchemaUser }; - } - } - - public string UserName { get; set; } - public NameModel Name { get; set; } - public List Emails { get; set; } - public string PrimaryEmail => Emails?.FirstOrDefault(e => e.Primary)?.Value; - public string WorkEmail => Emails?.FirstOrDefault(e => e.Type == "work")?.Value; - public string DisplayName { get; set; } - public bool Active { get; set; } - public List Groups { get; set; } - public string ExternalId { get; set; } - - public class NameModel - { - public NameModel() { } - - public NameModel(string name) - { - Formatted = name; - } - - public string Formatted { get; set; } - public string GivenName { get; set; } - public string MiddleName { get; set; } - public string FamilyName { get; set; } - } - - public class EmailModel - { - public EmailModel() { } - - public EmailModel(string email) - { - Primary = true; - Value = email; - Type = "work"; - } - - public bool Primary { get; set; } - public string Value { get; set; } - public string Type { get; set; } + Schemas = new List { ScimConstants.Scim2SchemaUser }; } } + + public string UserName { get; set; } + public NameModel Name { get; set; } + public List Emails { get; set; } + public string PrimaryEmail => Emails?.FirstOrDefault(e => e.Primary)?.Value; + public string WorkEmail => Emails?.FirstOrDefault(e => e.Type == "work")?.Value; + public string DisplayName { get; set; } + public bool Active { get; set; } + public List Groups { get; set; } + public string ExternalId { get; set; } + + public class NameModel + { + public NameModel() { } + + public NameModel(string name) + { + Formatted = name; + } + + public string Formatted { get; set; } + public string GivenName { get; set; } + public string MiddleName { get; set; } + public string FamilyName { get; set; } + } + + public class EmailModel + { + public EmailModel() { } + + public EmailModel(string email) + { + Primary = true; + Value = email; + Type = "work"; + } + + public bool Primary { get; set; } + public string Value { get; set; } + public string Type { get; set; } + } } diff --git a/bitwarden_license/src/Scim/Models/ScimErrorResponseModel.cs b/bitwarden_license/src/Scim/Models/ScimErrorResponseModel.cs index 6055001f5..d1dce35ef 100644 --- a/bitwarden_license/src/Scim/Models/ScimErrorResponseModel.cs +++ b/bitwarden_license/src/Scim/Models/ScimErrorResponseModel.cs @@ -1,14 +1,13 @@ using Bit.Scim.Utilities; -namespace Bit.Scim.Models -{ - public class ScimErrorResponseModel : BaseScimModel - { - public ScimErrorResponseModel() - : base(ScimConstants.Scim2SchemaError) - { } +namespace Bit.Scim.Models; - public string Detail { get; set; } - public int Status { get; set; } - } +public class ScimErrorResponseModel : BaseScimModel +{ + public ScimErrorResponseModel() + : base(ScimConstants.Scim2SchemaError) + { } + + public string Detail { get; set; } + public int Status { get; set; } } diff --git a/bitwarden_license/src/Scim/Models/ScimGroupRequestModel.cs b/bitwarden_license/src/Scim/Models/ScimGroupRequestModel.cs index 6de96655b..ac99eca2e 100644 --- a/bitwarden_license/src/Scim/Models/ScimGroupRequestModel.cs +++ b/bitwarden_license/src/Scim/Models/ScimGroupRequestModel.cs @@ -1,31 +1,30 @@ using Bit.Core.Entities; using Bit.Core.Utilities; -namespace Bit.Scim.Models +namespace Bit.Scim.Models; + +public class ScimGroupRequestModel : BaseScimGroupModel { - public class ScimGroupRequestModel : BaseScimGroupModel + public ScimGroupRequestModel() + : base(false) + { } + + public Group ToGroup(Guid organizationId) { - public ScimGroupRequestModel() - : base(false) - { } - - public Group ToGroup(Guid organizationId) + var externalId = string.IsNullOrWhiteSpace(ExternalId) ? CoreHelpers.RandomString(15) : ExternalId; + return new Group { - var externalId = string.IsNullOrWhiteSpace(ExternalId) ? CoreHelpers.RandomString(15) : ExternalId; - return new Group - { - Name = DisplayName, - ExternalId = externalId, - OrganizationId = organizationId - }; - } + Name = DisplayName, + ExternalId = externalId, + OrganizationId = organizationId + }; + } - public List Members { get; set; } + public List Members { get; set; } - public class GroupMembersModel - { - public string Value { get; set; } - public string Display { get; set; } - } + public class GroupMembersModel + { + public string Value { get; set; } + public string Display { get; set; } } } diff --git a/bitwarden_license/src/Scim/Models/ScimGroupResponseModel.cs b/bitwarden_license/src/Scim/Models/ScimGroupResponseModel.cs index df5d9b22a..d5bd64a32 100644 --- a/bitwarden_license/src/Scim/Models/ScimGroupResponseModel.cs +++ b/bitwarden_license/src/Scim/Models/ScimGroupResponseModel.cs @@ -1,26 +1,25 @@ using Bit.Core.Entities; -namespace Bit.Scim.Models +namespace Bit.Scim.Models; + +public class ScimGroupResponseModel : BaseScimGroupModel { - public class ScimGroupResponseModel : BaseScimGroupModel + public ScimGroupResponseModel() + : base(true) { - public ScimGroupResponseModel() - : base(true) - { - Meta = new ScimMetaModel("Group"); - } - - public ScimGroupResponseModel(Group group) - : this() - { - Id = group.Id.ToString(); - DisplayName = group.Name; - ExternalId = group.ExternalId; - Meta.Created = group.CreationDate; - Meta.LastModified = group.RevisionDate; - } - - public string Id { get; set; } - public ScimMetaModel Meta { get; private set; } + Meta = new ScimMetaModel("Group"); } + + public ScimGroupResponseModel(Group group) + : this() + { + Id = group.Id.ToString(); + DisplayName = group.Name; + ExternalId = group.ExternalId; + Meta.Created = group.CreationDate; + Meta.LastModified = group.RevisionDate; + } + + public string Id { get; set; } + public ScimMetaModel Meta { get; private set; } } diff --git a/bitwarden_license/src/Scim/Models/ScimListResponseModel.cs b/bitwarden_license/src/Scim/Models/ScimListResponseModel.cs index e7b952168..77ab52356 100644 --- a/bitwarden_license/src/Scim/Models/ScimListResponseModel.cs +++ b/bitwarden_license/src/Scim/Models/ScimListResponseModel.cs @@ -1,16 +1,15 @@ using Bit.Scim.Utilities; -namespace Bit.Scim.Models -{ - public class ScimListResponseModel : BaseScimModel - { - public ScimListResponseModel() - : base(ScimConstants.Scim2SchemaListResponse) - { } +namespace Bit.Scim.Models; - public int TotalResults { get; set; } - public int StartIndex { get; set; } - public int ItemsPerPage { get; set; } - public List Resources { get; set; } - } +public class ScimListResponseModel : BaseScimModel +{ + public ScimListResponseModel() + : base(ScimConstants.Scim2SchemaListResponse) + { } + + public int TotalResults { get; set; } + public int StartIndex { get; set; } + public int ItemsPerPage { get; set; } + public List Resources { get; set; } } diff --git a/bitwarden_license/src/Scim/Models/ScimMetaModel.cs b/bitwarden_license/src/Scim/Models/ScimMetaModel.cs index f3d95f5f3..862c054b7 100644 --- a/bitwarden_license/src/Scim/Models/ScimMetaModel.cs +++ b/bitwarden_license/src/Scim/Models/ScimMetaModel.cs @@ -1,14 +1,13 @@ -namespace Bit.Scim.Models -{ - public class ScimMetaModel - { - public ScimMetaModel(string resourceType) - { - ResourceType = resourceType; - } +namespace Bit.Scim.Models; - public string ResourceType { get; set; } - public DateTime? Created { get; set; } - public DateTime? LastModified { get; set; } +public class ScimMetaModel +{ + public ScimMetaModel(string resourceType) + { + ResourceType = resourceType; } + + public string ResourceType { get; set; } + public DateTime? Created { get; set; } + public DateTime? LastModified { get; set; } } diff --git a/bitwarden_license/src/Scim/Models/ScimPatchModel.cs b/bitwarden_license/src/Scim/Models/ScimPatchModel.cs index d42126765..6707ced85 100644 --- a/bitwarden_license/src/Scim/Models/ScimPatchModel.cs +++ b/bitwarden_license/src/Scim/Models/ScimPatchModel.cs @@ -1,19 +1,18 @@ using System.Text.Json; -namespace Bit.Scim.Models +namespace Bit.Scim.Models; + +public class ScimPatchModel : BaseScimModel { - public class ScimPatchModel : BaseScimModel + public ScimPatchModel() + : base() { } + + public List Operations { get; set; } + + public class OperationModel { - public ScimPatchModel() - : base() { } - - public List Operations { get; set; } - - public class OperationModel - { - public string Op { get; set; } - public string Path { get; set; } - public JsonElement Value { get; set; } - } + public string Op { get; set; } + public string Path { get; set; } + public JsonElement Value { get; set; } } } diff --git a/bitwarden_license/src/Scim/Models/ScimUserRequestModel.cs b/bitwarden_license/src/Scim/Models/ScimUserRequestModel.cs index 17f5e8593..a489e03ad 100644 --- a/bitwarden_license/src/Scim/Models/ScimUserRequestModel.cs +++ b/bitwarden_license/src/Scim/Models/ScimUserRequestModel.cs @@ -1,9 +1,8 @@ -namespace Bit.Scim.Models +namespace Bit.Scim.Models; + +public class ScimUserRequestModel : BaseScimUserModel { - public class ScimUserRequestModel : BaseScimUserModel - { - public ScimUserRequestModel() - : base(false) - { } - } + public ScimUserRequestModel() + : base(false) + { } } diff --git a/bitwarden_license/src/Scim/Models/ScimUserResponseModel.cs b/bitwarden_license/src/Scim/Models/ScimUserResponseModel.cs index 6f9650661..95d5184da 100644 --- a/bitwarden_license/src/Scim/Models/ScimUserResponseModel.cs +++ b/bitwarden_license/src/Scim/Models/ScimUserResponseModel.cs @@ -1,29 +1,28 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Scim.Models +namespace Bit.Scim.Models; + +public class ScimUserResponseModel : BaseScimUserModel { - public class ScimUserResponseModel : BaseScimUserModel + public ScimUserResponseModel() + : base(true) { - public ScimUserResponseModel() - : base(true) - { - Meta = new ScimMetaModel("User"); - Groups = new List(); - } - - public ScimUserResponseModel(OrganizationUserUserDetails orgUser) - : this() - { - Id = orgUser.Id.ToString(); - ExternalId = orgUser.ExternalId; - UserName = orgUser.Email; - DisplayName = orgUser.Name; - Emails = new List { new EmailModel(orgUser.Email) }; - Name = new NameModel(orgUser.Name); - Active = orgUser.Status != Core.Enums.OrganizationUserStatusType.Revoked; - } - - public string Id { get; set; } - public ScimMetaModel Meta { get; private set; } + Meta = new ScimMetaModel("User"); + Groups = new List(); } + + public ScimUserResponseModel(OrganizationUserUserDetails orgUser) + : this() + { + Id = orgUser.Id.ToString(); + ExternalId = orgUser.ExternalId; + UserName = orgUser.Email; + DisplayName = orgUser.Name; + Emails = new List { new EmailModel(orgUser.Email) }; + Name = new NameModel(orgUser.Name); + Active = orgUser.Status != Core.Enums.OrganizationUserStatusType.Revoked; + } + + public string Id { get; set; } + public ScimMetaModel Meta { get; private set; } } diff --git a/bitwarden_license/src/Scim/Program.cs b/bitwarden_license/src/Scim/Program.cs index f8d6cb15b..48d5711e1 100644 --- a/bitwarden_license/src/Scim/Program.cs +++ b/bitwarden_license/src/Scim/Program.cs @@ -1,34 +1,33 @@ using Bit.Core.Utilities; using Serilog.Events; -namespace Bit.Scim +namespace Bit.Scim; + +public class Program { - public class Program + public static void Main(string[] args) { - public static void Main(string[] args) - { - Host - .CreateDefaultBuilder(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => + Host + .CreateDefaultBuilder(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => + { + var context = e.Properties["SourceContext"].ToString(); + + if (e.Properties.ContainsKey("RequestPath") && + !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && + (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) { - var context = e.Properties["SourceContext"].ToString(); + return false; + } - if (e.Properties.ContainsKey("RequestPath") && - !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) - { - return false; - } - - return e.Level >= LogEventLevel.Warning; - })); - }) - .Build() - .Run(); - } + return e.Level >= LogEventLevel.Warning; + })); + }) + .Build() + .Run(); } } diff --git a/bitwarden_license/src/Scim/ScimSettings.cs b/bitwarden_license/src/Scim/ScimSettings.cs index 5c25dbf37..ef4ebfb50 100644 --- a/bitwarden_license/src/Scim/ScimSettings.cs +++ b/bitwarden_license/src/Scim/ScimSettings.cs @@ -1,6 +1,5 @@ -namespace Bit.Scim +namespace Bit.Scim; + +public class ScimSettings { - public class ScimSettings - { - } } diff --git a/bitwarden_license/src/Scim/Startup.cs b/bitwarden_license/src/Scim/Startup.cs index daa5752e9..65e9220a7 100644 --- a/bitwarden_license/src/Scim/Startup.cs +++ b/bitwarden_license/src/Scim/Startup.cs @@ -9,108 +9,107 @@ using IdentityModel; using Microsoft.Extensions.DependencyInjection.Extensions; using Stripe; -namespace Bit.Scim +namespace Bit.Scim; + +public class Startup { - public class Startup + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Configuration = configuration; + Environment = env; + } + + public IConfiguration Configuration { get; } + public IWebHostEnvironment Environment { get; set; } + + public void ConfigureServices(IServiceCollection services) + { + // Options + services.AddOptions(); + + // Settings + var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); + services.Configure(Configuration.GetSection("ScimSettings")); + + // Data Protection + services.AddCustomDataProtectionServices(Environment, globalSettings); + + // Stripe Billing + StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; + StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; + + // Repositories + services.AddSqlServerRepositories(globalSettings); + + // Context + services.AddScoped(); + services.AddScoped(); + + // Authentication + services.AddAuthentication(ApiKeyAuthenticationOptions.DefaultScheme) + .AddScheme( + ApiKeyAuthenticationOptions.DefaultScheme, null); + + services.AddAuthorization(config => { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - Configuration = configuration; - Environment = env; - } - - public IConfiguration Configuration { get; } - public IWebHostEnvironment Environment { get; set; } - - public void ConfigureServices(IServiceCollection services) - { - // Options - services.AddOptions(); - - // Settings - var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); - services.Configure(Configuration.GetSection("ScimSettings")); - - // Data Protection - services.AddCustomDataProtectionServices(Environment, globalSettings); - - // Stripe Billing - StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; - StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; - - // Repositories - services.AddSqlServerRepositories(globalSettings); - - // Context - services.AddScoped(); - services.AddScoped(); - - // Authentication - services.AddAuthentication(ApiKeyAuthenticationOptions.DefaultScheme) - .AddScheme( - ApiKeyAuthenticationOptions.DefaultScheme, null); - - services.AddAuthorization(config => + config.AddPolicy("Scim", policy => { - config.AddPolicy("Scim", policy => - { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.Scope, "api.scim"); - }); + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.Scope, "api.scim"); }); + }); - // Identity - services.AddCustomIdentityServices(globalSettings); + // Identity + services.AddCustomIdentityServices(globalSettings); - // Services - services.AddBaseServices(globalSettings); - services.AddDefaultServices(globalSettings); + // Services + services.AddBaseServices(globalSettings); + services.AddDefaultServices(globalSettings); - services.TryAddSingleton(); + services.TryAddSingleton(); - // Mvc - services.AddMvc(config => - { - config.Filters.Add(new LoggingExceptionHandlerFilterAttribute()); - }); - services.Configure(options => options.LowercaseUrls = true); - } - - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings) + // Mvc + services.AddMvc(config => { - app.UseSerilog(env, appLifetime, globalSettings); + config.Filters.Add(new LoggingExceptionHandlerFilterAttribute()); + }); + services.Configure(options => options.LowercaseUrls = true); + } - // Add general security headers - app.UseMiddleware(); + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings) + { + app.UseSerilog(env, appLifetime, globalSettings); - if (env.IsDevelopment()) - { - app.UseDeveloperExceptionPage(); - } + // Add general security headers + app.UseMiddleware(); - // Default Middleware - app.UseDefaultMiddleware(env, globalSettings); - - // Add routing - app.UseRouting(); - - // Add Scim context - app.UseMiddleware(); - - // Add authentication and authorization to the request pipeline. - app.UseAuthentication(); - app.UseAuthorization(); - - // Add current context - app.UseMiddleware(); - - // Add MVC to the request pipeline. - app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); } + + // Default Middleware + app.UseDefaultMiddleware(env, globalSettings); + + // Add routing + app.UseRouting(); + + // Add Scim context + app.UseMiddleware(); + + // Add authentication and authorization to the request pipeline. + app.UseAuthentication(); + app.UseAuthorization(); + + // Add current context + app.UseMiddleware(); + + // Add MVC to the request pipeline. + app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); } } diff --git a/bitwarden_license/src/Scim/Utilities/ApiKeyAuthenticationHandler.cs b/bitwarden_license/src/Scim/Utilities/ApiKeyAuthenticationHandler.cs index c1b08b1b9..4e7e7ceb7 100644 --- a/bitwarden_license/src/Scim/Utilities/ApiKeyAuthenticationHandler.cs +++ b/bitwarden_license/src/Scim/Utilities/ApiKeyAuthenticationHandler.cs @@ -8,83 +8,82 @@ using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Authorization; using Microsoft.Extensions.Options; -namespace Bit.Scim.Utilities +namespace Bit.Scim.Utilities; + +public class ApiKeyAuthenticationHandler : AuthenticationHandler { - public class ApiKeyAuthenticationHandler : AuthenticationHandler + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; + private readonly IScimContext _scimContext; + + public ApiKeyAuthenticationHandler( + IOptionsMonitor options, + ILoggerFactory logger, + UrlEncoder encoder, + ISystemClock clock, + IOrganizationRepository organizationRepository, + IOrganizationApiKeyRepository organizationApiKeyRepository, + IScimContext scimContext) : + base(options, logger, encoder, clock) { - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; - private readonly IScimContext _scimContext; + _organizationRepository = organizationRepository; + _organizationApiKeyRepository = organizationApiKeyRepository; + _scimContext = scimContext; + } - public ApiKeyAuthenticationHandler( - IOptionsMonitor options, - ILoggerFactory logger, - UrlEncoder encoder, - ISystemClock clock, - IOrganizationRepository organizationRepository, - IOrganizationApiKeyRepository organizationApiKeyRepository, - IScimContext scimContext) : - base(options, logger, encoder, clock) + protected override async Task HandleAuthenticateAsync() + { + var endpoint = Context.GetEndpoint(); + if (endpoint?.Metadata?.GetMetadata() != null) { - _organizationRepository = organizationRepository; - _organizationApiKeyRepository = organizationApiKeyRepository; - _scimContext = scimContext; + return AuthenticateResult.NoResult(); } - protected override async Task HandleAuthenticateAsync() + if (!_scimContext.OrganizationId.HasValue || _scimContext.Organization == null) { - var endpoint = Context.GetEndpoint(); - if (endpoint?.Metadata?.GetMetadata() != null) - { - return AuthenticateResult.NoResult(); - } - - if (!_scimContext.OrganizationId.HasValue || _scimContext.Organization == null) - { - Logger.LogWarning("No organization."); - return AuthenticateResult.Fail("Invalid parameters"); - } - - if (!Request.Headers.TryGetValue("Authorization", out var authHeader) || authHeader.Count != 1) - { - Logger.LogWarning("An API request was received without the Authorization header"); - return AuthenticateResult.Fail("Invalid parameters"); - } - var apiKey = authHeader.ToString(); - if (apiKey.StartsWith("Bearer ")) - { - apiKey = apiKey.Substring(7); - } - - if (!_scimContext.Organization.Enabled || !_scimContext.Organization.UseScim || - _scimContext.ScimConfiguration == null || !_scimContext.ScimConfiguration.Enabled) - { - Logger.LogInformation("Org {organizationId} not able to use Scim.", _scimContext.OrganizationId); - return AuthenticateResult.Fail("Invalid parameters"); - } - - var orgApiKey = (await _organizationApiKeyRepository - .GetManyByOrganizationIdTypeAsync(_scimContext.Organization.Id, OrganizationApiKeyType.Scim)) - .FirstOrDefault(); - if (orgApiKey?.ApiKey != apiKey) - { - Logger.LogWarning("An API request was received with an invalid API key: {apiKey}", apiKey); - return AuthenticateResult.Fail("Invalid parameters"); - } - - Logger.LogInformation("Org {organizationId} authenticated", _scimContext.OrganizationId); - - var claims = new[] - { - new Claim(JwtClaimTypes.ClientId, $"organization.{_scimContext.OrganizationId.Value}"), - new Claim("client_sub", _scimContext.OrganizationId.Value.ToString()), - new Claim(JwtClaimTypes.Scope, "api.scim"), - }; - var identity = new ClaimsIdentity(claims, nameof(ApiKeyAuthenticationHandler)); - var ticket = new AuthenticationTicket(new ClaimsPrincipal(identity), - ApiKeyAuthenticationOptions.DefaultScheme); - - return AuthenticateResult.Success(ticket); + Logger.LogWarning("No organization."); + return AuthenticateResult.Fail("Invalid parameters"); } + + if (!Request.Headers.TryGetValue("Authorization", out var authHeader) || authHeader.Count != 1) + { + Logger.LogWarning("An API request was received without the Authorization header"); + return AuthenticateResult.Fail("Invalid parameters"); + } + var apiKey = authHeader.ToString(); + if (apiKey.StartsWith("Bearer ")) + { + apiKey = apiKey.Substring(7); + } + + if (!_scimContext.Organization.Enabled || !_scimContext.Organization.UseScim || + _scimContext.ScimConfiguration == null || !_scimContext.ScimConfiguration.Enabled) + { + Logger.LogInformation("Org {organizationId} not able to use Scim.", _scimContext.OrganizationId); + return AuthenticateResult.Fail("Invalid parameters"); + } + + var orgApiKey = (await _organizationApiKeyRepository + .GetManyByOrganizationIdTypeAsync(_scimContext.Organization.Id, OrganizationApiKeyType.Scim)) + .FirstOrDefault(); + if (orgApiKey?.ApiKey != apiKey) + { + Logger.LogWarning("An API request was received with an invalid API key: {apiKey}", apiKey); + return AuthenticateResult.Fail("Invalid parameters"); + } + + Logger.LogInformation("Org {organizationId} authenticated", _scimContext.OrganizationId); + + var claims = new[] + { + new Claim(JwtClaimTypes.ClientId, $"organization.{_scimContext.OrganizationId.Value}"), + new Claim("client_sub", _scimContext.OrganizationId.Value.ToString()), + new Claim(JwtClaimTypes.Scope, "api.scim"), + }; + var identity = new ClaimsIdentity(claims, nameof(ApiKeyAuthenticationHandler)); + var ticket = new AuthenticationTicket(new ClaimsPrincipal(identity), + ApiKeyAuthenticationOptions.DefaultScheme); + + return AuthenticateResult.Success(ticket); } } diff --git a/bitwarden_license/src/Scim/Utilities/ApiKeyAuthenticationOptions.cs b/bitwarden_license/src/Scim/Utilities/ApiKeyAuthenticationOptions.cs index 7d2bb3e81..f0015226b 100644 --- a/bitwarden_license/src/Scim/Utilities/ApiKeyAuthenticationOptions.cs +++ b/bitwarden_license/src/Scim/Utilities/ApiKeyAuthenticationOptions.cs @@ -1,9 +1,8 @@ using Microsoft.AspNetCore.Authentication; -namespace Bit.Scim.Utilities +namespace Bit.Scim.Utilities; + +public class ApiKeyAuthenticationOptions : AuthenticationSchemeOptions { - public class ApiKeyAuthenticationOptions : AuthenticationSchemeOptions - { - public const string DefaultScheme = "ScimApiKey"; - } + public const string DefaultScheme = "ScimApiKey"; } diff --git a/bitwarden_license/src/Scim/Utilities/ScimConstants.cs b/bitwarden_license/src/Scim/Utilities/ScimConstants.cs index 4c9d11f6c..219be6534 100644 --- a/bitwarden_license/src/Scim/Utilities/ScimConstants.cs +++ b/bitwarden_license/src/Scim/Utilities/ScimConstants.cs @@ -1,10 +1,9 @@ -namespace Bit.Scim.Utilities +namespace Bit.Scim.Utilities; + +public static class ScimConstants { - public static class ScimConstants - { - public const string Scim2SchemaListResponse = "urn:ietf:params:scim:api:messages:2.0:ListResponse"; - public const string Scim2SchemaError = "urn:ietf:params:scim:api:messages:2.0:Error"; - public const string Scim2SchemaUser = "urn:ietf:params:scim:schemas:core:2.0:User"; - public const string Scim2SchemaGroup = "urn:ietf:params:scim:schemas:core:2.0:Group"; - } + public const string Scim2SchemaListResponse = "urn:ietf:params:scim:api:messages:2.0:ListResponse"; + public const string Scim2SchemaError = "urn:ietf:params:scim:api:messages:2.0:Error"; + public const string Scim2SchemaUser = "urn:ietf:params:scim:schemas:core:2.0:User"; + public const string Scim2SchemaGroup = "urn:ietf:params:scim:schemas:core:2.0:Group"; } diff --git a/bitwarden_license/src/Scim/Utilities/ScimContextMiddleware.cs b/bitwarden_license/src/Scim/Utilities/ScimContextMiddleware.cs index 9550814de..6d5f3e1bf 100644 --- a/bitwarden_license/src/Scim/Utilities/ScimContextMiddleware.cs +++ b/bitwarden_license/src/Scim/Utilities/ScimContextMiddleware.cs @@ -2,22 +2,21 @@ using Bit.Core.Settings; using Bit.Scim.Context; -namespace Bit.Scim.Utilities +namespace Bit.Scim.Utilities; + +public class ScimContextMiddleware { - public class ScimContextMiddleware + private readonly RequestDelegate _next; + + public ScimContextMiddleware(RequestDelegate next) { - private readonly RequestDelegate _next; + _next = next; + } - public ScimContextMiddleware(RequestDelegate next) - { - _next = next; - } - - public async Task Invoke(HttpContext httpContext, IScimContext scimContext, GlobalSettings globalSettings, - IOrganizationRepository organizationRepository, IOrganizationConnectionRepository organizationConnectionRepository) - { - await scimContext.BuildAsync(httpContext, globalSettings, organizationRepository, organizationConnectionRepository); - await _next.Invoke(httpContext); - } + public async Task Invoke(HttpContext httpContext, IScimContext scimContext, GlobalSettings globalSettings, + IOrganizationRepository organizationRepository, IOrganizationConnectionRepository organizationConnectionRepository) + { + await scimContext.BuildAsync(httpContext, globalSettings, organizationRepository, organizationConnectionRepository); + await _next.Invoke(httpContext); } } diff --git a/bitwarden_license/src/Sso/Controllers/AccountController.cs b/bitwarden_license/src/Sso/Controllers/AccountController.cs index fbbf3084d..4ab9d7ef0 100644 --- a/bitwarden_license/src/Sso/Controllers/AccountController.cs +++ b/bitwarden_license/src/Sso/Controllers/AccountController.cs @@ -22,689 +22,688 @@ using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.Mvc; -namespace Bit.Sso.Controllers +namespace Bit.Sso.Controllers; + +public class AccountController : Controller { - public class AccountController : Controller + private readonly IAuthenticationSchemeProvider _schemeProvider; + private readonly IClientStore _clientStore; + + private readonly IIdentityServerInteractionService _interaction; + private readonly ILogger _logger; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IOrganizationService _organizationService; + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly ISsoUserRepository _ssoUserRepository; + private readonly IUserRepository _userRepository; + private readonly IPolicyRepository _policyRepository; + private readonly IUserService _userService; + private readonly II18nService _i18nService; + private readonly UserManager _userManager; + private readonly IGlobalSettings _globalSettings; + private readonly Core.Services.IEventService _eventService; + private readonly IDataProtectorTokenFactory _dataProtector; + + public AccountController( + IAuthenticationSchemeProvider schemeProvider, + IClientStore clientStore, + IIdentityServerInteractionService interaction, + ILogger logger, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationService organizationService, + ISsoConfigRepository ssoConfigRepository, + ISsoUserRepository ssoUserRepository, + IUserRepository userRepository, + IPolicyRepository policyRepository, + IUserService userService, + II18nService i18nService, + UserManager userManager, + IGlobalSettings globalSettings, + Core.Services.IEventService eventService, + IDataProtectorTokenFactory dataProtector) { - private readonly IAuthenticationSchemeProvider _schemeProvider; - private readonly IClientStore _clientStore; + _schemeProvider = schemeProvider; + _clientStore = clientStore; + _interaction = interaction; + _logger = logger; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _organizationService = organizationService; + _userRepository = userRepository; + _ssoConfigRepository = ssoConfigRepository; + _ssoUserRepository = ssoUserRepository; + _policyRepository = policyRepository; + _userService = userService; + _i18nService = i18nService; + _userManager = userManager; + _eventService = eventService; + _globalSettings = globalSettings; + _dataProtector = dataProtector; + } - private readonly IIdentityServerInteractionService _interaction; - private readonly ILogger _logger; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IOrganizationService _organizationService; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly ISsoUserRepository _ssoUserRepository; - private readonly IUserRepository _userRepository; - private readonly IPolicyRepository _policyRepository; - private readonly IUserService _userService; - private readonly II18nService _i18nService; - private readonly UserManager _userManager; - private readonly IGlobalSettings _globalSettings; - private readonly Core.Services.IEventService _eventService; - private readonly IDataProtectorTokenFactory _dataProtector; - - public AccountController( - IAuthenticationSchemeProvider schemeProvider, - IClientStore clientStore, - IIdentityServerInteractionService interaction, - ILogger logger, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationService organizationService, - ISsoConfigRepository ssoConfigRepository, - ISsoUserRepository ssoUserRepository, - IUserRepository userRepository, - IPolicyRepository policyRepository, - IUserService userService, - II18nService i18nService, - UserManager userManager, - IGlobalSettings globalSettings, - Core.Services.IEventService eventService, - IDataProtectorTokenFactory dataProtector) + [HttpGet] + public async Task PreValidate(string domainHint) + { + try { - _schemeProvider = schemeProvider; - _clientStore = clientStore; - _interaction = interaction; - _logger = logger; - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _organizationService = organizationService; - _userRepository = userRepository; - _ssoConfigRepository = ssoConfigRepository; - _ssoUserRepository = ssoUserRepository; - _policyRepository = policyRepository; - _userService = userService; - _i18nService = i18nService; - _userManager = userManager; - _eventService = eventService; - _globalSettings = globalSettings; - _dataProtector = dataProtector; - } - - [HttpGet] - public async Task PreValidate(string domainHint) - { - try + // Validate domain_hint provided + if (string.IsNullOrWhiteSpace(domainHint)) { - // Validate domain_hint provided - if (string.IsNullOrWhiteSpace(domainHint)) - { - return InvalidJson("NoOrganizationIdentifierProvidedError"); - } - - // Validate organization exists from domain_hint - var organization = await _organizationRepository.GetByIdentifierAsync(domainHint); - if (organization == null) - { - return InvalidJson("OrganizationNotFoundByIdentifierError"); - } - if (!organization.UseSso) - { - return InvalidJson("SsoNotAllowedForOrganizationError"); - } - - // Validate SsoConfig exists and is Enabled - var ssoConfig = await _ssoConfigRepository.GetByIdentifierAsync(domainHint); - if (ssoConfig == null) - { - return InvalidJson("SsoConfigurationNotFoundForOrganizationError"); - } - if (!ssoConfig.Enabled) - { - return InvalidJson("SsoNotEnabledForOrganizationError"); - } - - // Validate Authentication Scheme exists and is loaded (cache) - var scheme = await _schemeProvider.GetSchemeAsync(organization.Id.ToString()); - if (scheme == null || !(scheme is IDynamicAuthenticationScheme dynamicScheme)) - { - return InvalidJson("NoSchemeOrHandlerForSsoConfigurationFoundError"); - } - - // Run scheme validation - try - { - await dynamicScheme.Validate(); - } - catch (Exception ex) - { - var translatedException = _i18nService.GetLocalizedHtmlString(ex.Message); - var errorKey = "InvalidSchemeConfigurationError"; - if (!translatedException.ResourceNotFound) - { - errorKey = ex.Message; - } - return InvalidJson(errorKey, translatedException.ResourceNotFound ? ex : null); - } - - var tokenable = new SsoTokenable(organization, _globalSettings.Sso.SsoTokenLifetimeInSeconds); - var token = _dataProtector.Protect(tokenable); - - return new SsoPreValidateResponseModel(token); - } - catch (Exception ex) - { - return InvalidJson("PreValidationError", ex); - } - } - - [HttpGet] - public async Task Login(string returnUrl) - { - var context = await _interaction.GetAuthorizationContextAsync(returnUrl); - - if (!context.Parameters.AllKeys.Contains("domain_hint") || - string.IsNullOrWhiteSpace(context.Parameters["domain_hint"])) - { - throw new Exception(_i18nService.T("NoDomainHintProvided")); + return InvalidJson("NoOrganizationIdentifierProvidedError"); } - var ssoToken = context.Parameters[SsoTokenable.TokenIdentifier]; - - if (string.IsNullOrWhiteSpace(ssoToken)) - { - return Unauthorized("A valid SSO token is required to continue with SSO login"); - } - - var domainHint = context.Parameters["domain_hint"]; + // Validate organization exists from domain_hint var organization = await _organizationRepository.GetByIdentifierAsync(domainHint); - if (organization == null) { return InvalidJson("OrganizationNotFoundByIdentifierError"); } - - var tokenable = _dataProtector.Unprotect(ssoToken); - - if (!tokenable.TokenIsValid(organization)) + if (!organization.UseSso) { - return Unauthorized("The SSO token associated with your request is expired. A valid SSO token is required to continue."); + return InvalidJson("SsoNotAllowedForOrganizationError"); } - return RedirectToAction(nameof(ExternalChallenge), new + // Validate SsoConfig exists and is Enabled + var ssoConfig = await _ssoConfigRepository.GetByIdentifierAsync(domainHint); + if (ssoConfig == null) { - scheme = organization.Id.ToString(), - returnUrl, - state = context.Parameters["state"], - userIdentifier = context.Parameters["session_state"], - }); - } - - [HttpGet] - public IActionResult ExternalChallenge(string scheme, string returnUrl, string state, string userIdentifier) - { - if (string.IsNullOrEmpty(returnUrl)) + return InvalidJson("SsoConfigurationNotFoundForOrganizationError"); + } + if (!ssoConfig.Enabled) { - returnUrl = "~/"; + return InvalidJson("SsoNotEnabledForOrganizationError"); } - if (!Url.IsLocalUrl(returnUrl) && !_interaction.IsValidReturnUrl(returnUrl)) + // Validate Authentication Scheme exists and is loaded (cache) + var scheme = await _schemeProvider.GetSchemeAsync(organization.Id.ToString()); + if (scheme == null || !(scheme is IDynamicAuthenticationScheme dynamicScheme)) { - throw new Exception(_i18nService.T("InvalidReturnUrl")); + return InvalidJson("NoSchemeOrHandlerForSsoConfigurationFoundError"); } - var props = new AuthenticationProperties + // Run scheme validation + try { - RedirectUri = Url.Action(nameof(ExternalCallback)), - Items = + await dynamicScheme.Validate(); + } + catch (Exception ex) + { + var translatedException = _i18nService.GetLocalizedHtmlString(ex.Message); + var errorKey = "InvalidSchemeConfigurationError"; + if (!translatedException.ResourceNotFound) { - // scheme will get serialized into `State` and returned back - { "scheme", scheme }, - { "return_url", returnUrl }, - { "state", state }, - { "user_identifier", userIdentifier }, + errorKey = ex.Message; } - }; + return InvalidJson(errorKey, translatedException.ResourceNotFound ? ex : null); + } - return Challenge(props, scheme); + var tokenable = new SsoTokenable(organization, _globalSettings.Sso.SsoTokenLifetimeInSeconds); + var token = _dataProtector.Protect(tokenable); + + return new SsoPreValidateResponseModel(token); } - - [HttpGet] - public async Task ExternalCallback() + catch (Exception ex) { - // Read external identity from the temporary cookie - var result = await HttpContext.AuthenticateAsync( - AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); - if (result?.Succeeded != true) - { - throw new Exception(_i18nService.T("ExternalAuthenticationError")); - } - - // Debugging - var externalClaims = result.Principal.Claims.Select(c => $"{c.Type}: {c.Value}"); - _logger.LogDebug("External claims: {@claims}", externalClaims); - - // Lookup our user and external provider info - var (user, provider, providerUserId, claims, ssoConfigData) = await FindUserFromExternalProviderAsync(result); - if (user == null) - { - // This might be where you might initiate a custom workflow for user registration - // in this sample we don't show how that would be done, as our sample implementation - // simply auto-provisions new external user - var userIdentifier = result.Properties.Items.Keys.Contains("user_identifier") ? - result.Properties.Items["user_identifier"] : null; - user = await AutoProvisionUserAsync(provider, providerUserId, claims, userIdentifier, ssoConfigData); - } - - if (user != null) - { - // This allows us to collect any additional claims or properties - // for the specific protocols used and store them in the local auth cookie. - // this is typically used to store data needed for signout from those protocols. - var additionalLocalClaims = new List(); - var localSignInProps = new AuthenticationProperties - { - IsPersistent = true, - ExpiresUtc = DateTimeOffset.UtcNow.AddMinutes(1) - }; - ProcessLoginCallback(result, additionalLocalClaims, localSignInProps); - - // Issue authentication cookie for user - await HttpContext.SignInAsync(new IdentityServerUser(user.Id.ToString()) - { - DisplayName = user.Email, - IdentityProvider = provider, - AdditionalClaims = additionalLocalClaims.ToArray() - }, localSignInProps); - } - - // Delete temporary cookie used during external authentication - await HttpContext.SignOutAsync(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); - - // Retrieve return URL - var returnUrl = result.Properties.Items["return_url"] ?? "~/"; - - // Check if external login is in the context of an OIDC request - var context = await _interaction.GetAuthorizationContextAsync(returnUrl); - if (context != null) - { - if (IsNativeClient(context)) - { - // The client is native, so this change in how to - // return the response is for better UX for the end user. - HttpContext.Response.StatusCode = 200; - HttpContext.Response.Headers["Location"] = string.Empty; - return View("Redirect", new RedirectViewModel { RedirectUrl = returnUrl }); - } - } - - return Redirect(returnUrl); - } - - [HttpGet] - public async Task Logout(string logoutId) - { - // Build a model so the logged out page knows what to display - var (updatedLogoutId, redirectUri, externalAuthenticationScheme) = await GetLoggedOutDataAsync(logoutId); - - if (User?.Identity.IsAuthenticated == true) - { - // Delete local authentication cookie - await HttpContext.SignOutAsync(); - } - - // HACK: Temporary workaroud for the time being that doesn't try to sign out of OneLogin schemes, - // which doesnt support SLO - if (externalAuthenticationScheme != null && !externalAuthenticationScheme.Contains("onelogin")) - { - // Build a return URL so the upstream provider will redirect back - // to us after the user has logged out. this allows us to then - // complete our single sign-out processing. - var url = Url.Action("Logout", new { logoutId = updatedLogoutId }); - - // This triggers a redirect to the external provider for sign-out - return SignOut(new AuthenticationProperties { RedirectUri = url }, externalAuthenticationScheme); - } - if (redirectUri != null) - { - return View("Redirect", new RedirectViewModel { RedirectUrl = redirectUri }); - } - else - { - return Redirect("~/"); - } - } - - private async Task<(User user, string provider, string providerUserId, IEnumerable claims, SsoConfigurationData config)> - FindUserFromExternalProviderAsync(AuthenticateResult result) - { - var provider = result.Properties.Items["scheme"]; - var orgId = new Guid(provider); - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(orgId); - if (ssoConfig == null || !ssoConfig.Enabled) - { - throw new Exception(_i18nService.T("OrganizationOrSsoConfigNotFound")); - } - - var ssoConfigData = ssoConfig.GetData(); - var externalUser = result.Principal; - - // Validate acr claim against expectation before going further - if (!string.IsNullOrWhiteSpace(ssoConfigData.ExpectedReturnAcrValue)) - { - var acrClaim = externalUser.FindFirst(JwtClaimTypes.AuthenticationContextClassReference); - if (acrClaim?.Value != ssoConfigData.ExpectedReturnAcrValue) - { - throw new Exception(_i18nService.T("AcrMissingOrInvalid")); - } - } - - // Ensure the NameIdentifier used is not a transient name ID, if so, we need a different attribute - // for the user identifier. - static bool nameIdIsNotTransient(Claim c) => c.Type == ClaimTypes.NameIdentifier - && (c.Properties == null - || !c.Properties.ContainsKey(SamlPropertyKeys.ClaimFormat) - || c.Properties[SamlPropertyKeys.ClaimFormat] != SamlNameIdFormats.Transient); - - // Try to determine the unique id of the external user (issued by the provider) - // the most common claim type for that are the sub claim and the NameIdentifier - // depending on the external provider, some other claim type might be used - var customUserIdClaimTypes = ssoConfigData.GetAdditionalUserIdClaimTypes(); - var userIdClaim = externalUser.FindFirst(c => customUserIdClaimTypes.Contains(c.Type)) ?? - externalUser.FindFirst(JwtClaimTypes.Subject) ?? - externalUser.FindFirst(nameIdIsNotTransient) ?? - // Some SAML providers may use the `uid` attribute for this - // where a transient NameID has been sent in the subject - externalUser.FindFirst("uid") ?? - externalUser.FindFirst("upn") ?? - externalUser.FindFirst("eppn") ?? - throw new Exception(_i18nService.T("UnknownUserId")); - - // Remove the user id claim so we don't include it as an extra claim if/when we provision the user - var claims = externalUser.Claims.ToList(); - claims.Remove(userIdClaim); - - // find external user - var providerUserId = userIdClaim.Value; - - var user = await _userRepository.GetBySsoUserAsync(providerUserId, orgId); - - return (user, provider, providerUserId, claims, ssoConfigData); - } - - private async Task AutoProvisionUserAsync(string provider, string providerUserId, - IEnumerable claims, string userIdentifier, SsoConfigurationData config) - { - var name = GetName(claims, config.GetAdditionalNameClaimTypes()); - var email = GetEmailAddress(claims, config.GetAdditionalEmailClaimTypes()); - if (string.IsNullOrWhiteSpace(email) && providerUserId.Contains("@")) - { - email = providerUserId; - } - - if (!Guid.TryParse(provider, out var orgId)) - { - // TODO: support non-org (server-wide) SSO in the future? - throw new Exception(_i18nService.T("SSOProviderIsNotAnOrgId", provider)); - } - - User existingUser = null; - if (string.IsNullOrWhiteSpace(userIdentifier)) - { - if (string.IsNullOrWhiteSpace(email)) - { - throw new Exception(_i18nService.T("CannotFindEmailClaim")); - } - existingUser = await _userRepository.GetByEmailAsync(email); - } - else - { - var split = userIdentifier.Split(","); - if (split.Length < 2) - { - throw new Exception(_i18nService.T("InvalidUserIdentifier")); - } - var userId = split[0]; - var token = split[1]; - - var tokenOptions = new TokenOptions(); - - var claimedUser = await _userService.GetUserByIdAsync(userId); - if (claimedUser != null) - { - var tokenIsValid = await _userManager.VerifyUserTokenAsync( - claimedUser, tokenOptions.PasswordResetTokenProvider, TokenPurposes.LinkSso, token); - if (tokenIsValid) - { - existingUser = claimedUser; - } - else - { - throw new Exception(_i18nService.T("UserIdAndTokenMismatch")); - } - } - } - - OrganizationUser orgUser = null; - var organization = await _organizationRepository.GetByIdAsync(orgId); - if (organization == null) - { - throw new Exception(_i18nService.T("CouldNotFindOrganization", orgId)); - } - - // Try to find OrgUser via existing User Id (accepted/confirmed user) - if (existingUser != null) - { - var orgUsersByUserId = await _organizationUserRepository.GetManyByUserAsync(existingUser.Id); - orgUser = orgUsersByUserId.SingleOrDefault(u => u.OrganizationId == orgId); - } - - // If no Org User found by Existing User Id - search all organization users via email - orgUser ??= await _organizationUserRepository.GetByOrganizationEmailAsync(orgId, email); - - // All Existing User flows handled below - if (existingUser != null) - { - if (existingUser.UsesKeyConnector && - (orgUser == null || orgUser.Status == OrganizationUserStatusType.Invited)) - { - throw new Exception(_i18nService.T("UserAlreadyExistsKeyConnector")); - } - - if (orgUser == null) - { - // Org User is not created - no invite has been sent - throw new Exception(_i18nService.T("UserAlreadyExistsInviteProcess")); - } - - if (orgUser.Status == OrganizationUserStatusType.Invited) - { - // Org User is invited - they must manually accept the invite via email and authenticate with MP - throw new Exception(_i18nService.T("UserAlreadyInvited", email, organization.Name)); - } - - // Accepted or Confirmed - create SSO link and return; - await CreateSsoUserRecord(providerUserId, existingUser.Id, orgId, orgUser); - return existingUser; - } - - // Before any user creation - if Org User doesn't exist at this point - make sure there are enough seats to add one - if (orgUser == null && organization.Seats.HasValue) - { - var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(orgId); - var initialSeatCount = organization.Seats.Value; - var availableSeats = initialSeatCount - userCount; - var prorationDate = DateTime.UtcNow; - if (availableSeats < 1) - { - try - { - if (_globalSettings.SelfHosted) - { - throw new Exception("Cannot autoscale on self-hosted instance."); - } - - await _organizationService.AutoAddSeatsAsync(organization, 1, prorationDate); - } - catch (Exception e) - { - if (organization.Seats.Value != initialSeatCount) - { - await _organizationService.AdjustSeatsAsync(orgId, initialSeatCount - organization.Seats.Value, prorationDate); - } - _logger.LogInformation(e, "SSO auto provisioning failed"); - throw new Exception(_i18nService.T("NoSeatsAvailable", organization.Name)); - } - } - } - - // Create user record - all existing user flows are handled above - var user = new User - { - Name = name, - Email = email, - ApiKey = CoreHelpers.SecureRandomString(30) - }; - await _userService.RegisterUserAsync(user); - - // If the organization has 2fa policy enabled, make sure to default jit user 2fa to email - var twoFactorPolicy = - await _policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.TwoFactorAuthentication); - if (twoFactorPolicy != null && twoFactorPolicy.Enabled) - { - user.SetTwoFactorProviders(new Dictionary - { - [TwoFactorProviderType.Email] = new TwoFactorProvider - { - MetaData = new Dictionary { ["Email"] = user.Email.ToLowerInvariant() }, - Enabled = true - } - }); - await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Email); - } - - // Create Org User if null or else update existing Org User - if (orgUser == null) - { - orgUser = new OrganizationUser - { - OrganizationId = orgId, - UserId = user.Id, - Type = OrganizationUserType.User, - Status = OrganizationUserStatusType.Invited - }; - await _organizationUserRepository.CreateAsync(orgUser); - } - else - { - orgUser.UserId = user.Id; - await _organizationUserRepository.ReplaceAsync(orgUser); - } - - // Create sso user record - await CreateSsoUserRecord(providerUserId, user.Id, orgId, orgUser); - - return user; - } - - private IActionResult InvalidJson(string errorMessageKey, Exception ex = null) - { - Response.StatusCode = ex == null ? 400 : 500; - return Json(new ErrorResponseModel(_i18nService.T(errorMessageKey)) - { - ExceptionMessage = ex?.Message, - ExceptionStackTrace = ex?.StackTrace, - InnerExceptionMessage = ex?.InnerException?.Message, - }); - } - - private string GetEmailAddress(IEnumerable claims, IEnumerable additionalClaimTypes) - { - var filteredClaims = claims.Where(c => !string.IsNullOrWhiteSpace(c.Value) && c.Value.Contains("@")); - - var email = filteredClaims.GetFirstMatch(additionalClaimTypes.ToArray()) ?? - filteredClaims.GetFirstMatch(JwtClaimTypes.Email, ClaimTypes.Email, - SamlClaimTypes.Email, "mail", "emailaddress"); - if (!string.IsNullOrWhiteSpace(email)) - { - return email; - } - - var username = filteredClaims.GetFirstMatch(JwtClaimTypes.PreferredUserName, - SamlClaimTypes.UserId, "uid"); - if (!string.IsNullOrWhiteSpace(username)) - { - return username; - } - - return null; - } - - private string GetName(IEnumerable claims, IEnumerable additionalClaimTypes) - { - var filteredClaims = claims.Where(c => !string.IsNullOrWhiteSpace(c.Value)); - - var name = filteredClaims.GetFirstMatch(additionalClaimTypes.ToArray()) ?? - filteredClaims.GetFirstMatch(JwtClaimTypes.Name, ClaimTypes.Name, - SamlClaimTypes.DisplayName, SamlClaimTypes.CommonName, "displayname", "cn"); - if (!string.IsNullOrWhiteSpace(name)) - { - return name; - } - - var givenName = filteredClaims.GetFirstMatch(SamlClaimTypes.GivenName, "givenname", "firstname", - "fn", "fname", "nickname"); - var surname = filteredClaims.GetFirstMatch(SamlClaimTypes.Surname, "sn", "surname", "lastname"); - var nameParts = new[] { givenName, surname }.Where(p => !string.IsNullOrWhiteSpace(p)); - if (nameParts.Any()) - { - return string.Join(' ', nameParts); - } - - return null; - } - - private async Task CreateSsoUserRecord(string providerUserId, Guid userId, Guid orgId, OrganizationUser orgUser) - { - // Delete existing SsoUser (if any) - avoids error if providerId has changed and the sso link is stale - var existingSsoUser = await _ssoUserRepository.GetByUserIdOrganizationIdAsync(orgId, userId); - if (existingSsoUser != null) - { - await _ssoUserRepository.DeleteAsync(userId, orgId); - await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_ResetSsoLink); - } - else - { - // If no stale user, this is the user's first Sso login ever - await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_FirstSsoLogin); - } - - var ssoUser = new SsoUser - { - ExternalId = providerUserId, - UserId = userId, - OrganizationId = orgId, - }; - await _ssoUserRepository.CreateAsync(ssoUser); - } - - private void ProcessLoginCallback(AuthenticateResult externalResult, - List localClaims, AuthenticationProperties localSignInProps) - { - // If the external system sent a session id claim, copy it over - // so we can use it for single sign-out - var sid = externalResult.Principal.Claims.FirstOrDefault(x => x.Type == JwtClaimTypes.SessionId); - if (sid != null) - { - localClaims.Add(new Claim(JwtClaimTypes.SessionId, sid.Value)); - } - - // If the external provider issued an idToken, we'll keep it for signout - var idToken = externalResult.Properties.GetTokenValue("id_token"); - if (idToken != null) - { - localSignInProps.StoreTokens( - new[] { new AuthenticationToken { Name = "id_token", Value = idToken } }); - } - } - - private async Task GetProviderAsync(string returnUrl) - { - var context = await _interaction.GetAuthorizationContextAsync(returnUrl); - if (context?.IdP != null && await _schemeProvider.GetSchemeAsync(context.IdP) != null) - { - return context.IdP; - } - var schemes = await _schemeProvider.GetAllSchemesAsync(); - var providers = schemes.Select(x => x.Name).ToList(); - return providers.FirstOrDefault(); - } - - private async Task<(string, string, string)> GetLoggedOutDataAsync(string logoutId) - { - // Get context information (client name, post logout redirect URI and iframe for federated signout) - var logout = await _interaction.GetLogoutContextAsync(logoutId); - string externalAuthenticationScheme = null; - if (User?.Identity.IsAuthenticated == true) - { - var idp = User.FindFirst(JwtClaimTypes.IdentityProvider)?.Value; - if (idp != null && idp != IdentityServerConstants.LocalIdentityProvider) - { - var providerSupportsSignout = await HttpContext.GetSchemeSupportsSignOutAsync(idp); - if (providerSupportsSignout) - { - if (logoutId == null) - { - // If there's no current logout context, we need to create one - // this captures necessary info from the current logged in user - // before we signout and redirect away to the external IdP for signout - logoutId = await _interaction.CreateLogoutContextAsync(); - } - - externalAuthenticationScheme = idp; - } - } - } - - return (logoutId, logout?.PostLogoutRedirectUri, externalAuthenticationScheme); - } - - public bool IsNativeClient(IdentityServer4.Models.AuthorizationRequest context) - { - return !context.RedirectUri.StartsWith("https", StringComparison.Ordinal) - && !context.RedirectUri.StartsWith("http", StringComparison.Ordinal); + return InvalidJson("PreValidationError", ex); } } + + [HttpGet] + public async Task Login(string returnUrl) + { + var context = await _interaction.GetAuthorizationContextAsync(returnUrl); + + if (!context.Parameters.AllKeys.Contains("domain_hint") || + string.IsNullOrWhiteSpace(context.Parameters["domain_hint"])) + { + throw new Exception(_i18nService.T("NoDomainHintProvided")); + } + + var ssoToken = context.Parameters[SsoTokenable.TokenIdentifier]; + + if (string.IsNullOrWhiteSpace(ssoToken)) + { + return Unauthorized("A valid SSO token is required to continue with SSO login"); + } + + var domainHint = context.Parameters["domain_hint"]; + var organization = await _organizationRepository.GetByIdentifierAsync(domainHint); + + if (organization == null) + { + return InvalidJson("OrganizationNotFoundByIdentifierError"); + } + + var tokenable = _dataProtector.Unprotect(ssoToken); + + if (!tokenable.TokenIsValid(organization)) + { + return Unauthorized("The SSO token associated with your request is expired. A valid SSO token is required to continue."); + } + + return RedirectToAction(nameof(ExternalChallenge), new + { + scheme = organization.Id.ToString(), + returnUrl, + state = context.Parameters["state"], + userIdentifier = context.Parameters["session_state"], + }); + } + + [HttpGet] + public IActionResult ExternalChallenge(string scheme, string returnUrl, string state, string userIdentifier) + { + if (string.IsNullOrEmpty(returnUrl)) + { + returnUrl = "~/"; + } + + if (!Url.IsLocalUrl(returnUrl) && !_interaction.IsValidReturnUrl(returnUrl)) + { + throw new Exception(_i18nService.T("InvalidReturnUrl")); + } + + var props = new AuthenticationProperties + { + RedirectUri = Url.Action(nameof(ExternalCallback)), + Items = + { + // scheme will get serialized into `State` and returned back + { "scheme", scheme }, + { "return_url", returnUrl }, + { "state", state }, + { "user_identifier", userIdentifier }, + } + }; + + return Challenge(props, scheme); + } + + [HttpGet] + public async Task ExternalCallback() + { + // Read external identity from the temporary cookie + var result = await HttpContext.AuthenticateAsync( + AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); + if (result?.Succeeded != true) + { + throw new Exception(_i18nService.T("ExternalAuthenticationError")); + } + + // Debugging + var externalClaims = result.Principal.Claims.Select(c => $"{c.Type}: {c.Value}"); + _logger.LogDebug("External claims: {@claims}", externalClaims); + + // Lookup our user and external provider info + var (user, provider, providerUserId, claims, ssoConfigData) = await FindUserFromExternalProviderAsync(result); + if (user == null) + { + // This might be where you might initiate a custom workflow for user registration + // in this sample we don't show how that would be done, as our sample implementation + // simply auto-provisions new external user + var userIdentifier = result.Properties.Items.Keys.Contains("user_identifier") ? + result.Properties.Items["user_identifier"] : null; + user = await AutoProvisionUserAsync(provider, providerUserId, claims, userIdentifier, ssoConfigData); + } + + if (user != null) + { + // This allows us to collect any additional claims or properties + // for the specific protocols used and store them in the local auth cookie. + // this is typically used to store data needed for signout from those protocols. + var additionalLocalClaims = new List(); + var localSignInProps = new AuthenticationProperties + { + IsPersistent = true, + ExpiresUtc = DateTimeOffset.UtcNow.AddMinutes(1) + }; + ProcessLoginCallback(result, additionalLocalClaims, localSignInProps); + + // Issue authentication cookie for user + await HttpContext.SignInAsync(new IdentityServerUser(user.Id.ToString()) + { + DisplayName = user.Email, + IdentityProvider = provider, + AdditionalClaims = additionalLocalClaims.ToArray() + }, localSignInProps); + } + + // Delete temporary cookie used during external authentication + await HttpContext.SignOutAsync(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); + + // Retrieve return URL + var returnUrl = result.Properties.Items["return_url"] ?? "~/"; + + // Check if external login is in the context of an OIDC request + var context = await _interaction.GetAuthorizationContextAsync(returnUrl); + if (context != null) + { + if (IsNativeClient(context)) + { + // The client is native, so this change in how to + // return the response is for better UX for the end user. + HttpContext.Response.StatusCode = 200; + HttpContext.Response.Headers["Location"] = string.Empty; + return View("Redirect", new RedirectViewModel { RedirectUrl = returnUrl }); + } + } + + return Redirect(returnUrl); + } + + [HttpGet] + public async Task Logout(string logoutId) + { + // Build a model so the logged out page knows what to display + var (updatedLogoutId, redirectUri, externalAuthenticationScheme) = await GetLoggedOutDataAsync(logoutId); + + if (User?.Identity.IsAuthenticated == true) + { + // Delete local authentication cookie + await HttpContext.SignOutAsync(); + } + + // HACK: Temporary workaroud for the time being that doesn't try to sign out of OneLogin schemes, + // which doesnt support SLO + if (externalAuthenticationScheme != null && !externalAuthenticationScheme.Contains("onelogin")) + { + // Build a return URL so the upstream provider will redirect back + // to us after the user has logged out. this allows us to then + // complete our single sign-out processing. + var url = Url.Action("Logout", new { logoutId = updatedLogoutId }); + + // This triggers a redirect to the external provider for sign-out + return SignOut(new AuthenticationProperties { RedirectUri = url }, externalAuthenticationScheme); + } + if (redirectUri != null) + { + return View("Redirect", new RedirectViewModel { RedirectUrl = redirectUri }); + } + else + { + return Redirect("~/"); + } + } + + private async Task<(User user, string provider, string providerUserId, IEnumerable claims, SsoConfigurationData config)> + FindUserFromExternalProviderAsync(AuthenticateResult result) + { + var provider = result.Properties.Items["scheme"]; + var orgId = new Guid(provider); + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(orgId); + if (ssoConfig == null || !ssoConfig.Enabled) + { + throw new Exception(_i18nService.T("OrganizationOrSsoConfigNotFound")); + } + + var ssoConfigData = ssoConfig.GetData(); + var externalUser = result.Principal; + + // Validate acr claim against expectation before going further + if (!string.IsNullOrWhiteSpace(ssoConfigData.ExpectedReturnAcrValue)) + { + var acrClaim = externalUser.FindFirst(JwtClaimTypes.AuthenticationContextClassReference); + if (acrClaim?.Value != ssoConfigData.ExpectedReturnAcrValue) + { + throw new Exception(_i18nService.T("AcrMissingOrInvalid")); + } + } + + // Ensure the NameIdentifier used is not a transient name ID, if so, we need a different attribute + // for the user identifier. + static bool nameIdIsNotTransient(Claim c) => c.Type == ClaimTypes.NameIdentifier + && (c.Properties == null + || !c.Properties.ContainsKey(SamlPropertyKeys.ClaimFormat) + || c.Properties[SamlPropertyKeys.ClaimFormat] != SamlNameIdFormats.Transient); + + // Try to determine the unique id of the external user (issued by the provider) + // the most common claim type for that are the sub claim and the NameIdentifier + // depending on the external provider, some other claim type might be used + var customUserIdClaimTypes = ssoConfigData.GetAdditionalUserIdClaimTypes(); + var userIdClaim = externalUser.FindFirst(c => customUserIdClaimTypes.Contains(c.Type)) ?? + externalUser.FindFirst(JwtClaimTypes.Subject) ?? + externalUser.FindFirst(nameIdIsNotTransient) ?? + // Some SAML providers may use the `uid` attribute for this + // where a transient NameID has been sent in the subject + externalUser.FindFirst("uid") ?? + externalUser.FindFirst("upn") ?? + externalUser.FindFirst("eppn") ?? + throw new Exception(_i18nService.T("UnknownUserId")); + + // Remove the user id claim so we don't include it as an extra claim if/when we provision the user + var claims = externalUser.Claims.ToList(); + claims.Remove(userIdClaim); + + // find external user + var providerUserId = userIdClaim.Value; + + var user = await _userRepository.GetBySsoUserAsync(providerUserId, orgId); + + return (user, provider, providerUserId, claims, ssoConfigData); + } + + private async Task AutoProvisionUserAsync(string provider, string providerUserId, + IEnumerable claims, string userIdentifier, SsoConfigurationData config) + { + var name = GetName(claims, config.GetAdditionalNameClaimTypes()); + var email = GetEmailAddress(claims, config.GetAdditionalEmailClaimTypes()); + if (string.IsNullOrWhiteSpace(email) && providerUserId.Contains("@")) + { + email = providerUserId; + } + + if (!Guid.TryParse(provider, out var orgId)) + { + // TODO: support non-org (server-wide) SSO in the future? + throw new Exception(_i18nService.T("SSOProviderIsNotAnOrgId", provider)); + } + + User existingUser = null; + if (string.IsNullOrWhiteSpace(userIdentifier)) + { + if (string.IsNullOrWhiteSpace(email)) + { + throw new Exception(_i18nService.T("CannotFindEmailClaim")); + } + existingUser = await _userRepository.GetByEmailAsync(email); + } + else + { + var split = userIdentifier.Split(","); + if (split.Length < 2) + { + throw new Exception(_i18nService.T("InvalidUserIdentifier")); + } + var userId = split[0]; + var token = split[1]; + + var tokenOptions = new TokenOptions(); + + var claimedUser = await _userService.GetUserByIdAsync(userId); + if (claimedUser != null) + { + var tokenIsValid = await _userManager.VerifyUserTokenAsync( + claimedUser, tokenOptions.PasswordResetTokenProvider, TokenPurposes.LinkSso, token); + if (tokenIsValid) + { + existingUser = claimedUser; + } + else + { + throw new Exception(_i18nService.T("UserIdAndTokenMismatch")); + } + } + } + + OrganizationUser orgUser = null; + var organization = await _organizationRepository.GetByIdAsync(orgId); + if (organization == null) + { + throw new Exception(_i18nService.T("CouldNotFindOrganization", orgId)); + } + + // Try to find OrgUser via existing User Id (accepted/confirmed user) + if (existingUser != null) + { + var orgUsersByUserId = await _organizationUserRepository.GetManyByUserAsync(existingUser.Id); + orgUser = orgUsersByUserId.SingleOrDefault(u => u.OrganizationId == orgId); + } + + // If no Org User found by Existing User Id - search all organization users via email + orgUser ??= await _organizationUserRepository.GetByOrganizationEmailAsync(orgId, email); + + // All Existing User flows handled below + if (existingUser != null) + { + if (existingUser.UsesKeyConnector && + (orgUser == null || orgUser.Status == OrganizationUserStatusType.Invited)) + { + throw new Exception(_i18nService.T("UserAlreadyExistsKeyConnector")); + } + + if (orgUser == null) + { + // Org User is not created - no invite has been sent + throw new Exception(_i18nService.T("UserAlreadyExistsInviteProcess")); + } + + if (orgUser.Status == OrganizationUserStatusType.Invited) + { + // Org User is invited - they must manually accept the invite via email and authenticate with MP + throw new Exception(_i18nService.T("UserAlreadyInvited", email, organization.Name)); + } + + // Accepted or Confirmed - create SSO link and return; + await CreateSsoUserRecord(providerUserId, existingUser.Id, orgId, orgUser); + return existingUser; + } + + // Before any user creation - if Org User doesn't exist at this point - make sure there are enough seats to add one + if (orgUser == null && organization.Seats.HasValue) + { + var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(orgId); + var initialSeatCount = organization.Seats.Value; + var availableSeats = initialSeatCount - userCount; + var prorationDate = DateTime.UtcNow; + if (availableSeats < 1) + { + try + { + if (_globalSettings.SelfHosted) + { + throw new Exception("Cannot autoscale on self-hosted instance."); + } + + await _organizationService.AutoAddSeatsAsync(organization, 1, prorationDate); + } + catch (Exception e) + { + if (organization.Seats.Value != initialSeatCount) + { + await _organizationService.AdjustSeatsAsync(orgId, initialSeatCount - organization.Seats.Value, prorationDate); + } + _logger.LogInformation(e, "SSO auto provisioning failed"); + throw new Exception(_i18nService.T("NoSeatsAvailable", organization.Name)); + } + } + } + + // Create user record - all existing user flows are handled above + var user = new User + { + Name = name, + Email = email, + ApiKey = CoreHelpers.SecureRandomString(30) + }; + await _userService.RegisterUserAsync(user); + + // If the organization has 2fa policy enabled, make sure to default jit user 2fa to email + var twoFactorPolicy = + await _policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.TwoFactorAuthentication); + if (twoFactorPolicy != null && twoFactorPolicy.Enabled) + { + user.SetTwoFactorProviders(new Dictionary + { + [TwoFactorProviderType.Email] = new TwoFactorProvider + { + MetaData = new Dictionary { ["Email"] = user.Email.ToLowerInvariant() }, + Enabled = true + } + }); + await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Email); + } + + // Create Org User if null or else update existing Org User + if (orgUser == null) + { + orgUser = new OrganizationUser + { + OrganizationId = orgId, + UserId = user.Id, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Invited + }; + await _organizationUserRepository.CreateAsync(orgUser); + } + else + { + orgUser.UserId = user.Id; + await _organizationUserRepository.ReplaceAsync(orgUser); + } + + // Create sso user record + await CreateSsoUserRecord(providerUserId, user.Id, orgId, orgUser); + + return user; + } + + private IActionResult InvalidJson(string errorMessageKey, Exception ex = null) + { + Response.StatusCode = ex == null ? 400 : 500; + return Json(new ErrorResponseModel(_i18nService.T(errorMessageKey)) + { + ExceptionMessage = ex?.Message, + ExceptionStackTrace = ex?.StackTrace, + InnerExceptionMessage = ex?.InnerException?.Message, + }); + } + + private string GetEmailAddress(IEnumerable claims, IEnumerable additionalClaimTypes) + { + var filteredClaims = claims.Where(c => !string.IsNullOrWhiteSpace(c.Value) && c.Value.Contains("@")); + + var email = filteredClaims.GetFirstMatch(additionalClaimTypes.ToArray()) ?? + filteredClaims.GetFirstMatch(JwtClaimTypes.Email, ClaimTypes.Email, + SamlClaimTypes.Email, "mail", "emailaddress"); + if (!string.IsNullOrWhiteSpace(email)) + { + return email; + } + + var username = filteredClaims.GetFirstMatch(JwtClaimTypes.PreferredUserName, + SamlClaimTypes.UserId, "uid"); + if (!string.IsNullOrWhiteSpace(username)) + { + return username; + } + + return null; + } + + private string GetName(IEnumerable claims, IEnumerable additionalClaimTypes) + { + var filteredClaims = claims.Where(c => !string.IsNullOrWhiteSpace(c.Value)); + + var name = filteredClaims.GetFirstMatch(additionalClaimTypes.ToArray()) ?? + filteredClaims.GetFirstMatch(JwtClaimTypes.Name, ClaimTypes.Name, + SamlClaimTypes.DisplayName, SamlClaimTypes.CommonName, "displayname", "cn"); + if (!string.IsNullOrWhiteSpace(name)) + { + return name; + } + + var givenName = filteredClaims.GetFirstMatch(SamlClaimTypes.GivenName, "givenname", "firstname", + "fn", "fname", "nickname"); + var surname = filteredClaims.GetFirstMatch(SamlClaimTypes.Surname, "sn", "surname", "lastname"); + var nameParts = new[] { givenName, surname }.Where(p => !string.IsNullOrWhiteSpace(p)); + if (nameParts.Any()) + { + return string.Join(' ', nameParts); + } + + return null; + } + + private async Task CreateSsoUserRecord(string providerUserId, Guid userId, Guid orgId, OrganizationUser orgUser) + { + // Delete existing SsoUser (if any) - avoids error if providerId has changed and the sso link is stale + var existingSsoUser = await _ssoUserRepository.GetByUserIdOrganizationIdAsync(orgId, userId); + if (existingSsoUser != null) + { + await _ssoUserRepository.DeleteAsync(userId, orgId); + await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_ResetSsoLink); + } + else + { + // If no stale user, this is the user's first Sso login ever + await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_FirstSsoLogin); + } + + var ssoUser = new SsoUser + { + ExternalId = providerUserId, + UserId = userId, + OrganizationId = orgId, + }; + await _ssoUserRepository.CreateAsync(ssoUser); + } + + private void ProcessLoginCallback(AuthenticateResult externalResult, + List localClaims, AuthenticationProperties localSignInProps) + { + // If the external system sent a session id claim, copy it over + // so we can use it for single sign-out + var sid = externalResult.Principal.Claims.FirstOrDefault(x => x.Type == JwtClaimTypes.SessionId); + if (sid != null) + { + localClaims.Add(new Claim(JwtClaimTypes.SessionId, sid.Value)); + } + + // If the external provider issued an idToken, we'll keep it for signout + var idToken = externalResult.Properties.GetTokenValue("id_token"); + if (idToken != null) + { + localSignInProps.StoreTokens( + new[] { new AuthenticationToken { Name = "id_token", Value = idToken } }); + } + } + + private async Task GetProviderAsync(string returnUrl) + { + var context = await _interaction.GetAuthorizationContextAsync(returnUrl); + if (context?.IdP != null && await _schemeProvider.GetSchemeAsync(context.IdP) != null) + { + return context.IdP; + } + var schemes = await _schemeProvider.GetAllSchemesAsync(); + var providers = schemes.Select(x => x.Name).ToList(); + return providers.FirstOrDefault(); + } + + private async Task<(string, string, string)> GetLoggedOutDataAsync(string logoutId) + { + // Get context information (client name, post logout redirect URI and iframe for federated signout) + var logout = await _interaction.GetLogoutContextAsync(logoutId); + string externalAuthenticationScheme = null; + if (User?.Identity.IsAuthenticated == true) + { + var idp = User.FindFirst(JwtClaimTypes.IdentityProvider)?.Value; + if (idp != null && idp != IdentityServerConstants.LocalIdentityProvider) + { + var providerSupportsSignout = await HttpContext.GetSchemeSupportsSignOutAsync(idp); + if (providerSupportsSignout) + { + if (logoutId == null) + { + // If there's no current logout context, we need to create one + // this captures necessary info from the current logged in user + // before we signout and redirect away to the external IdP for signout + logoutId = await _interaction.CreateLogoutContextAsync(); + } + + externalAuthenticationScheme = idp; + } + } + } + + return (logoutId, logout?.PostLogoutRedirectUri, externalAuthenticationScheme); + } + + public bool IsNativeClient(IdentityServer4.Models.AuthorizationRequest context) + { + return !context.RedirectUri.StartsWith("https", StringComparison.Ordinal) + && !context.RedirectUri.StartsWith("http", StringComparison.Ordinal); + } } diff --git a/bitwarden_license/src/Sso/Controllers/HomeController.cs b/bitwarden_license/src/Sso/Controllers/HomeController.cs index 5ce112fa4..ee15fefc9 100644 --- a/bitwarden_license/src/Sso/Controllers/HomeController.cs +++ b/bitwarden_license/src/Sso/Controllers/HomeController.cs @@ -5,51 +5,50 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Diagnostics; using Microsoft.AspNetCore.Mvc; -namespace Bit.Sso.Controllers +namespace Bit.Sso.Controllers; + +public class HomeController : Controller { - public class HomeController : Controller + private readonly IIdentityServerInteractionService _interaction; + + public HomeController(IIdentityServerInteractionService interaction) { - private readonly IIdentityServerInteractionService _interaction; + _interaction = interaction; + } - public HomeController(IIdentityServerInteractionService interaction) + [Route("~/Error")] + [Route("~/Home/Error")] + [AllowAnonymous] + public async Task Error(string errorId) + { + var vm = new ErrorViewModel(); + + // retrieve error details from identityserver + var message = string.IsNullOrWhiteSpace(errorId) ? null : + await _interaction.GetErrorContextAsync(errorId); + if (message != null) { - _interaction = interaction; + vm.Error = message; + } + else + { + vm.RequestId = Activity.Current?.Id ?? HttpContext.TraceIdentifier; + var exceptionHandlerPathFeature = HttpContext.Features.Get(); + var exception = exceptionHandlerPathFeature?.Error; + if (exception is InvalidOperationException opEx && opEx.Message.Contains("schemes are: ")) + { + // Messages coming from aspnetcore with a message + // similar to "The registered sign-in schemes are: {schemes}." + // will expose other Org IDs and sign-in schemes enabled on + // the server. These errors should be truncated to just the + // scheme impacted (always the first sentence) + var cleanupPoint = opEx.Message.IndexOf(". ") + 1; + var exMessage = opEx.Message.Substring(0, cleanupPoint); + exception = new InvalidOperationException(exMessage, opEx); + } + vm.Exception = exception; } - [Route("~/Error")] - [Route("~/Home/Error")] - [AllowAnonymous] - public async Task Error(string errorId) - { - var vm = new ErrorViewModel(); - - // retrieve error details from identityserver - var message = string.IsNullOrWhiteSpace(errorId) ? null : - await _interaction.GetErrorContextAsync(errorId); - if (message != null) - { - vm.Error = message; - } - else - { - vm.RequestId = Activity.Current?.Id ?? HttpContext.TraceIdentifier; - var exceptionHandlerPathFeature = HttpContext.Features.Get(); - var exception = exceptionHandlerPathFeature?.Error; - if (exception is InvalidOperationException opEx && opEx.Message.Contains("schemes are: ")) - { - // Messages coming from aspnetcore with a message - // similar to "The registered sign-in schemes are: {schemes}." - // will expose other Org IDs and sign-in schemes enabled on - // the server. These errors should be truncated to just the - // scheme impacted (always the first sentence) - var cleanupPoint = opEx.Message.IndexOf(". ") + 1; - var exMessage = opEx.Message.Substring(0, cleanupPoint); - exception = new InvalidOperationException(exMessage, opEx); - } - vm.Exception = exception; - } - - return View("Error", vm); - } + return View("Error", vm); } } diff --git a/bitwarden_license/src/Sso/Controllers/InfoController.cs b/bitwarden_license/src/Sso/Controllers/InfoController.cs index d652e8cdd..c3641c466 100644 --- a/bitwarden_license/src/Sso/Controllers/InfoController.cs +++ b/bitwarden_license/src/Sso/Controllers/InfoController.cs @@ -1,21 +1,20 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Sso.Controllers -{ - public class InfoController : Controller - { - [HttpGet("~/alive")] - [HttpGet("~/now")] - public DateTime GetAlive() - { - return DateTime.UtcNow; - } +namespace Bit.Sso.Controllers; - [HttpGet("~/version")] - public JsonResult GetVersion() - { - return Json(CoreHelpers.GetVersion()); - } +public class InfoController : Controller +{ + [HttpGet("~/alive")] + [HttpGet("~/now")] + public DateTime GetAlive() + { + return DateTime.UtcNow; + } + + [HttpGet("~/version")] + public JsonResult GetVersion() + { + return Json(CoreHelpers.GetVersion()); } } diff --git a/bitwarden_license/src/Sso/Controllers/MetadataController.cs b/bitwarden_license/src/Sso/Controllers/MetadataController.cs index dbf033e84..54f4f8cd4 100644 --- a/bitwarden_license/src/Sso/Controllers/MetadataController.cs +++ b/bitwarden_license/src/Sso/Controllers/MetadataController.cs @@ -5,66 +5,65 @@ using Microsoft.AspNetCore.Mvc; using Sustainsys.Saml2.AspNetCore2; using Sustainsys.Saml2.WebSso; -namespace Bit.Sso.Controllers +namespace Bit.Sso.Controllers; + +public class MetadataController : Controller { - public class MetadataController : Controller + private readonly IAuthenticationSchemeProvider _schemeProvider; + + public MetadataController( + IAuthenticationSchemeProvider schemeProvider) { - private readonly IAuthenticationSchemeProvider _schemeProvider; + _schemeProvider = schemeProvider; + } - public MetadataController( - IAuthenticationSchemeProvider schemeProvider) + [HttpGet("saml2/{scheme}")] + public async Task ViewAsync(string scheme) + { + if (string.IsNullOrWhiteSpace(scheme)) { - _schemeProvider = schemeProvider; + return NotFound(); } - [HttpGet("saml2/{scheme}")] - public async Task ViewAsync(string scheme) + var authScheme = await _schemeProvider.GetSchemeAsync(scheme); + if (authScheme == null || + !(authScheme is DynamicAuthenticationScheme dynamicAuthScheme) || + dynamicAuthScheme?.SsoType != SsoType.Saml2) { - if (string.IsNullOrWhiteSpace(scheme)) - { - return NotFound(); - } - - var authScheme = await _schemeProvider.GetSchemeAsync(scheme); - if (authScheme == null || - !(authScheme is DynamicAuthenticationScheme dynamicAuthScheme) || - dynamicAuthScheme?.SsoType != SsoType.Saml2) - { - return NotFound(); - } - - if (!(dynamicAuthScheme.Options is Saml2Options options)) - { - return NotFound(); - } - - var uri = new Uri( - Request.Scheme - + "://" - + Request.Host - + Request.Path - + Request.QueryString); - - var pathBase = Request.PathBase.Value; - pathBase = string.IsNullOrEmpty(pathBase) ? "/" : pathBase; - - var requestdata = new HttpRequestData( - Request.Method, - uri, - pathBase, - null, - Request.Cookies, - (data) => data); - - var metadataResult = CommandFactory - .GetCommand(CommandFactory.MetadataCommand) - .Run(requestdata, options); - //Response.Headers.Add("Content-Disposition", $"filename= bitwarden-saml2-meta-{scheme}.xml"); - return new ContentResult - { - Content = metadataResult.Content, - ContentType = "text/xml", - }; + return NotFound(); } + + if (!(dynamicAuthScheme.Options is Saml2Options options)) + { + return NotFound(); + } + + var uri = new Uri( + Request.Scheme + + "://" + + Request.Host + + Request.Path + + Request.QueryString); + + var pathBase = Request.PathBase.Value; + pathBase = string.IsNullOrEmpty(pathBase) ? "/" : pathBase; + + var requestdata = new HttpRequestData( + Request.Method, + uri, + pathBase, + null, + Request.Cookies, + (data) => data); + + var metadataResult = CommandFactory + .GetCommand(CommandFactory.MetadataCommand) + .Run(requestdata, options); + //Response.Headers.Add("Content-Disposition", $"filename= bitwarden-saml2-meta-{scheme}.xml"); + return new ContentResult + { + Content = metadataResult.Content, + ContentType = "text/xml", + }; } } diff --git a/bitwarden_license/src/Sso/Models/ErrorViewModel.cs b/bitwarden_license/src/Sso/Models/ErrorViewModel.cs index 4c0ea8748..46ae8edd9 100644 --- a/bitwarden_license/src/Sso/Models/ErrorViewModel.cs +++ b/bitwarden_license/src/Sso/Models/ErrorViewModel.cs @@ -1,27 +1,26 @@ using IdentityServer4.Models; -namespace Bit.Sso.Models +namespace Bit.Sso.Models; + +public class ErrorViewModel { - public class ErrorViewModel + private string _requestId; + + public ErrorMessage Error { get; set; } + public Exception Exception { get; set; } + + public string Message => Error?.Error; + public string Description => Error?.ErrorDescription ?? Exception?.Message; + public string RedirectUri => Error?.RedirectUri; + public string RequestId { - private string _requestId; - - public ErrorMessage Error { get; set; } - public Exception Exception { get; set; } - - public string Message => Error?.Error; - public string Description => Error?.ErrorDescription ?? Exception?.Message; - public string RedirectUri => Error?.RedirectUri; - public string RequestId + get { - get - { - return Error?.RequestId ?? _requestId; - } - set - { - _requestId = value; - } + return Error?.RequestId ?? _requestId; + } + set + { + _requestId = value; } } } diff --git a/bitwarden_license/src/Sso/Models/RedirectViewModel.cs b/bitwarden_license/src/Sso/Models/RedirectViewModel.cs index 54b5b4715..9bc294d96 100644 --- a/bitwarden_license/src/Sso/Models/RedirectViewModel.cs +++ b/bitwarden_license/src/Sso/Models/RedirectViewModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Sso.Models +namespace Bit.Sso.Models; + +public class RedirectViewModel { - public class RedirectViewModel - { - public string RedirectUrl { get; set; } - } + public string RedirectUrl { get; set; } } diff --git a/bitwarden_license/src/Sso/Models/SamlEnvironment.cs b/bitwarden_license/src/Sso/Models/SamlEnvironment.cs index f1890840f..6de718029 100644 --- a/bitwarden_license/src/Sso/Models/SamlEnvironment.cs +++ b/bitwarden_license/src/Sso/Models/SamlEnvironment.cs @@ -1,9 +1,8 @@ using System.Security.Cryptography.X509Certificates; -namespace Bit.Sso.Models +namespace Bit.Sso.Models; + +public class SamlEnvironment { - public class SamlEnvironment - { - public X509Certificate2 SpSigningCertificate { get; set; } - } + public X509Certificate2 SpSigningCertificate { get; set; } } diff --git a/bitwarden_license/src/Sso/Models/SsoPreValidateResponseModel.cs b/bitwarden_license/src/Sso/Models/SsoPreValidateResponseModel.cs index 9877e1c5a..f96b38775 100644 --- a/bitwarden_license/src/Sso/Models/SsoPreValidateResponseModel.cs +++ b/bitwarden_license/src/Sso/Models/SsoPreValidateResponseModel.cs @@ -1,13 +1,12 @@ using Microsoft.AspNetCore.Mvc; -namespace Bit.Sso.Models +namespace Bit.Sso.Models; + +public class SsoPreValidateResponseModel : JsonResult { - public class SsoPreValidateResponseModel : JsonResult + public SsoPreValidateResponseModel(string token) : base(new { - public SsoPreValidateResponseModel(string token) : base(new - { - token - }) - { } - } + token + }) + { } } diff --git a/bitwarden_license/src/Sso/Program.cs b/bitwarden_license/src/Sso/Program.cs index 910f09332..672c73bfb 100644 --- a/bitwarden_license/src/Sso/Program.cs +++ b/bitwarden_license/src/Sso/Program.cs @@ -2,33 +2,32 @@ using Serilog; using Serilog.Events; -namespace Bit.Sso +namespace Bit.Sso; + +public class Program { - public class Program + public static void Main(string[] args) { - public static void Main(string[] args) - { - Host - .CreateDefaultBuilder(args) - .ConfigureCustomAppConfiguration(args) - .ConfigureWebHostDefaults(webBuilder => + Host + .CreateDefaultBuilder(args) + .ConfigureCustomAppConfiguration(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => { - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => + var context = e.Properties["SourceContext"].ToString(); + if (e.Properties.ContainsKey("RequestPath") && + !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && + (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) { - var context = e.Properties["SourceContext"].ToString(); - if (e.Properties.ContainsKey("RequestPath") && - !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) - { - return false; - } - return e.Level >= LogEventLevel.Error; - })); - }) - .Build() - .Run(); - } + return false; + } + return e.Level >= LogEventLevel.Error; + })); + }) + .Build() + .Run(); } } diff --git a/bitwarden_license/src/Sso/Startup.cs b/bitwarden_license/src/Sso/Startup.cs index 6116d86c2..99aa5961f 100644 --- a/bitwarden_license/src/Sso/Startup.cs +++ b/bitwarden_license/src/Sso/Startup.cs @@ -8,148 +8,147 @@ using IdentityServer4.Extensions; using Microsoft.IdentityModel.Logging; using Stripe; -namespace Bit.Sso +namespace Bit.Sso; + +public class Startup { - public class Startup + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + Configuration = configuration; + Environment = env; + } + + public IConfiguration Configuration { get; } + public IWebHostEnvironment Environment { get; set; } + + public void ConfigureServices(IServiceCollection services) + { + // Options + services.AddOptions(); + + // Settings + var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); + + // Stripe Billing + StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; + StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; + + // Data Protection + services.AddCustomDataProtectionServices(Environment, globalSettings); + + // Repositories + services.AddSqlServerRepositories(globalSettings); + + // Context + services.AddScoped(); + + // Caching + services.AddMemoryCache(); + services.AddDistributedCache(globalSettings); + + // Mvc + services.AddControllersWithViews(); + + // Cookies + if (Environment.IsDevelopment()) { - Configuration = configuration; - Environment = env; - } - - public IConfiguration Configuration { get; } - public IWebHostEnvironment Environment { get; set; } - - public void ConfigureServices(IServiceCollection services) - { - // Options - services.AddOptions(); - - // Settings - var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); - - // Stripe Billing - StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; - StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; - - // Data Protection - services.AddCustomDataProtectionServices(Environment, globalSettings); - - // Repositories - services.AddSqlServerRepositories(globalSettings); - - // Context - services.AddScoped(); - - // Caching - services.AddMemoryCache(); - services.AddDistributedCache(globalSettings); - - // Mvc - services.AddControllersWithViews(); - - // Cookies - if (Environment.IsDevelopment()) + services.Configure(options => { - services.Configure(options => + options.MinimumSameSitePolicy = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; + options.OnAppendCookie = ctx => { - options.MinimumSameSitePolicy = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; - options.OnAppendCookie = ctx => - { - ctx.CookieOptions.SameSite = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; - }; - }); - } - - // Authentication - services.AddDistributedIdentityServices(globalSettings); - services.AddAuthentication() - .AddCookie(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); - services.AddSsoServices(globalSettings); - - // IdentityServer - services.AddSsoIdentityServerServices(Environment, globalSettings); - - // Identity - services.AddCustomIdentityServices(globalSettings); - - // Services - services.AddBaseServices(globalSettings); - services.AddDefaultServices(globalSettings); - services.AddCoreLocalizationServices(); - } - - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings, - ILogger logger) - { - if (env.IsDevelopment() || globalSettings.SelfHosted) - { - IdentityModelEventSource.ShowPII = true; - } - - app.UseSerilog(env, appLifetime, globalSettings); - - // Add general security headers - app.UseMiddleware(); - - if (!env.IsDevelopment()) - { - var uri = new Uri(globalSettings.BaseServiceUri.Sso); - app.Use(async (ctx, next) => - { - ctx.SetIdentityServerOrigin($"{uri.Scheme}://{uri.Host}"); - await next(); - }); - } - - if (globalSettings.SelfHosted) - { - app.UsePathBase("/sso"); - app.UseForwardedHeaders(globalSettings); - } - - if (env.IsDevelopment()) - { - app.UseDeveloperExceptionPage(); - app.UseCookiePolicy(); - } - else - { - app.UseExceptionHandler("/Error"); - } - - app.UseCoreLocalization(); - - // Add static files to the request pipeline. - app.UseStaticFiles(); - - // Add routing - app.UseRouting(); - - // Add Cors - app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) - .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); - - // Add current context - app.UseMiddleware(); - - // Add IdentityServer to the request pipeline. - app.UseIdentityServer(new IdentityServerMiddlewareOptions - { - AuthenticationMiddleware = app => app.UseMiddleware() + ctx.CookieOptions.SameSite = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; + }; }); - - // Add Mvc stuff - app.UseAuthorization(); - app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); - - // Log startup - logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started."); } + + // Authentication + services.AddDistributedIdentityServices(globalSettings); + services.AddAuthentication() + .AddCookie(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); + services.AddSsoServices(globalSettings); + + // IdentityServer + services.AddSsoIdentityServerServices(Environment, globalSettings); + + // Identity + services.AddCustomIdentityServices(globalSettings); + + // Services + services.AddBaseServices(globalSettings); + services.AddDefaultServices(globalSettings); + services.AddCoreLocalizationServices(); + } + + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings, + ILogger logger) + { + if (env.IsDevelopment() || globalSettings.SelfHosted) + { + IdentityModelEventSource.ShowPII = true; + } + + app.UseSerilog(env, appLifetime, globalSettings); + + // Add general security headers + app.UseMiddleware(); + + if (!env.IsDevelopment()) + { + var uri = new Uri(globalSettings.BaseServiceUri.Sso); + app.Use(async (ctx, next) => + { + ctx.SetIdentityServerOrigin($"{uri.Scheme}://{uri.Host}"); + await next(); + }); + } + + if (globalSettings.SelfHosted) + { + app.UsePathBase("/sso"); + app.UseForwardedHeaders(globalSettings); + } + + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); + app.UseCookiePolicy(); + } + else + { + app.UseExceptionHandler("/Error"); + } + + app.UseCoreLocalization(); + + // Add static files to the request pipeline. + app.UseStaticFiles(); + + // Add routing + app.UseRouting(); + + // Add Cors + app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) + .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); + + // Add current context + app.UseMiddleware(); + + // Add IdentityServer to the request pipeline. + app.UseIdentityServer(new IdentityServerMiddlewareOptions + { + AuthenticationMiddleware = app => app.UseMiddleware() + }); + + // Add Mvc stuff + app.UseAuthorization(); + app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); + + // Log startup + logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started."); } } diff --git a/bitwarden_license/src/Sso/Utilities/ClaimsExtensions.cs b/bitwarden_license/src/Sso/Utilities/ClaimsExtensions.cs index 93a6fd146..735c7bc0a 100644 --- a/bitwarden_license/src/Sso/Utilities/ClaimsExtensions.cs +++ b/bitwarden_license/src/Sso/Utilities/ClaimsExtensions.cs @@ -1,46 +1,45 @@ using System.Security.Claims; using System.Text.RegularExpressions; -namespace Bit.Sso.Utilities +namespace Bit.Sso.Utilities; + +public static class ClaimsExtensions { - public static class ClaimsExtensions + private static readonly Regex _normalizeTextRegEx = + new Regex(@"[^a-zA-Z]", RegexOptions.CultureInvariant | RegexOptions.Singleline); + + public static string GetFirstMatch(this IEnumerable claims, params string[] possibleNames) { - private static readonly Regex _normalizeTextRegEx = - new Regex(@"[^a-zA-Z]", RegexOptions.CultureInvariant | RegexOptions.Singleline); + var normalizedClaims = claims.Select(c => (Normalize(c.Type), c.Value)).ToList(); - public static string GetFirstMatch(this IEnumerable claims, params string[] possibleNames) + // Order of prescendence is by passed in names + foreach (var name in possibleNames.Select(Normalize)) { - var normalizedClaims = claims.Select(c => (Normalize(c.Type), c.Value)).ToList(); - - // Order of prescendence is by passed in names - foreach (var name in possibleNames.Select(Normalize)) + // Second by order of claims (find claim by name) + foreach (var claim in normalizedClaims) { - // Second by order of claims (find claim by name) - foreach (var claim in normalizedClaims) + if (Equals(claim.Item1, name)) { - if (Equals(claim.Item1, name)) - { - return claim.Value; - } + return claim.Value; } } - return null; } + return null; + } - private static bool Equals(string text, string compare) - { - return text == compare || - (string.IsNullOrWhiteSpace(text) && string.IsNullOrWhiteSpace(compare)) || - string.Equals(Normalize(text), compare, StringComparison.InvariantCultureIgnoreCase); - } + private static bool Equals(string text, string compare) + { + return text == compare || + (string.IsNullOrWhiteSpace(text) && string.IsNullOrWhiteSpace(compare)) || + string.Equals(Normalize(text), compare, StringComparison.InvariantCultureIgnoreCase); + } - private static string Normalize(string text) + private static string Normalize(string text) + { + if (string.IsNullOrWhiteSpace(text)) { - if (string.IsNullOrWhiteSpace(text)) - { - return text; - } - return _normalizeTextRegEx.Replace(text, string.Empty); + return text; } + return _normalizeTextRegEx.Replace(text, string.Empty); } } diff --git a/bitwarden_license/src/Sso/Utilities/DiscoveryResponseGenerator.cs b/bitwarden_license/src/Sso/Utilities/DiscoveryResponseGenerator.cs index bd58fc612..7a7f56963 100644 --- a/bitwarden_license/src/Sso/Utilities/DiscoveryResponseGenerator.cs +++ b/bitwarden_license/src/Sso/Utilities/DiscoveryResponseGenerator.cs @@ -5,32 +5,31 @@ using IdentityServer4.Services; using IdentityServer4.Stores; using IdentityServer4.Validation; -namespace Bit.Sso.Utilities +namespace Bit.Sso.Utilities; + +public class DiscoveryResponseGenerator : IdentityServer4.ResponseHandling.DiscoveryResponseGenerator { - public class DiscoveryResponseGenerator : IdentityServer4.ResponseHandling.DiscoveryResponseGenerator + private readonly GlobalSettings _globalSettings; + + public DiscoveryResponseGenerator( + IdentityServerOptions options, + IResourceStore resourceStore, + IKeyMaterialService keys, + ExtensionGrantValidator extensionGrants, + ISecretsListParser secretParsers, + IResourceOwnerPasswordValidator resourceOwnerValidator, + ILogger logger, + GlobalSettings globalSettings) + : base(options, resourceStore, keys, extensionGrants, secretParsers, resourceOwnerValidator, logger) { - private readonly GlobalSettings _globalSettings; + _globalSettings = globalSettings; + } - public DiscoveryResponseGenerator( - IdentityServerOptions options, - IResourceStore resourceStore, - IKeyMaterialService keys, - ExtensionGrantValidator extensionGrants, - ISecretsListParser secretParsers, - IResourceOwnerPasswordValidator resourceOwnerValidator, - ILogger logger, - GlobalSettings globalSettings) - : base(options, resourceStore, keys, extensionGrants, secretParsers, resourceOwnerValidator, logger) - { - _globalSettings = globalSettings; - } - - public override async Task> CreateDiscoveryDocumentAsync( - string baseUrl, string issuerUri) - { - var dict = await base.CreateDiscoveryDocumentAsync(baseUrl, issuerUri); - return CoreHelpers.AdjustIdentityServerConfig(dict, _globalSettings.BaseServiceUri.Sso, - _globalSettings.BaseServiceUri.InternalSso); - } + public override async Task> CreateDiscoveryDocumentAsync( + string baseUrl, string issuerUri) + { + var dict = await base.CreateDiscoveryDocumentAsync(baseUrl, issuerUri); + return CoreHelpers.AdjustIdentityServerConfig(dict, _globalSettings.BaseServiceUri.Sso, + _globalSettings.BaseServiceUri.InternalSso); } } diff --git a/bitwarden_license/src/Sso/Utilities/DynamicAuthenticationScheme.cs b/bitwarden_license/src/Sso/Utilities/DynamicAuthenticationScheme.cs index 5a7ab6523..96a316bc6 100644 --- a/bitwarden_license/src/Sso/Utilities/DynamicAuthenticationScheme.cs +++ b/bitwarden_license/src/Sso/Utilities/DynamicAuthenticationScheme.cs @@ -3,88 +3,87 @@ using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Sustainsys.Saml2.AspNetCore2; -namespace Bit.Sso.Utilities +namespace Bit.Sso.Utilities; + +public class DynamicAuthenticationScheme : AuthenticationScheme, IDynamicAuthenticationScheme { - public class DynamicAuthenticationScheme : AuthenticationScheme, IDynamicAuthenticationScheme + public DynamicAuthenticationScheme(string name, string displayName, Type handlerType, + AuthenticationSchemeOptions options) + : base(name, displayName, handlerType) { - public DynamicAuthenticationScheme(string name, string displayName, Type handlerType, - AuthenticationSchemeOptions options) - : base(name, displayName, handlerType) - { - Options = options; - } - public DynamicAuthenticationScheme(string name, string displayName, Type handlerType, - AuthenticationSchemeOptions options, SsoType ssoType) - : this(name, displayName, handlerType, options) - { - SsoType = ssoType; - } + Options = options; + } + public DynamicAuthenticationScheme(string name, string displayName, Type handlerType, + AuthenticationSchemeOptions options, SsoType ssoType) + : this(name, displayName, handlerType, options) + { + SsoType = ssoType; + } - public AuthenticationSchemeOptions Options { get; set; } - public SsoType SsoType { get; set; } + public AuthenticationSchemeOptions Options { get; set; } + public SsoType SsoType { get; set; } - public async Task Validate() + public async Task Validate() + { + switch (SsoType) { - switch (SsoType) - { - case SsoType.OpenIdConnect: - await ValidateOpenIdConnectAsync(); - break; - case SsoType.Saml2: - ValidateSaml(); - break; - default: - break; - } + case SsoType.OpenIdConnect: + await ValidateOpenIdConnectAsync(); + break; + case SsoType.Saml2: + ValidateSaml(); + break; + default: + break; } + } - private void ValidateSaml() + private void ValidateSaml() + { + if (SsoType != SsoType.Saml2) { - if (SsoType != SsoType.Saml2) - { - return; - } - if (!(Options is Saml2Options samlOptions)) - { - throw new Exception("InvalidAuthenticationOptionsForSaml2SchemeError"); - } - samlOptions.Validate(Name); + return; } - - private async Task ValidateOpenIdConnectAsync() + if (!(Options is Saml2Options samlOptions)) { - if (SsoType != SsoType.OpenIdConnect) + throw new Exception("InvalidAuthenticationOptionsForSaml2SchemeError"); + } + samlOptions.Validate(Name); + } + + private async Task ValidateOpenIdConnectAsync() + { + if (SsoType != SsoType.OpenIdConnect) + { + return; + } + if (!(Options is OpenIdConnectOptions oidcOptions)) + { + throw new Exception("InvalidAuthenticationOptionsForOidcSchemeError"); + } + oidcOptions.Validate(); + if (oidcOptions.Configuration == null) + { + if (oidcOptions.ConfigurationManager == null) { - return; - } - if (!(Options is OpenIdConnectOptions oidcOptions)) - { - throw new Exception("InvalidAuthenticationOptionsForOidcSchemeError"); - } - oidcOptions.Validate(); - if (oidcOptions.Configuration == null) - { - if (oidcOptions.ConfigurationManager == null) - { - throw new Exception("PostConfigurationNotExecutedError"); - } - if (oidcOptions.Configuration == null) - { - try - { - oidcOptions.Configuration = await oidcOptions.ConfigurationManager - .GetConfigurationAsync(CancellationToken.None); - } - catch (Exception ex) - { - throw new Exception("ReadingOpenIdConnectMetadataFailedError", ex); - } - } + throw new Exception("PostConfigurationNotExecutedError"); } if (oidcOptions.Configuration == null) { - throw new Exception("NoOpenIdConnectMetadataError"); + try + { + oidcOptions.Configuration = await oidcOptions.ConfigurationManager + .GetConfigurationAsync(CancellationToken.None); + } + catch (Exception ex) + { + throw new Exception("ReadingOpenIdConnectMetadataFailedError", ex); + } } } + if (oidcOptions.Configuration == null) + { + throw new Exception("NoOpenIdConnectMetadataError"); + } } } diff --git a/bitwarden_license/src/Sso/Utilities/DynamicAuthenticationSchemeProvider.cs b/bitwarden_license/src/Sso/Utilities/DynamicAuthenticationSchemeProvider.cs index 22f897998..b02e83ded 100644 --- a/bitwarden_license/src/Sso/Utilities/DynamicAuthenticationSchemeProvider.cs +++ b/bitwarden_license/src/Sso/Utilities/DynamicAuthenticationSchemeProvider.cs @@ -18,441 +18,440 @@ using Sustainsys.Saml2.AspNetCore2; using Sustainsys.Saml2.Configuration; using Sustainsys.Saml2.Saml2P; -namespace Bit.Core.Business.Sso +namespace Bit.Core.Business.Sso; + +public class DynamicAuthenticationSchemeProvider : AuthenticationSchemeProvider { - public class DynamicAuthenticationSchemeProvider : AuthenticationSchemeProvider + private readonly IPostConfigureOptions _oidcPostConfigureOptions; + private readonly IExtendedOptionsMonitorCache _extendedOidcOptionsMonitorCache; + private readonly IPostConfigureOptions _saml2PostConfigureOptions; + private readonly IExtendedOptionsMonitorCache _extendedSaml2OptionsMonitorCache; + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly ILogger _logger; + private readonly GlobalSettings _globalSettings; + private readonly SamlEnvironment _samlEnvironment; + private readonly TimeSpan _schemeCacheLifetime; + private readonly Dictionary _cachedSchemes; + private readonly Dictionary _cachedHandlerSchemes; + private readonly SemaphoreSlim _semaphore; + private readonly IHttpContextAccessor _httpContextAccessor; + + private DateTime? _lastSchemeLoad; + private IEnumerable _schemesCopy = Array.Empty(); + private IEnumerable _handlerSchemesCopy = Array.Empty(); + + public DynamicAuthenticationSchemeProvider( + IOptions options, + IPostConfigureOptions oidcPostConfigureOptions, + IOptionsMonitorCache oidcOptionsMonitorCache, + IPostConfigureOptions saml2PostConfigureOptions, + IOptionsMonitorCache saml2OptionsMonitorCache, + ISsoConfigRepository ssoConfigRepository, + ILogger logger, + GlobalSettings globalSettings, + SamlEnvironment samlEnvironment, + IHttpContextAccessor httpContextAccessor) + : base(options) { - private readonly IPostConfigureOptions _oidcPostConfigureOptions; - private readonly IExtendedOptionsMonitorCache _extendedOidcOptionsMonitorCache; - private readonly IPostConfigureOptions _saml2PostConfigureOptions; - private readonly IExtendedOptionsMonitorCache _extendedSaml2OptionsMonitorCache; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly ILogger _logger; - private readonly GlobalSettings _globalSettings; - private readonly SamlEnvironment _samlEnvironment; - private readonly TimeSpan _schemeCacheLifetime; - private readonly Dictionary _cachedSchemes; - private readonly Dictionary _cachedHandlerSchemes; - private readonly SemaphoreSlim _semaphore; - private readonly IHttpContextAccessor _httpContextAccessor; - - private DateTime? _lastSchemeLoad; - private IEnumerable _schemesCopy = Array.Empty(); - private IEnumerable _handlerSchemesCopy = Array.Empty(); - - public DynamicAuthenticationSchemeProvider( - IOptions options, - IPostConfigureOptions oidcPostConfigureOptions, - IOptionsMonitorCache oidcOptionsMonitorCache, - IPostConfigureOptions saml2PostConfigureOptions, - IOptionsMonitorCache saml2OptionsMonitorCache, - ISsoConfigRepository ssoConfigRepository, - ILogger logger, - GlobalSettings globalSettings, - SamlEnvironment samlEnvironment, - IHttpContextAccessor httpContextAccessor) - : base(options) + _oidcPostConfigureOptions = oidcPostConfigureOptions; + _extendedOidcOptionsMonitorCache = oidcOptionsMonitorCache as + IExtendedOptionsMonitorCache; + if (_extendedOidcOptionsMonitorCache == null) { - _oidcPostConfigureOptions = oidcPostConfigureOptions; - _extendedOidcOptionsMonitorCache = oidcOptionsMonitorCache as - IExtendedOptionsMonitorCache; - if (_extendedOidcOptionsMonitorCache == null) - { - throw new ArgumentNullException("_extendedOidcOptionsMonitorCache could not be resolved."); - } - - _saml2PostConfigureOptions = saml2PostConfigureOptions; - _extendedSaml2OptionsMonitorCache = saml2OptionsMonitorCache as - IExtendedOptionsMonitorCache; - if (_extendedSaml2OptionsMonitorCache == null) - { - throw new ArgumentNullException("_extendedSaml2OptionsMonitorCache could not be resolved."); - } - - _ssoConfigRepository = ssoConfigRepository; - _logger = logger; - _globalSettings = globalSettings; - _schemeCacheLifetime = TimeSpan.FromSeconds(_globalSettings.Sso?.CacheLifetimeInSeconds ?? 30); - _samlEnvironment = samlEnvironment; - _cachedSchemes = new Dictionary(); - _cachedHandlerSchemes = new Dictionary(); - _semaphore = new SemaphoreSlim(1); - _httpContextAccessor = httpContextAccessor ?? throw new ArgumentNullException(nameof(httpContextAccessor)); + throw new ArgumentNullException("_extendedOidcOptionsMonitorCache could not be resolved."); } - private bool CacheIsValid + _saml2PostConfigureOptions = saml2PostConfigureOptions; + _extendedSaml2OptionsMonitorCache = saml2OptionsMonitorCache as + IExtendedOptionsMonitorCache; + if (_extendedSaml2OptionsMonitorCache == null) { - get => _lastSchemeLoad.HasValue - && _lastSchemeLoad.Value.Add(_schemeCacheLifetime) >= DateTime.UtcNow; + throw new ArgumentNullException("_extendedSaml2OptionsMonitorCache could not be resolved."); } - public override async Task GetSchemeAsync(string name) + _ssoConfigRepository = ssoConfigRepository; + _logger = logger; + _globalSettings = globalSettings; + _schemeCacheLifetime = TimeSpan.FromSeconds(_globalSettings.Sso?.CacheLifetimeInSeconds ?? 30); + _samlEnvironment = samlEnvironment; + _cachedSchemes = new Dictionary(); + _cachedHandlerSchemes = new Dictionary(); + _semaphore = new SemaphoreSlim(1); + _httpContextAccessor = httpContextAccessor ?? throw new ArgumentNullException(nameof(httpContextAccessor)); + } + + private bool CacheIsValid + { + get => _lastSchemeLoad.HasValue + && _lastSchemeLoad.Value.Add(_schemeCacheLifetime) >= DateTime.UtcNow; + } + + public override async Task GetSchemeAsync(string name) + { + var scheme = await base.GetSchemeAsync(name); + if (scheme != null) { - var scheme = await base.GetSchemeAsync(name); - if (scheme != null) - { - return scheme; - } - - try - { - var dynamicScheme = await GetDynamicSchemeAsync(name); - return dynamicScheme; - } - catch (Exception ex) - { - _logger.LogError(ex, "Unable to load a dynamic authentication scheme for '{0}'", name); - } - - return null; - } - - public override async Task> GetAllSchemesAsync() - { - var existingSchemes = await base.GetAllSchemesAsync(); - var schemes = new List(); - schemes.AddRange(existingSchemes); - - await LoadAllDynamicSchemesIntoCacheAsync(); - schemes.AddRange(_schemesCopy); - - return schemes.ToArray(); - } - - public override async Task> GetRequestHandlerSchemesAsync() - { - var existingSchemes = await base.GetRequestHandlerSchemesAsync(); - var schemes = new List(); - schemes.AddRange(existingSchemes); - - await LoadAllDynamicSchemesIntoCacheAsync(); - schemes.AddRange(_handlerSchemesCopy); - - return schemes.ToArray(); - } - - private async Task LoadAllDynamicSchemesIntoCacheAsync() - { - if (CacheIsValid) - { - // Our cache hasn't expired or been invalidated, ignore request - return; - } - await _semaphore.WaitAsync(); - try - { - if (CacheIsValid) - { - // Just in case (double-checked locking pattern) - return; - } - - // Save time just in case the following operation takes longer - var now = DateTime.UtcNow; - var newSchemes = await _ssoConfigRepository.GetManyByRevisionNotBeforeDate(_lastSchemeLoad); - - foreach (var config in newSchemes) - { - DynamicAuthenticationScheme scheme; - try - { - scheme = GetSchemeFromSsoConfig(config); - } - catch (Exception ex) - { - _logger.LogError(ex, "Error converting configuration to scheme for '{0}'", config.Id); - continue; - } - if (scheme == null) - { - continue; - } - SetSchemeInCache(scheme); - } - - if (newSchemes.Any()) - { - // Maintain "safe" copy for use in enumeration routines - _schemesCopy = _cachedSchemes.Values.ToArray(); - _handlerSchemesCopy = _cachedHandlerSchemes.Values.ToArray(); - } - _lastSchemeLoad = now; - } - finally - { - _semaphore.Release(); - } - } - - private DynamicAuthenticationScheme SetSchemeInCache(DynamicAuthenticationScheme scheme) - { - if (!PostConfigureDynamicScheme(scheme)) - { - return null; - } - _cachedSchemes[scheme.Name] = scheme; - if (typeof(IAuthenticationRequestHandler).IsAssignableFrom(scheme.HandlerType)) - { - _cachedHandlerSchemes[scheme.Name] = scheme; - } return scheme; } - private async Task GetDynamicSchemeAsync(string name) + try { - if (_cachedSchemes.TryGetValue(name, out var cachedScheme)) + var dynamicScheme = await GetDynamicSchemeAsync(name); + return dynamicScheme; + } + catch (Exception ex) + { + _logger.LogError(ex, "Unable to load a dynamic authentication scheme for '{0}'", name); + } + + return null; + } + + public override async Task> GetAllSchemesAsync() + { + var existingSchemes = await base.GetAllSchemesAsync(); + var schemes = new List(); + schemes.AddRange(existingSchemes); + + await LoadAllDynamicSchemesIntoCacheAsync(); + schemes.AddRange(_schemesCopy); + + return schemes.ToArray(); + } + + public override async Task> GetRequestHandlerSchemesAsync() + { + var existingSchemes = await base.GetRequestHandlerSchemesAsync(); + var schemes = new List(); + schemes.AddRange(existingSchemes); + + await LoadAllDynamicSchemesIntoCacheAsync(); + schemes.AddRange(_handlerSchemesCopy); + + return schemes.ToArray(); + } + + private async Task LoadAllDynamicSchemesIntoCacheAsync() + { + if (CacheIsValid) + { + // Our cache hasn't expired or been invalidated, ignore request + return; + } + await _semaphore.WaitAsync(); + try + { + if (CacheIsValid) { - return cachedScheme; + // Just in case (double-checked locking pattern) + return; } - var scheme = await GetSchemeFromSsoConfigAsync(name); + // Save time just in case the following operation takes longer + var now = DateTime.UtcNow; + var newSchemes = await _ssoConfigRepository.GetManyByRevisionNotBeforeDate(_lastSchemeLoad); + + foreach (var config in newSchemes) + { + DynamicAuthenticationScheme scheme; + try + { + scheme = GetSchemeFromSsoConfig(config); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error converting configuration to scheme for '{0}'", config.Id); + continue; + } + if (scheme == null) + { + continue; + } + SetSchemeInCache(scheme); + } + + if (newSchemes.Any()) + { + // Maintain "safe" copy for use in enumeration routines + _schemesCopy = _cachedSchemes.Values.ToArray(); + _handlerSchemesCopy = _cachedHandlerSchemes.Values.ToArray(); + } + _lastSchemeLoad = now; + } + finally + { + _semaphore.Release(); + } + } + + private DynamicAuthenticationScheme SetSchemeInCache(DynamicAuthenticationScheme scheme) + { + if (!PostConfigureDynamicScheme(scheme)) + { + return null; + } + _cachedSchemes[scheme.Name] = scheme; + if (typeof(IAuthenticationRequestHandler).IsAssignableFrom(scheme.HandlerType)) + { + _cachedHandlerSchemes[scheme.Name] = scheme; + } + return scheme; + } + + private async Task GetDynamicSchemeAsync(string name) + { + if (_cachedSchemes.TryGetValue(name, out var cachedScheme)) + { + return cachedScheme; + } + + var scheme = await GetSchemeFromSsoConfigAsync(name); + if (scheme == null) + { + return null; + } + + await _semaphore.WaitAsync(); + try + { + scheme = SetSchemeInCache(scheme); if (scheme == null) { return null; } - await _semaphore.WaitAsync(); - try + if (typeof(IAuthenticationRequestHandler).IsAssignableFrom(scheme.HandlerType)) { - scheme = SetSchemeInCache(scheme); - if (scheme == null) - { - return null; - } + _handlerSchemesCopy = _cachedHandlerSchemes.Values.ToArray(); + } + _schemesCopy = _cachedSchemes.Values.ToArray(); + } + finally + { + // Note: _lastSchemeLoad is not set here, this is a one-off + // and should not impact loading further cache updates + _semaphore.Release(); + } + return scheme; + } - if (typeof(IAuthenticationRequestHandler).IsAssignableFrom(scheme.HandlerType)) - { - _handlerSchemesCopy = _cachedHandlerSchemes.Values.ToArray(); - } - _schemesCopy = _cachedSchemes.Values.ToArray(); - } - finally + private bool PostConfigureDynamicScheme(DynamicAuthenticationScheme scheme) + { + try + { + if (scheme.SsoType == SsoType.OpenIdConnect && scheme.Options is OpenIdConnectOptions oidcOptions) { - // Note: _lastSchemeLoad is not set here, this is a one-off - // and should not impact loading further cache updates - _semaphore.Release(); + _oidcPostConfigureOptions.PostConfigure(scheme.Name, oidcOptions); + _extendedOidcOptionsMonitorCache.AddOrUpdate(scheme.Name, oidcOptions); } - return scheme; + else if (scheme.SsoType == SsoType.Saml2 && scheme.Options is Saml2Options saml2Options) + { + _saml2PostConfigureOptions.PostConfigure(scheme.Name, saml2Options); + _extendedSaml2OptionsMonitorCache.AddOrUpdate(scheme.Name, saml2Options); + } + return true; + } + catch (Exception ex) + { + _logger.LogError(ex, "Error performing post configuration for '{0}' ({1})", + scheme.Name, scheme.DisplayName); + } + return false; + } + + private DynamicAuthenticationScheme GetSchemeFromSsoConfig(SsoConfig config) + { + var data = config.GetData(); + return data.ConfigType switch + { + SsoType.OpenIdConnect => GetOidcAuthenticationScheme(config.OrganizationId.ToString(), data), + SsoType.Saml2 => GetSaml2AuthenticationScheme(config.OrganizationId.ToString(), data), + _ => throw new Exception($"SSO Config Type, '{data.ConfigType}', not supported"), + }; + } + + private async Task GetSchemeFromSsoConfigAsync(string name) + { + if (!Guid.TryParse(name, out var organizationId)) + { + _logger.LogWarning("Could not determine organization id from name, '{0}'", name); + return null; + } + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organizationId); + if (ssoConfig == null || !ssoConfig.Enabled) + { + _logger.LogWarning("Could not find SSO config or config was not enabled for '{0}'", name); + return null; } - private bool PostConfigureDynamicScheme(DynamicAuthenticationScheme scheme) + return GetSchemeFromSsoConfig(ssoConfig); + } + + private DynamicAuthenticationScheme GetOidcAuthenticationScheme(string name, SsoConfigurationData config) + { + var oidcOptions = new OpenIdConnectOptions { - try + Authority = config.Authority, + ClientId = config.ClientId, + ClientSecret = config.ClientSecret, + ResponseType = "code", + ResponseMode = "form_post", + SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme, + SignOutScheme = IdentityServerConstants.SignoutScheme, + SaveTokens = false, // reduce overall request size + TokenValidationParameters = new TokenValidationParameters { - if (scheme.SsoType == SsoType.OpenIdConnect && scheme.Options is OpenIdConnectOptions oidcOptions) - { - _oidcPostConfigureOptions.PostConfigure(scheme.Name, oidcOptions); - _extendedOidcOptionsMonitorCache.AddOrUpdate(scheme.Name, oidcOptions); - } - else if (scheme.SsoType == SsoType.Saml2 && scheme.Options is Saml2Options saml2Options) - { - _saml2PostConfigureOptions.PostConfigure(scheme.Name, saml2Options); - _extendedSaml2OptionsMonitorCache.AddOrUpdate(scheme.Name, saml2Options); - } - return true; - } - catch (Exception ex) - { - _logger.LogError(ex, "Error performing post configuration for '{0}' ({1})", - scheme.Name, scheme.DisplayName); - } - return false; + NameClaimType = JwtClaimTypes.Name, + RoleClaimType = JwtClaimTypes.Role, + }, + CallbackPath = SsoConfigurationData.BuildCallbackPath(), + SignedOutCallbackPath = SsoConfigurationData.BuildSignedOutCallbackPath(), + MetadataAddress = config.MetadataAddress, + // Prevents URLs that go beyond 1024 characters which may break for some servers + AuthenticationMethod = config.RedirectBehavior, + GetClaimsFromUserInfoEndpoint = config.GetClaimsFromUserInfoEndpoint, + }; + oidcOptions.Scope + .AddIfNotExists(OpenIdConnectScopes.OpenId) + .AddIfNotExists(OpenIdConnectScopes.Email) + .AddIfNotExists(OpenIdConnectScopes.Profile); + foreach (var scope in config.GetAdditionalScopes()) + { + oidcOptions.Scope.AddIfNotExists(scope); + } + if (!string.IsNullOrWhiteSpace(config.ExpectedReturnAcrValue)) + { + oidcOptions.Scope.AddIfNotExists(OpenIdConnectScopes.Acr); } - private DynamicAuthenticationScheme GetSchemeFromSsoConfig(SsoConfig config) + oidcOptions.StateDataFormat = new DistributedCacheStateDataFormatter(_httpContextAccessor, name); + + // see: https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest (acr_values) + if (!string.IsNullOrWhiteSpace(config.AcrValues)) { - var data = config.GetData(); - return data.ConfigType switch + oidcOptions.Events ??= new OpenIdConnectEvents(); + oidcOptions.Events.OnRedirectToIdentityProvider = ctx => { - SsoType.OpenIdConnect => GetOidcAuthenticationScheme(config.OrganizationId.ToString(), data), - SsoType.Saml2 => GetSaml2AuthenticationScheme(config.OrganizationId.ToString(), data), - _ => throw new Exception($"SSO Config Type, '{data.ConfigType}', not supported"), + ctx.ProtocolMessage.AcrValues = config.AcrValues; + return Task.CompletedTask; }; } - private async Task GetSchemeFromSsoConfigAsync(string name) - { - if (!Guid.TryParse(name, out var organizationId)) - { - _logger.LogWarning("Could not determine organization id from name, '{0}'", name); - return null; - } - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organizationId); - if (ssoConfig == null || !ssoConfig.Enabled) - { - _logger.LogWarning("Could not find SSO config or config was not enabled for '{0}'", name); - return null; - } + return new DynamicAuthenticationScheme(name, name, typeof(OpenIdConnectHandler), + oidcOptions, SsoType.OpenIdConnect); + } - return GetSchemeFromSsoConfig(ssoConfig); + private DynamicAuthenticationScheme GetSaml2AuthenticationScheme(string name, SsoConfigurationData config) + { + if (_samlEnvironment == null) + { + throw new Exception($"SSO SAML2 Service Provider profile is missing for {name}"); } - private DynamicAuthenticationScheme GetOidcAuthenticationScheme(string name, SsoConfigurationData config) + var spEntityId = new Sustainsys.Saml2.Metadata.EntityId( + SsoConfigurationData.BuildSaml2ModulePath(_globalSettings.BaseServiceUri.Sso)); + bool? allowCreate = null; + if (config.SpNameIdFormat != Saml2NameIdFormat.Transient) { - var oidcOptions = new OpenIdConnectOptions - { - Authority = config.Authority, - ClientId = config.ClientId, - ClientSecret = config.ClientSecret, - ResponseType = "code", - ResponseMode = "form_post", - SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme, - SignOutScheme = IdentityServerConstants.SignoutScheme, - SaveTokens = false, // reduce overall request size - TokenValidationParameters = new TokenValidationParameters - { - NameClaimType = JwtClaimTypes.Name, - RoleClaimType = JwtClaimTypes.Role, - }, - CallbackPath = SsoConfigurationData.BuildCallbackPath(), - SignedOutCallbackPath = SsoConfigurationData.BuildSignedOutCallbackPath(), - MetadataAddress = config.MetadataAddress, - // Prevents URLs that go beyond 1024 characters which may break for some servers - AuthenticationMethod = config.RedirectBehavior, - GetClaimsFromUserInfoEndpoint = config.GetClaimsFromUserInfoEndpoint, - }; - oidcOptions.Scope - .AddIfNotExists(OpenIdConnectScopes.OpenId) - .AddIfNotExists(OpenIdConnectScopes.Email) - .AddIfNotExists(OpenIdConnectScopes.Profile); - foreach (var scope in config.GetAdditionalScopes()) - { - oidcOptions.Scope.AddIfNotExists(scope); - } - if (!string.IsNullOrWhiteSpace(config.ExpectedReturnAcrValue)) - { - oidcOptions.Scope.AddIfNotExists(OpenIdConnectScopes.Acr); - } - - oidcOptions.StateDataFormat = new DistributedCacheStateDataFormatter(_httpContextAccessor, name); - - // see: https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest (acr_values) - if (!string.IsNullOrWhiteSpace(config.AcrValues)) - { - oidcOptions.Events ??= new OpenIdConnectEvents(); - oidcOptions.Events.OnRedirectToIdentityProvider = ctx => - { - ctx.ProtocolMessage.AcrValues = config.AcrValues; - return Task.CompletedTask; - }; - } - - return new DynamicAuthenticationScheme(name, name, typeof(OpenIdConnectHandler), - oidcOptions, SsoType.OpenIdConnect); + allowCreate = true; + } + var spOptions = new SPOptions + { + EntityId = spEntityId, + ModulePath = SsoConfigurationData.BuildSaml2ModulePath(null, name), + NameIdPolicy = new Saml2NameIdPolicy(allowCreate, GetNameIdFormat(config.SpNameIdFormat)), + WantAssertionsSigned = config.SpWantAssertionsSigned, + AuthenticateRequestSigningBehavior = GetSigningBehavior(config.SpSigningBehavior), + ValidateCertificates = config.SpValidateCertificates, + }; + if (!string.IsNullOrWhiteSpace(config.SpMinIncomingSigningAlgorithm)) + { + spOptions.MinIncomingSigningAlgorithm = config.SpMinIncomingSigningAlgorithm; + } + if (!string.IsNullOrWhiteSpace(config.SpOutboundSigningAlgorithm)) + { + spOptions.OutboundSigningAlgorithm = config.SpOutboundSigningAlgorithm; + } + if (_samlEnvironment.SpSigningCertificate != null) + { + spOptions.ServiceCertificates.Add(_samlEnvironment.SpSigningCertificate); } - private DynamicAuthenticationScheme GetSaml2AuthenticationScheme(string name, SsoConfigurationData config) + var idpEntityId = new Sustainsys.Saml2.Metadata.EntityId(config.IdpEntityId); + var idp = new Sustainsys.Saml2.IdentityProvider(idpEntityId, spOptions) { - if (_samlEnvironment == null) - { - throw new Exception($"SSO SAML2 Service Provider profile is missing for {name}"); - } - - var spEntityId = new Sustainsys.Saml2.Metadata.EntityId( - SsoConfigurationData.BuildSaml2ModulePath(_globalSettings.BaseServiceUri.Sso)); - bool? allowCreate = null; - if (config.SpNameIdFormat != Saml2NameIdFormat.Transient) - { - allowCreate = true; - } - var spOptions = new SPOptions - { - EntityId = spEntityId, - ModulePath = SsoConfigurationData.BuildSaml2ModulePath(null, name), - NameIdPolicy = new Saml2NameIdPolicy(allowCreate, GetNameIdFormat(config.SpNameIdFormat)), - WantAssertionsSigned = config.SpWantAssertionsSigned, - AuthenticateRequestSigningBehavior = GetSigningBehavior(config.SpSigningBehavior), - ValidateCertificates = config.SpValidateCertificates, - }; - if (!string.IsNullOrWhiteSpace(config.SpMinIncomingSigningAlgorithm)) - { - spOptions.MinIncomingSigningAlgorithm = config.SpMinIncomingSigningAlgorithm; - } - if (!string.IsNullOrWhiteSpace(config.SpOutboundSigningAlgorithm)) - { - spOptions.OutboundSigningAlgorithm = config.SpOutboundSigningAlgorithm; - } - if (_samlEnvironment.SpSigningCertificate != null) - { - spOptions.ServiceCertificates.Add(_samlEnvironment.SpSigningCertificate); - } - - var idpEntityId = new Sustainsys.Saml2.Metadata.EntityId(config.IdpEntityId); - var idp = new Sustainsys.Saml2.IdentityProvider(idpEntityId, spOptions) - { - Binding = GetBindingType(config.IdpBindingType), - AllowUnsolicitedAuthnResponse = config.IdpAllowUnsolicitedAuthnResponse, - DisableOutboundLogoutRequests = config.IdpDisableOutboundLogoutRequests, - WantAuthnRequestsSigned = config.IdpWantAuthnRequestsSigned, - }; - if (!string.IsNullOrWhiteSpace(config.IdpSingleSignOnServiceUrl)) - { - idp.SingleSignOnServiceUrl = new Uri(config.IdpSingleSignOnServiceUrl); - } - if (!string.IsNullOrWhiteSpace(config.IdpSingleLogoutServiceUrl)) - { - idp.SingleLogoutServiceUrl = new Uri(config.IdpSingleLogoutServiceUrl); - } - if (!string.IsNullOrWhiteSpace(config.IdpOutboundSigningAlgorithm)) - { - idp.OutboundSigningAlgorithm = config.IdpOutboundSigningAlgorithm; - } - if (!string.IsNullOrWhiteSpace(config.IdpX509PublicCert)) - { - var cert = CoreHelpers.Base64UrlDecode(config.IdpX509PublicCert); - idp.SigningKeys.AddConfiguredKey(new X509Certificate2(cert)); - } - idp.ArtifactResolutionServiceUrls.Clear(); - // This must happen last since it calls Validate() internally. - idp.LoadMetadata = false; - - var options = new Saml2Options - { - SPOptions = spOptions, - SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme, - SignOutScheme = IdentityServerConstants.DefaultCookieAuthenticationScheme, - CookieManager = new IdentityServer.DistributedCacheCookieManager(), - }; - options.IdentityProviders.Add(idp); - - return new DynamicAuthenticationScheme(name, name, typeof(Saml2Handler), options, SsoType.Saml2); - } - - private NameIdFormat GetNameIdFormat(Saml2NameIdFormat format) + Binding = GetBindingType(config.IdpBindingType), + AllowUnsolicitedAuthnResponse = config.IdpAllowUnsolicitedAuthnResponse, + DisableOutboundLogoutRequests = config.IdpDisableOutboundLogoutRequests, + WantAuthnRequestsSigned = config.IdpWantAuthnRequestsSigned, + }; + if (!string.IsNullOrWhiteSpace(config.IdpSingleSignOnServiceUrl)) { - return format switch - { - Saml2NameIdFormat.Unspecified => NameIdFormat.Unspecified, - Saml2NameIdFormat.EmailAddress => NameIdFormat.EmailAddress, - Saml2NameIdFormat.X509SubjectName => NameIdFormat.X509SubjectName, - Saml2NameIdFormat.WindowsDomainQualifiedName => NameIdFormat.WindowsDomainQualifiedName, - Saml2NameIdFormat.KerberosPrincipalName => NameIdFormat.KerberosPrincipalName, - Saml2NameIdFormat.EntityIdentifier => NameIdFormat.EntityIdentifier, - Saml2NameIdFormat.Persistent => NameIdFormat.Persistent, - Saml2NameIdFormat.Transient => NameIdFormat.Transient, - _ => NameIdFormat.NotConfigured, - }; + idp.SingleSignOnServiceUrl = new Uri(config.IdpSingleSignOnServiceUrl); } - - private SigningBehavior GetSigningBehavior(Saml2SigningBehavior behavior) + if (!string.IsNullOrWhiteSpace(config.IdpSingleLogoutServiceUrl)) { - return behavior switch - { - Saml2SigningBehavior.IfIdpWantAuthnRequestsSigned => SigningBehavior.IfIdpWantAuthnRequestsSigned, - Saml2SigningBehavior.Always => SigningBehavior.Always, - Saml2SigningBehavior.Never => SigningBehavior.Never, - _ => SigningBehavior.IfIdpWantAuthnRequestsSigned, - }; + idp.SingleLogoutServiceUrl = new Uri(config.IdpSingleLogoutServiceUrl); } - - private Sustainsys.Saml2.WebSso.Saml2BindingType GetBindingType(Saml2BindingType bindingType) + if (!string.IsNullOrWhiteSpace(config.IdpOutboundSigningAlgorithm)) { - return bindingType switch - { - Saml2BindingType.HttpRedirect => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpRedirect, - Saml2BindingType.HttpPost => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpPost, - _ => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpPost, - }; + idp.OutboundSigningAlgorithm = config.IdpOutboundSigningAlgorithm; } + if (!string.IsNullOrWhiteSpace(config.IdpX509PublicCert)) + { + var cert = CoreHelpers.Base64UrlDecode(config.IdpX509PublicCert); + idp.SigningKeys.AddConfiguredKey(new X509Certificate2(cert)); + } + idp.ArtifactResolutionServiceUrls.Clear(); + // This must happen last since it calls Validate() internally. + idp.LoadMetadata = false; + + var options = new Saml2Options + { + SPOptions = spOptions, + SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme, + SignOutScheme = IdentityServerConstants.DefaultCookieAuthenticationScheme, + CookieManager = new IdentityServer.DistributedCacheCookieManager(), + }; + options.IdentityProviders.Add(idp); + + return new DynamicAuthenticationScheme(name, name, typeof(Saml2Handler), options, SsoType.Saml2); + } + + private NameIdFormat GetNameIdFormat(Saml2NameIdFormat format) + { + return format switch + { + Saml2NameIdFormat.Unspecified => NameIdFormat.Unspecified, + Saml2NameIdFormat.EmailAddress => NameIdFormat.EmailAddress, + Saml2NameIdFormat.X509SubjectName => NameIdFormat.X509SubjectName, + Saml2NameIdFormat.WindowsDomainQualifiedName => NameIdFormat.WindowsDomainQualifiedName, + Saml2NameIdFormat.KerberosPrincipalName => NameIdFormat.KerberosPrincipalName, + Saml2NameIdFormat.EntityIdentifier => NameIdFormat.EntityIdentifier, + Saml2NameIdFormat.Persistent => NameIdFormat.Persistent, + Saml2NameIdFormat.Transient => NameIdFormat.Transient, + _ => NameIdFormat.NotConfigured, + }; + } + + private SigningBehavior GetSigningBehavior(Saml2SigningBehavior behavior) + { + return behavior switch + { + Saml2SigningBehavior.IfIdpWantAuthnRequestsSigned => SigningBehavior.IfIdpWantAuthnRequestsSigned, + Saml2SigningBehavior.Always => SigningBehavior.Always, + Saml2SigningBehavior.Never => SigningBehavior.Never, + _ => SigningBehavior.IfIdpWantAuthnRequestsSigned, + }; + } + + private Sustainsys.Saml2.WebSso.Saml2BindingType GetBindingType(Saml2BindingType bindingType) + { + return bindingType switch + { + Saml2BindingType.HttpRedirect => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpRedirect, + Saml2BindingType.HttpPost => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpPost, + _ => Sustainsys.Saml2.WebSso.Saml2BindingType.HttpPost, + }; } } diff --git a/bitwarden_license/src/Sso/Utilities/ExtendedOptionsMonitorCache.cs b/bitwarden_license/src/Sso/Utilities/ExtendedOptionsMonitorCache.cs index 8e23e1f07..083417f25 100644 --- a/bitwarden_license/src/Sso/Utilities/ExtendedOptionsMonitorCache.cs +++ b/bitwarden_license/src/Sso/Utilities/ExtendedOptionsMonitorCache.cs @@ -1,37 +1,36 @@ using System.Collections.Concurrent; using Microsoft.Extensions.Options; -namespace Bit.Sso.Utilities +namespace Bit.Sso.Utilities; + +public class ExtendedOptionsMonitorCache : IExtendedOptionsMonitorCache where TOptions : class { - public class ExtendedOptionsMonitorCache : IExtendedOptionsMonitorCache where TOptions : class + private readonly ConcurrentDictionary> _cache = + new ConcurrentDictionary>(StringComparer.Ordinal); + + public void AddOrUpdate(string name, TOptions options) { - private readonly ConcurrentDictionary> _cache = - new ConcurrentDictionary>(StringComparer.Ordinal); + _cache.AddOrUpdate(name ?? Options.DefaultName, new Lazy(() => options), + (string s, Lazy lazy) => new Lazy(() => options)); + } - public void AddOrUpdate(string name, TOptions options) - { - _cache.AddOrUpdate(name ?? Options.DefaultName, new Lazy(() => options), - (string s, Lazy lazy) => new Lazy(() => options)); - } + public void Clear() + { + _cache.Clear(); + } - public void Clear() - { - _cache.Clear(); - } + public TOptions GetOrAdd(string name, Func createOptions) + { + return _cache.GetOrAdd(name ?? Options.DefaultName, new Lazy(createOptions)).Value; + } - public TOptions GetOrAdd(string name, Func createOptions) - { - return _cache.GetOrAdd(name ?? Options.DefaultName, new Lazy(createOptions)).Value; - } + public bool TryAdd(string name, TOptions options) + { + return _cache.TryAdd(name ?? Options.DefaultName, new Lazy(() => options)); + } - public bool TryAdd(string name, TOptions options) - { - return _cache.TryAdd(name ?? Options.DefaultName, new Lazy(() => options)); - } - - public bool TryRemove(string name) - { - return _cache.TryRemove(name ?? Options.DefaultName, out _); - } + public bool TryRemove(string name) + { + return _cache.TryRemove(name ?? Options.DefaultName, out _); } } diff --git a/bitwarden_license/src/Sso/Utilities/IDynamicAuthenticationScheme.cs b/bitwarden_license/src/Sso/Utilities/IDynamicAuthenticationScheme.cs index 7deab5440..9ebd0f9cf 100644 --- a/bitwarden_license/src/Sso/Utilities/IDynamicAuthenticationScheme.cs +++ b/bitwarden_license/src/Sso/Utilities/IDynamicAuthenticationScheme.cs @@ -1,13 +1,12 @@ using Bit.Core.Enums; using Microsoft.AspNetCore.Authentication; -namespace Bit.Sso.Utilities -{ - public interface IDynamicAuthenticationScheme - { - AuthenticationSchemeOptions Options { get; set; } - SsoType SsoType { get; set; } +namespace Bit.Sso.Utilities; - Task Validate(); - } +public interface IDynamicAuthenticationScheme +{ + AuthenticationSchemeOptions Options { get; set; } + SsoType SsoType { get; set; } + + Task Validate(); } diff --git a/bitwarden_license/src/Sso/Utilities/IExtendedOptionsMonitorCache.cs b/bitwarden_license/src/Sso/Utilities/IExtendedOptionsMonitorCache.cs index 73a5352a8..0f6284318 100644 --- a/bitwarden_license/src/Sso/Utilities/IExtendedOptionsMonitorCache.cs +++ b/bitwarden_license/src/Sso/Utilities/IExtendedOptionsMonitorCache.cs @@ -1,9 +1,8 @@ using Microsoft.Extensions.Options; -namespace Bit.Sso.Utilities +namespace Bit.Sso.Utilities; + +public interface IExtendedOptionsMonitorCache : IOptionsMonitorCache where TOptions : class { - public interface IExtendedOptionsMonitorCache : IOptionsMonitorCache where TOptions : class - { - void AddOrUpdate(string name, TOptions options); - } + void AddOrUpdate(string name, TOptions options); } diff --git a/bitwarden_license/src/Sso/Utilities/OpenIdConnectOptionsExtensions.cs b/bitwarden_license/src/Sso/Utilities/OpenIdConnectOptionsExtensions.cs index e01ff7111..9221877a0 100644 --- a/bitwarden_license/src/Sso/Utilities/OpenIdConnectOptionsExtensions.cs +++ b/bitwarden_license/src/Sso/Utilities/OpenIdConnectOptionsExtensions.cs @@ -1,63 +1,62 @@ using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Microsoft.IdentityModel.Protocols.OpenIdConnect; -namespace Bit.Sso.Utilities +namespace Bit.Sso.Utilities; + +public static class OpenIdConnectOptionsExtensions { - public static class OpenIdConnectOptionsExtensions + public static async Task CouldHandleAsync(this OpenIdConnectOptions options, string scheme, HttpContext context) { - public static async Task CouldHandleAsync(this OpenIdConnectOptions options, string scheme, HttpContext context) + // Determine this is a valid request for our handler + if (options.CallbackPath != context.Request.Path && + options.RemoteSignOutPath != context.Request.Path && + options.SignedOutCallbackPath != context.Request.Path) { - // Determine this is a valid request for our handler - if (options.CallbackPath != context.Request.Path && - options.RemoteSignOutPath != context.Request.Path && - options.SignedOutCallbackPath != context.Request.Path) - { - return false; - } - - if (context.Request.Query["scheme"].FirstOrDefault() == scheme) - { - return true; - } - - try - { - // Parse out the message - OpenIdConnectMessage message = null; - if (string.Equals(context.Request.Method, "GET", StringComparison.OrdinalIgnoreCase)) - { - message = new OpenIdConnectMessage(context.Request.Query.Select(pair => new KeyValuePair(pair.Key, pair.Value))); - } - else if (string.Equals(context.Request.Method, "POST", StringComparison.OrdinalIgnoreCase) && - !string.IsNullOrEmpty(context.Request.ContentType) && - context.Request.ContentType.StartsWith("application/x-www-form-urlencoded", StringComparison.OrdinalIgnoreCase) && - context.Request.Body.CanRead) - { - var form = await context.Request.ReadFormAsync(); - message = new OpenIdConnectMessage(form.Select(pair => new KeyValuePair(pair.Key, pair.Value))); - } - - var state = message?.State; - if (string.IsNullOrWhiteSpace(state)) - { - // State is required, it will fail later on for this reason. - return false; - } - - // Handle State if we've gotten that back - var decodedState = options.StateDataFormat.Unprotect(state); - if (decodedState != null && decodedState.Items.ContainsKey("scheme")) - { - return decodedState.Items["scheme"] == scheme; - } - } - catch - { - return false; - } - - // This is likely not an appropriate handler return false; } + + if (context.Request.Query["scheme"].FirstOrDefault() == scheme) + { + return true; + } + + try + { + // Parse out the message + OpenIdConnectMessage message = null; + if (string.Equals(context.Request.Method, "GET", StringComparison.OrdinalIgnoreCase)) + { + message = new OpenIdConnectMessage(context.Request.Query.Select(pair => new KeyValuePair(pair.Key, pair.Value))); + } + else if (string.Equals(context.Request.Method, "POST", StringComparison.OrdinalIgnoreCase) && + !string.IsNullOrEmpty(context.Request.ContentType) && + context.Request.ContentType.StartsWith("application/x-www-form-urlencoded", StringComparison.OrdinalIgnoreCase) && + context.Request.Body.CanRead) + { + var form = await context.Request.ReadFormAsync(); + message = new OpenIdConnectMessage(form.Select(pair => new KeyValuePair(pair.Key, pair.Value))); + } + + var state = message?.State; + if (string.IsNullOrWhiteSpace(state)) + { + // State is required, it will fail later on for this reason. + return false; + } + + // Handle State if we've gotten that back + var decodedState = options.StateDataFormat.Unprotect(state); + if (decodedState != null && decodedState.Items.ContainsKey("scheme")) + { + return decodedState.Items["scheme"] == scheme; + } + } + catch + { + return false; + } + + // This is likely not an appropriate handler + return false; } } diff --git a/bitwarden_license/src/Sso/Utilities/OpenIdConnectScopes.cs b/bitwarden_license/src/Sso/Utilities/OpenIdConnectScopes.cs index 983ce8b33..3fae7ce4e 100644 --- a/bitwarden_license/src/Sso/Utilities/OpenIdConnectScopes.cs +++ b/bitwarden_license/src/Sso/Utilities/OpenIdConnectScopes.cs @@ -1,64 +1,63 @@ -namespace Bit.Sso.Utilities +namespace Bit.Sso.Utilities; + +/// +/// OpenID Connect Clients use scope values as defined in 3.3 of OAuth 2.0 +/// [RFC6749]. These values represent the standard scope values supported +/// by OAuth 2.0 and therefore OIDC. +/// +/// +/// See: https://openid.net/specs/openid-connect-basic-1_0.html#Scopes +/// +public static class OpenIdConnectScopes { /// - /// OpenID Connect Clients use scope values as defined in 3.3 of OAuth 2.0 - /// [RFC6749]. These values represent the standard scope values supported - /// by OAuth 2.0 and therefore OIDC. + /// REQUIRED. Informs the Authorization Server that the Client is making + /// an OpenID Connect request. If the openid scope value is not present, + /// the behavior is entirely unspecified. + /// + public const string OpenId = "openid"; + + /// + /// OPTIONAL. This scope value requests access to the End-User's default + /// profile Claims, which are: name, family_name, given_name, + /// middle_name, nickname, preferred_username, profile, picture, + /// website, gender, birthdate, zoneinfo, locale, and updated_at. + /// + public const string Profile = "profile"; + + /// + /// OPTIONAL. This scope value requests access to the email and + /// email_verified Claims. + /// + public const string Email = "email"; + + /// + /// OPTIONAL. This scope value requests access to the address Claim. + /// + public const string Address = "address"; + + /// + /// OPTIONAL. This scope value requests access to the phone_number and + /// phone_number_verified Claims. + /// + public const string Phone = "phone"; + + /// + /// OPTIONAL. This scope value requests that an OAuth 2.0 Refresh Token + /// be issued that can be used to obtain an Access Token that grants + /// access to the End-User's UserInfo Endpoint even when the End-User is + /// not present (not logged in). + /// + public const string OfflineAccess = "offline_access"; + + /// + /// OPTIONAL. Authentication Context Class Reference. String specifying + /// an Authentication Context Class Reference value that identifies the + /// Authentication Context Class that the authentication performed + /// satisfied. /// /// - /// See: https://openid.net/specs/openid-connect-basic-1_0.html#Scopes + /// See: https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.2 /// - public static class OpenIdConnectScopes - { - /// - /// REQUIRED. Informs the Authorization Server that the Client is making - /// an OpenID Connect request. If the openid scope value is not present, - /// the behavior is entirely unspecified. - /// - public const string OpenId = "openid"; - - /// - /// OPTIONAL. This scope value requests access to the End-User's default - /// profile Claims, which are: name, family_name, given_name, - /// middle_name, nickname, preferred_username, profile, picture, - /// website, gender, birthdate, zoneinfo, locale, and updated_at. - /// - public const string Profile = "profile"; - - /// - /// OPTIONAL. This scope value requests access to the email and - /// email_verified Claims. - /// - public const string Email = "email"; - - /// - /// OPTIONAL. This scope value requests access to the address Claim. - /// - public const string Address = "address"; - - /// - /// OPTIONAL. This scope value requests access to the phone_number and - /// phone_number_verified Claims. - /// - public const string Phone = "phone"; - - /// - /// OPTIONAL. This scope value requests that an OAuth 2.0 Refresh Token - /// be issued that can be used to obtain an Access Token that grants - /// access to the End-User's UserInfo Endpoint even when the End-User is - /// not present (not logged in). - /// - public const string OfflineAccess = "offline_access"; - - /// - /// OPTIONAL. Authentication Context Class Reference. String specifying - /// an Authentication Context Class Reference value that identifies the - /// Authentication Context Class that the authentication performed - /// satisfied. - /// - /// - /// See: https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.2 - /// - public const string Acr = "acr"; - } + public const string Acr = "acr"; } diff --git a/bitwarden_license/src/Sso/Utilities/Saml2OptionsExtensions.cs b/bitwarden_license/src/Sso/Utilities/Saml2OptionsExtensions.cs index 9d4870bd7..46a75ca5c 100644 --- a/bitwarden_license/src/Sso/Utilities/Saml2OptionsExtensions.cs +++ b/bitwarden_license/src/Sso/Utilities/Saml2OptionsExtensions.cs @@ -4,102 +4,101 @@ using System.Xml; using Sustainsys.Saml2; using Sustainsys.Saml2.AspNetCore2; -namespace Bit.Sso.Utilities +namespace Bit.Sso.Utilities; + +public static class Saml2OptionsExtensions { - public static class Saml2OptionsExtensions + public static async Task CouldHandleAsync(this Saml2Options options, string scheme, HttpContext context) { - public static async Task CouldHandleAsync(this Saml2Options options, string scheme, HttpContext context) + // Determine this is a valid request for our handler + if (!context.Request.Path.StartsWithSegments(options.SPOptions.ModulePath, StringComparison.Ordinal)) { - // Determine this is a valid request for our handler - if (!context.Request.Path.StartsWithSegments(options.SPOptions.ModulePath, StringComparison.Ordinal)) - { - return false; - } + return false; + } - var idp = options.IdentityProviders.IsEmpty ? null : options.IdentityProviders.Default; - if (idp == null) - { - return false; - } - - if (context.Request.Query["scheme"].FirstOrDefault() == scheme) - { - return true; - } - - // We need to pull out and parse the response or request SAML envelope - XmlElement envelope = null; - try - { - if (string.Equals(context.Request.Method, "POST", StringComparison.OrdinalIgnoreCase) && - context.Request.HasFormContentType) - { - string encodedMessage; - if (context.Request.Form.TryGetValue("SAMLResponse", out var response)) - { - encodedMessage = response.FirstOrDefault(); - } - else - { - encodedMessage = context.Request.Form["SAMLRequest"]; - } - if (string.IsNullOrWhiteSpace(encodedMessage)) - { - return false; - } - envelope = XmlHelpers.XmlDocumentFromString( - Encoding.UTF8.GetString(Convert.FromBase64String(encodedMessage)))?.DocumentElement; - } - else if (string.Equals(context.Request.Method, "GET", StringComparison.OrdinalIgnoreCase)) - { - var encodedPayload = context.Request.Query["SAMLRequest"].FirstOrDefault() ?? - context.Request.Query["SAMLResponse"].FirstOrDefault(); - try - { - var payload = Convert.FromBase64String(encodedPayload); - using var compressed = new MemoryStream(payload); - using var decompressedStream = new DeflateStream(compressed, CompressionMode.Decompress, true); - using var deCompressed = new MemoryStream(); - await decompressedStream.CopyToAsync(deCompressed); - - envelope = XmlHelpers.XmlDocumentFromString( - Encoding.UTF8.GetString(deCompressed.GetBuffer(), 0, (int)deCompressed.Length))?.DocumentElement; - } - catch (FormatException ex) - { - throw new FormatException($"\'{encodedPayload}\' is not a valid Base64 encoded string: {ex.Message}", ex); - } - } - } - catch - { - return false; - } - - if (envelope == null) - { - return false; - } - - // Double check the entity Ids - var entityId = envelope["Issuer", Saml2Namespaces.Saml2Name]?.InnerText.Trim(); - if (!string.Equals(entityId, idp.EntityId.Id, StringComparison.InvariantCultureIgnoreCase)) - { - return false; - } - - if (options.SPOptions.WantAssertionsSigned) - { - var assertion = envelope["Assertion", Saml2Namespaces.Saml2Name]; - var isAssertionSigned = assertion != null && XmlHelpers.IsSignedByAny(assertion, idp.SigningKeys, - options.SPOptions.ValidateCertificates, options.SPOptions.MinIncomingSigningAlgorithm); - if (!isAssertionSigned) - { - throw new Exception("Cannot verify SAML assertion signature."); - } - } + var idp = options.IdentityProviders.IsEmpty ? null : options.IdentityProviders.Default; + if (idp == null) + { + return false; + } + if (context.Request.Query["scheme"].FirstOrDefault() == scheme) + { return true; } + + // We need to pull out and parse the response or request SAML envelope + XmlElement envelope = null; + try + { + if (string.Equals(context.Request.Method, "POST", StringComparison.OrdinalIgnoreCase) && + context.Request.HasFormContentType) + { + string encodedMessage; + if (context.Request.Form.TryGetValue("SAMLResponse", out var response)) + { + encodedMessage = response.FirstOrDefault(); + } + else + { + encodedMessage = context.Request.Form["SAMLRequest"]; + } + if (string.IsNullOrWhiteSpace(encodedMessage)) + { + return false; + } + envelope = XmlHelpers.XmlDocumentFromString( + Encoding.UTF8.GetString(Convert.FromBase64String(encodedMessage)))?.DocumentElement; + } + else if (string.Equals(context.Request.Method, "GET", StringComparison.OrdinalIgnoreCase)) + { + var encodedPayload = context.Request.Query["SAMLRequest"].FirstOrDefault() ?? + context.Request.Query["SAMLResponse"].FirstOrDefault(); + try + { + var payload = Convert.FromBase64String(encodedPayload); + using var compressed = new MemoryStream(payload); + using var decompressedStream = new DeflateStream(compressed, CompressionMode.Decompress, true); + using var deCompressed = new MemoryStream(); + await decompressedStream.CopyToAsync(deCompressed); + + envelope = XmlHelpers.XmlDocumentFromString( + Encoding.UTF8.GetString(deCompressed.GetBuffer(), 0, (int)deCompressed.Length))?.DocumentElement; + } + catch (FormatException ex) + { + throw new FormatException($"\'{encodedPayload}\' is not a valid Base64 encoded string: {ex.Message}", ex); + } + } + } + catch + { + return false; + } + + if (envelope == null) + { + return false; + } + + // Double check the entity Ids + var entityId = envelope["Issuer", Saml2Namespaces.Saml2Name]?.InnerText.Trim(); + if (!string.Equals(entityId, idp.EntityId.Id, StringComparison.InvariantCultureIgnoreCase)) + { + return false; + } + + if (options.SPOptions.WantAssertionsSigned) + { + var assertion = envelope["Assertion", Saml2Namespaces.Saml2Name]; + var isAssertionSigned = assertion != null && XmlHelpers.IsSignedByAny(assertion, idp.SigningKeys, + options.SPOptions.ValidateCertificates, options.SPOptions.MinIncomingSigningAlgorithm); + if (!isAssertionSigned) + { + throw new Exception("Cannot verify SAML assertion signature."); + } + } + + return true; } } diff --git a/bitwarden_license/src/Sso/Utilities/SamlClaimTypes.cs b/bitwarden_license/src/Sso/Utilities/SamlClaimTypes.cs index f62f5b04b..2f314c7ef 100644 --- a/bitwarden_license/src/Sso/Utilities/SamlClaimTypes.cs +++ b/bitwarden_license/src/Sso/Utilities/SamlClaimTypes.cs @@ -1,12 +1,11 @@ -namespace Bit.Sso.Utilities +namespace Bit.Sso.Utilities; + +public static class SamlClaimTypes { - public static class SamlClaimTypes - { - public const string Email = "urn:oid:0.9.2342.19200300.100.1.3"; - public const string GivenName = "urn:oid:2.5.4.42"; - public const string Surname = "urn:oid:2.5.4.4"; - public const string DisplayName = "urn:oid:2.16.840.1.113730.3.1.241"; - public const string CommonName = "urn:oid:2.5.4.3"; - public const string UserId = "urn:oid:0.9.2342.19200300.100.1.1"; - } + public const string Email = "urn:oid:0.9.2342.19200300.100.1.3"; + public const string GivenName = "urn:oid:2.5.4.42"; + public const string Surname = "urn:oid:2.5.4.4"; + public const string DisplayName = "urn:oid:2.16.840.1.113730.3.1.241"; + public const string CommonName = "urn:oid:2.5.4.3"; + public const string UserId = "urn:oid:0.9.2342.19200300.100.1.1"; } diff --git a/bitwarden_license/src/Sso/Utilities/SamlNameIdFormats.cs b/bitwarden_license/src/Sso/Utilities/SamlNameIdFormats.cs index 18ccc140f..94c03bd64 100644 --- a/bitwarden_license/src/Sso/Utilities/SamlNameIdFormats.cs +++ b/bitwarden_license/src/Sso/Utilities/SamlNameIdFormats.cs @@ -1,18 +1,17 @@ -namespace Bit.Sso.Utilities +namespace Bit.Sso.Utilities; + +public static class SamlNameIdFormats { - public static class SamlNameIdFormats - { - // Common - public const string Unspecified = "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified"; - public const string Email = "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"; - public const string Persistent = "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent"; - public const string Transient = "urn:oasis:names:tc:SAML:2.0:nameid-format:transient"; - // Not-so-common - public const string Upn = "http://schemas.xmlsoap.org/claims/UPN"; - public const string CommonName = "http://schemas.xmlsoap.org/claims/CommonName"; - public const string X509SubjectName = "urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName"; - public const string WindowsQualifiedDomainName = "urn:oasis:names:tc:SAML:1.1:nameid-format:WindowsDomainQualifiedName"; - public const string KerberosPrincipalName = "urn:oasis:names:tc:SAML:2.0:nameid-format:kerberos"; - public const string EntityIdentifier = "urn:oasis:names:tc:SAML:2.0:nameid-format:entity"; - } + // Common + public const string Unspecified = "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified"; + public const string Email = "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"; + public const string Persistent = "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent"; + public const string Transient = "urn:oasis:names:tc:SAML:2.0:nameid-format:transient"; + // Not-so-common + public const string Upn = "http://schemas.xmlsoap.org/claims/UPN"; + public const string CommonName = "http://schemas.xmlsoap.org/claims/CommonName"; + public const string X509SubjectName = "urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName"; + public const string WindowsQualifiedDomainName = "urn:oasis:names:tc:SAML:1.1:nameid-format:WindowsDomainQualifiedName"; + public const string KerberosPrincipalName = "urn:oasis:names:tc:SAML:2.0:nameid-format:kerberos"; + public const string EntityIdentifier = "urn:oasis:names:tc:SAML:2.0:nameid-format:entity"; } diff --git a/bitwarden_license/src/Sso/Utilities/SamlPropertyKeys.cs b/bitwarden_license/src/Sso/Utilities/SamlPropertyKeys.cs index 21d599b7f..7be7fb4f6 100644 --- a/bitwarden_license/src/Sso/Utilities/SamlPropertyKeys.cs +++ b/bitwarden_license/src/Sso/Utilities/SamlPropertyKeys.cs @@ -1,7 +1,6 @@ -namespace Bit.Sso.Utilities +namespace Bit.Sso.Utilities; + +public static class SamlPropertyKeys { - public static class SamlPropertyKeys - { - public const string ClaimFormat = "http://schemas.xmlsoap.org/ws/2005/05/identity/claimproperties/format"; - } + public const string ClaimFormat = "http://schemas.xmlsoap.org/ws/2005/05/identity/claimproperties/format"; } diff --git a/bitwarden_license/src/Sso/Utilities/ServiceCollectionExtensions.cs b/bitwarden_license/src/Sso/Utilities/ServiceCollectionExtensions.cs index 444ed6c52..d7a5e3b1b 100644 --- a/bitwarden_license/src/Sso/Utilities/ServiceCollectionExtensions.cs +++ b/bitwarden_license/src/Sso/Utilities/ServiceCollectionExtensions.cs @@ -9,70 +9,69 @@ using IdentityServer4.ResponseHandling; using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Sustainsys.Saml2.AspNetCore2; -namespace Bit.Sso.Utilities +namespace Bit.Sso.Utilities; + +public static class ServiceCollectionExtensions { - public static class ServiceCollectionExtensions + public static IServiceCollection AddSsoServices(this IServiceCollection services, + GlobalSettings globalSettings) { - public static IServiceCollection AddSsoServices(this IServiceCollection services, - GlobalSettings globalSettings) + // SAML SP Configuration + var samlEnvironment = new SamlEnvironment { - // SAML SP Configuration - var samlEnvironment = new SamlEnvironment + SpSigningCertificate = CoreHelpers.GetIdentityServerCertificate(globalSettings), + }; + services.AddSingleton(s => samlEnvironment); + + services.AddSingleton(); + // Oidc + services.AddSingleton, + OpenIdConnectPostConfigureOptions>(); + services.AddSingleton, + ExtendedOptionsMonitorCache>(); + // Saml2 + services.AddSingleton, + PostConfigureSaml2Options>(); + services.AddSingleton, + ExtendedOptionsMonitorCache>(); + + return services; + } + + public static IIdentityServerBuilder AddSsoIdentityServerServices(this IServiceCollection services, + IWebHostEnvironment env, GlobalSettings globalSettings) + { + services.AddTransient(); + + var issuerUri = new Uri(globalSettings.BaseServiceUri.InternalSso); + var identityServerBuilder = services + .AddIdentityServer(options => { - SpSigningCertificate = CoreHelpers.GetIdentityServerCertificate(globalSettings), - }; - services.AddSingleton(s => samlEnvironment); - - services.AddSingleton(); - // Oidc - services.AddSingleton, - OpenIdConnectPostConfigureOptions>(); - services.AddSingleton, - ExtendedOptionsMonitorCache>(); - // Saml2 - services.AddSingleton, - PostConfigureSaml2Options>(); - services.AddSingleton, - ExtendedOptionsMonitorCache>(); - - return services; - } - - public static IIdentityServerBuilder AddSsoIdentityServerServices(this IServiceCollection services, - IWebHostEnvironment env, GlobalSettings globalSettings) - { - services.AddTransient(); - - var issuerUri = new Uri(globalSettings.BaseServiceUri.InternalSso); - var identityServerBuilder = services - .AddIdentityServer(options => + options.IssuerUri = $"{issuerUri.Scheme}://{issuerUri.Host}"; + if (env.IsDevelopment()) { - options.IssuerUri = $"{issuerUri.Scheme}://{issuerUri.Host}"; - if (env.IsDevelopment()) - { - options.Authentication.CookieSameSiteMode = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; - } - else - { - options.UserInteraction.ErrorUrl = "/Error"; - options.UserInteraction.ErrorIdParameter = "errorId"; - } - options.InputLengthRestrictions.UserName = 256; - }) - .AddInMemoryCaching() - .AddInMemoryClients(new List + options.Authentication.CookieSameSiteMode = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; + } + else { - new OidcIdentityClient(globalSettings) - }) - .AddInMemoryIdentityResources(new List - { - new IdentityResources.OpenId(), - new IdentityResources.Profile() - }) - .AddIdentityServerCertificate(env, globalSettings); + options.UserInteraction.ErrorUrl = "/Error"; + options.UserInteraction.ErrorIdParameter = "errorId"; + } + options.InputLengthRestrictions.UserName = 256; + }) + .AddInMemoryCaching() + .AddInMemoryClients(new List + { + new OidcIdentityClient(globalSettings) + }) + .AddInMemoryIdentityResources(new List + { + new IdentityResources.OpenId(), + new IdentityResources.Profile() + }) + .AddIdentityServerCertificate(env, globalSettings); - return identityServerBuilder; - } + return identityServerBuilder; } } diff --git a/bitwarden_license/src/Sso/Utilities/SsoAuthenticationMiddleware.cs b/bitwarden_license/src/Sso/Utilities/SsoAuthenticationMiddleware.cs index 4a39082f3..9dca7a690 100644 --- a/bitwarden_license/src/Sso/Utilities/SsoAuthenticationMiddleware.cs +++ b/bitwarden_license/src/Sso/Utilities/SsoAuthenticationMiddleware.cs @@ -3,83 +3,82 @@ using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Authentication.OpenIdConnect; using Sustainsys.Saml2.AspNetCore2; -namespace Bit.Sso.Utilities +namespace Bit.Sso.Utilities; + +public class SsoAuthenticationMiddleware { - public class SsoAuthenticationMiddleware + private readonly RequestDelegate _next; + + public SsoAuthenticationMiddleware(RequestDelegate next, IAuthenticationSchemeProvider schemes) { - private readonly RequestDelegate _next; + _next = next ?? throw new ArgumentNullException(nameof(next)); + Schemes = schemes ?? throw new ArgumentNullException(nameof(schemes)); + } - public SsoAuthenticationMiddleware(RequestDelegate next, IAuthenticationSchemeProvider schemes) + public IAuthenticationSchemeProvider Schemes { get; set; } + + public async Task Invoke(HttpContext context) + { + if ((context.Request.Method == "GET" && context.Request.Query.ContainsKey("SAMLart")) + || (context.Request.Method == "POST" && context.Request.Form.ContainsKey("SAMLart"))) { - _next = next ?? throw new ArgumentNullException(nameof(next)); - Schemes = schemes ?? throw new ArgumentNullException(nameof(schemes)); + throw new Exception("SAMLart parameter detected. SAML Artifact binding is not allowed."); } - public IAuthenticationSchemeProvider Schemes { get; set; } - - public async Task Invoke(HttpContext context) + context.Features.Set(new AuthenticationFeature { - if ((context.Request.Method == "GET" && context.Request.Query.ContainsKey("SAMLart")) - || (context.Request.Method == "POST" && context.Request.Form.ContainsKey("SAMLart"))) - { - throw new Exception("SAMLart parameter detected. SAML Artifact binding is not allowed."); - } + OriginalPath = context.Request.Path, + OriginalPathBase = context.Request.PathBase + }); - context.Features.Set(new AuthenticationFeature + // Give any IAuthenticationRequestHandler schemes a chance to handle the request + var handlers = context.RequestServices.GetRequiredService(); + foreach (var scheme in await Schemes.GetRequestHandlerSchemesAsync()) + { + // Determine if scheme is appropriate for the current context FIRST + if (scheme is IDynamicAuthenticationScheme dynamicScheme) { - OriginalPath = context.Request.Path, - OriginalPathBase = context.Request.PathBase - }); - - // Give any IAuthenticationRequestHandler schemes a chance to handle the request - var handlers = context.RequestServices.GetRequiredService(); - foreach (var scheme in await Schemes.GetRequestHandlerSchemesAsync()) - { - // Determine if scheme is appropriate for the current context FIRST - if (scheme is IDynamicAuthenticationScheme dynamicScheme) + switch (dynamicScheme.SsoType) { - switch (dynamicScheme.SsoType) - { - case SsoType.OpenIdConnect: - default: - if (dynamicScheme.Options is OpenIdConnectOptions oidcOptions && - !await oidcOptions.CouldHandleAsync(scheme.Name, context)) - { - // It's OIDC and Dynamic, but not a good fit - continue; - } - break; - case SsoType.Saml2: - if (dynamicScheme.Options is Saml2Options samlOptions && - !await samlOptions.CouldHandleAsync(scheme.Name, context)) - { - // It's SAML and Dynamic, but not a good fit - continue; - } - break; - } - } - - // This far it's not dynamic OR it is but "could" be handled - if (await handlers.GetHandlerAsync(context, scheme.Name) is IAuthenticationRequestHandler handler && - await handler.HandleRequestAsync()) - { - return; + case SsoType.OpenIdConnect: + default: + if (dynamicScheme.Options is OpenIdConnectOptions oidcOptions && + !await oidcOptions.CouldHandleAsync(scheme.Name, context)) + { + // It's OIDC and Dynamic, but not a good fit + continue; + } + break; + case SsoType.Saml2: + if (dynamicScheme.Options is Saml2Options samlOptions && + !await samlOptions.CouldHandleAsync(scheme.Name, context)) + { + // It's SAML and Dynamic, but not a good fit + continue; + } + break; } } - // Fallback to the default scheme from the provider - var defaultAuthenticate = await Schemes.GetDefaultAuthenticateSchemeAsync(); - if (defaultAuthenticate != null) + // This far it's not dynamic OR it is but "could" be handled + if (await handlers.GetHandlerAsync(context, scheme.Name) is IAuthenticationRequestHandler handler && + await handler.HandleRequestAsync()) { - var result = await context.AuthenticateAsync(defaultAuthenticate.Name); - if (result?.Principal != null) - { - context.User = result.Principal; - } + return; } - - await _next(context); } + + // Fallback to the default scheme from the provider + var defaultAuthenticate = await Schemes.GetDefaultAuthenticateSchemeAsync(); + if (defaultAuthenticate != null) + { + var result = await context.AuthenticateAsync(defaultAuthenticate.Name); + if (result?.Principal != null) + { + context.User = result.Principal; + } + } + + await _next(context); } } diff --git a/bitwarden_license/test/Commercial.Core.Test/AutoFixture/ProviderUserFixtures.cs b/bitwarden_license/test/Commercial.Core.Test/AutoFixture/ProviderUserFixtures.cs index 59bc1c59f..48f70c335 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AutoFixture/ProviderUserFixtures.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AutoFixture/ProviderUserFixtures.cs @@ -3,43 +3,42 @@ using AutoFixture; using AutoFixture.Xunit2; using Bit.Core.Enums.Provider; -namespace Bit.Commercial.Core.Test.AutoFixture +namespace Bit.Commercial.Core.Test.AutoFixture; + +internal class ProviderUser : ICustomization { - internal class ProviderUser : ICustomization + public ProviderUserStatusType Status { get; set; } + public ProviderUserType Type { get; set; } + + public ProviderUser(ProviderUserStatusType status, ProviderUserType type) { - public ProviderUserStatusType Status { get; set; } - public ProviderUserType Type { get; set; } - - public ProviderUser(ProviderUserStatusType status, ProviderUserType type) - { - Status = status; - Type = type; - } - - public void Customize(IFixture fixture) - { - fixture.Customize(composer => composer - .With(o => o.Type, Type) - .With(o => o.Status, Status)); - } + Status = status; + Type = type; } - public class ProviderUserAttribute : CustomizeAttribute + public void Customize(IFixture fixture) { - private readonly ProviderUserStatusType _status; - private readonly ProviderUserType _type; - - public ProviderUserAttribute( - ProviderUserStatusType status = ProviderUserStatusType.Confirmed, - ProviderUserType type = ProviderUserType.ProviderAdmin) - { - _status = status; - _type = type; - } - - public override ICustomization GetCustomization(ParameterInfo parameter) - { - return new ProviderUser(_status, _type); - } + fixture.Customize(composer => composer + .With(o => o.Type, Type) + .With(o => o.Status, Status)); + } +} + +public class ProviderUserAttribute : CustomizeAttribute +{ + private readonly ProviderUserStatusType _status; + private readonly ProviderUserType _type; + + public ProviderUserAttribute( + ProviderUserStatusType status = ProviderUserStatusType.Confirmed, + ProviderUserType type = ProviderUserType.ProviderAdmin) + { + _status = status; + _type = type; + } + + public override ICustomization GetCustomization(ParameterInfo parameter) + { + return new ProviderUser(_status, _type); } } diff --git a/bitwarden_license/test/Commercial.Core.Test/Services/ProviderServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Services/ProviderServiceTests.cs index a8c08b632..53911ea06 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Services/ProviderServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Services/ProviderServiceTests.cs @@ -19,532 +19,531 @@ using NSubstitute.ReturnsExtensions; using Xunit; using ProviderUser = Bit.Core.Entities.Provider.ProviderUser; -namespace Bit.Commercial.Core.Test.Services +namespace Bit.Commercial.Core.Test.Services; + +public class ProviderServiceTests { - public class ProviderServiceTests + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task CreateAsync_UserIdIsInvalid_Throws(SutProvider sutProvider) { - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task CreateAsync_UserIdIsInvalid_Throws(SutProvider sutProvider) + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateAsync(default)); + Assert.Contains("Invalid owner.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task CreateAsync_Success(User user, SutProvider sutProvider) + { + var userRepository = sutProvider.GetDependency(); + userRepository.GetByEmailAsync(user.Email).Returns(user); + + await sutProvider.Sut.CreateAsync(user.Email); + + await sutProvider.GetDependency().ReceivedWithAnyArgs().CreateAsync(default); + await sutProvider.GetDependency().ReceivedWithAnyArgs().SendProviderSetupInviteEmailAsync(default, default, default); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task CompleteSetupAsync_UserIdIsInvalid_Throws(SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.CompleteSetupAsync(default, default, default, default)); + Assert.Contains("Invalid owner.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task CompleteSetupAsync_TokenIsInvalid_Throws(User user, Provider provider, + SutProvider sutProvider) + { + var userService = sutProvider.GetDependency(); + userService.GetUserByIdAsync(user.Id).Returns(user); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.CompleteSetupAsync(provider, user.Id, default, default)); + Assert.Contains("Invalid token.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task CompleteSetupAsync_Success(User user, Provider provider, string key, + [ProviderUser(ProviderUserStatusType.Confirmed, ProviderUserType.ProviderAdmin)] ProviderUser providerUser, + SutProvider sutProvider) + { + providerUser.ProviderId = provider.Id; + providerUser.UserId = user.Id; + var userService = sutProvider.GetDependency(); + userService.GetUserByIdAsync(user.Id).Returns(user); + + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser); + + var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); + var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); + sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + sutProvider.Create(); + + var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + await sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key); + + await sutProvider.GetDependency().Received().UpsertAsync(provider); + await sutProvider.GetDependency().Received() + .ReplaceAsync(Arg.Is(pu => pu.UserId == user.Id && pu.ProviderId == provider.Id && pu.Key == key)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpdateAsync_ProviderIdIsInvalid_Throws(Provider provider, SutProvider sutProvider) + { + provider.Id = default; + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(provider)); + Assert.Contains("Cannot create provider this way.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpdateAsync_Success(Provider provider, SutProvider sutProvider) + { + await sutProvider.Sut.UpdateAsync(provider); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task InviteUserAsync_ProviderIdIsInvalid_Throws(ProviderUserInvite invite, SutProvider sutProvider) + { + sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(true); + + await Assert.ThrowsAsync(() => sutProvider.Sut.InviteUserAsync(invite)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task InviteUserAsync_InvalidPermissions_Throws(ProviderUserInvite invite, SutProvider sutProvider) + { + sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(false); + await Assert.ThrowsAsync(() => sutProvider.Sut.InviteUserAsync(invite)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task InviteUserAsync_EmailsInvalid_Throws(Provider provider, ProviderUserInvite providerUserInvite, + SutProvider sutProvider) + { + var providerRepository = sutProvider.GetDependency(); + providerRepository.GetByIdAsync(providerUserInvite.ProviderId).Returns(provider); + sutProvider.GetDependency().ProviderManageUsers(providerUserInvite.ProviderId).Returns(true); + + providerUserInvite.UserIdentifiers = null; + + await Assert.ThrowsAsync(() => sutProvider.Sut.InviteUserAsync(providerUserInvite)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task InviteUserAsync_AlreadyInvited(Provider provider, ProviderUserInvite providerUserInvite, + SutProvider sutProvider) + { + var providerRepository = sutProvider.GetDependency(); + providerRepository.GetByIdAsync(providerUserInvite.ProviderId).Returns(provider); + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetCountByProviderAsync(default, default, default).ReturnsForAnyArgs(1); + sutProvider.GetDependency().ProviderManageUsers(providerUserInvite.ProviderId).Returns(true); + + var result = await sutProvider.Sut.InviteUserAsync(providerUserInvite); + Assert.Empty(result); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task InviteUserAsync_Success(Provider provider, ProviderUserInvite providerUserInvite, + SutProvider sutProvider) + { + var providerRepository = sutProvider.GetDependency(); + providerRepository.GetByIdAsync(providerUserInvite.ProviderId).Returns(provider); + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetCountByProviderAsync(default, default, default).ReturnsForAnyArgs(0); + sutProvider.GetDependency().ProviderManageUsers(providerUserInvite.ProviderId).Returns(true); + + var result = await sutProvider.Sut.InviteUserAsync(providerUserInvite); + Assert.Equal(providerUserInvite.UserIdentifiers.Count(), result.Count); + Assert.True(result.TrueForAll(pu => pu.Status == ProviderUserStatusType.Invited), "Status must be invited"); + Assert.True(result.TrueForAll(pu => pu.ProviderId == providerUserInvite.ProviderId), "Provider Id must be correct"); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ResendInviteUserAsync_InvalidPermissions_Throws(ProviderUserInvite invite, SutProvider sutProvider) + { + sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(false); + await Assert.ThrowsAsync(() => sutProvider.Sut.ResendInvitesAsync(invite)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ResendInvitesAsync_Errors(Provider provider, + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser pu1, + [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser pu2, + [ProviderUser(ProviderUserStatusType.Confirmed)] ProviderUser pu3, + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser pu4, + SutProvider sutProvider) + { + var providerUsers = new[] { pu1, pu2, pu3, pu4 }; + pu1.ProviderId = pu2.ProviderId = pu3.ProviderId = provider.Id; + + var invite = new ProviderUserInvite { - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.CreateAsync(default)); - Assert.Contains("Invalid owner.", exception.Message); - } + UserIdentifiers = providerUsers.Select(pu => pu.Id), + ProviderId = provider.Id + }; - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task CreateAsync_Success(User user, SutProvider sutProvider) - { - var userRepository = sutProvider.GetDependency(); - userRepository.GetByEmailAsync(user.Email).Returns(user); + var providerRepository = sutProvider.GetDependency(); + providerRepository.GetByIdAsync(provider.Id).Returns(provider); + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers.ToList()); + sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(true); - await sutProvider.Sut.CreateAsync(user.Email); + var result = await sutProvider.Sut.ResendInvitesAsync(invite); + Assert.Equal("", result[0].Item2); + Assert.Equal("User invalid.", result[1].Item2); + Assert.Equal("User invalid.", result[2].Item2); + Assert.Equal("User invalid.", result[3].Item2); + } - await sutProvider.GetDependency().ReceivedWithAnyArgs().CreateAsync(default); - await sutProvider.GetDependency().ReceivedWithAnyArgs().SendProviderSetupInviteEmailAsync(default, default, default); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task CompleteSetupAsync_UserIdIsInvalid_Throws(SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.CompleteSetupAsync(default, default, default, default)); - Assert.Contains("Invalid owner.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task CompleteSetupAsync_TokenIsInvalid_Throws(User user, Provider provider, - SutProvider sutProvider) - { - var userService = sutProvider.GetDependency(); - userService.GetUserByIdAsync(user.Id).Returns(user); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.CompleteSetupAsync(provider, user.Id, default, default)); - Assert.Contains("Invalid token.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task CompleteSetupAsync_Success(User user, Provider provider, string key, - [ProviderUser(ProviderUserStatusType.Confirmed, ProviderUserType.ProviderAdmin)] ProviderUser providerUser, - SutProvider sutProvider) + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ResendInvitesAsync_Success(Provider provider, IEnumerable providerUsers, + SutProvider sutProvider) + { + foreach (var providerUser in providerUsers) { providerUser.ProviderId = provider.Id; - providerUser.UserId = user.Id; - var userService = sutProvider.GetDependency(); - userService.GetUserByIdAsync(user.Id).Returns(user); - - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser); - - var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); - var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); - sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") - .Returns(protector); - sutProvider.Create(); - - var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); - - await sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key); - - await sutProvider.GetDependency().Received().UpsertAsync(provider); - await sutProvider.GetDependency().Received() - .ReplaceAsync(Arg.Is(pu => pu.UserId == user.Id && pu.ProviderId == provider.Id && pu.Key == key)); + providerUser.Status = ProviderUserStatusType.Invited; } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpdateAsync_ProviderIdIsInvalid_Throws(Provider provider, SutProvider sutProvider) + var invite = new ProviderUserInvite { - provider.Id = default; + UserIdentifiers = providerUsers.Select(pu => pu.Id), + ProviderId = provider.Id + }; - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.UpdateAsync(provider)); - Assert.Contains("Cannot create provider this way.", exception.Message); - } + var providerRepository = sutProvider.GetDependency(); + providerRepository.GetByIdAsync(provider.Id).Returns(provider); + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers.ToList()); + sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(true); - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpdateAsync_Success(Provider provider, SutProvider sutProvider) + var result = await sutProvider.Sut.ResendInvitesAsync(invite); + Assert.True(result.All(r => r.Item2 == "")); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task AcceptUserAsync_UserIsInvalid_Throws(SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.AcceptUserAsync(default, default, default)); + Assert.Equal("User invalid.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task AcceptUserAsync_AlreadyAccepted_Throws( + [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser providerUser, User user, + SutProvider sutProvider) + { + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, default)); + Assert.Equal("Already accepted.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task AcceptUserAsync_TokenIsInvalid_Throws( + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser providerUser, User user, + SutProvider sutProvider) + { + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, default)); + Assert.Equal("Invalid token.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task AcceptUserAsync_WrongEmail_Throws( + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser providerUser, User user, + SutProvider sutProvider) + { + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); + + var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); + var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); + sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + sutProvider.Create(); + + var token = protector.Protect($"ProviderUserInvite {providerUser.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, token)); + Assert.Equal("User email does not match invite.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task AcceptUserAsync_Success( + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser providerUser, User user, + SutProvider sutProvider) + { + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); + + var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); + var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); + sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + sutProvider.Create(); + + providerUser.Email = user.Email; + var token = protector.Protect($"ProviderUserInvite {providerUser.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + var pu = await sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, token); + Assert.Null(pu.Email); + Assert.Equal(ProviderUserStatusType.Accepted, pu.Status); + Assert.Equal(user.Id, pu.UserId); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ConfirmUsersAsync_NoValid( + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser pu1, + [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser pu2, + [ProviderUser(ProviderUserStatusType.Confirmed)] ProviderUser pu3, + SutProvider sutProvider) + { + pu1.ProviderId = pu3.ProviderId; + var providerUsers = new[] { pu1, pu2, pu3 }; + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers); + + var dict = providerUsers.ToDictionary(pu => pu.Id, _ => "key"); + var result = await sutProvider.Sut.ConfirmUsersAsync(pu1.ProviderId, dict, default); + + Assert.Empty(result); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ConfirmUsersAsync_Success( + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser pu1, User u1, + [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser pu2, User u2, + [ProviderUser(ProviderUserStatusType.Confirmed)] ProviderUser pu3, User u3, + Provider provider, User user, SutProvider sutProvider) + { + pu1.ProviderId = pu2.ProviderId = pu3.ProviderId = provider.Id; + pu1.UserId = u1.Id; + pu2.UserId = u2.Id; + pu3.UserId = u3.Id; + var providerUsers = new[] { pu1, pu2, pu3 }; + + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers); + var providerRepository = sutProvider.GetDependency(); + providerRepository.GetByIdAsync(provider.Id).Returns(provider); + var userRepository = sutProvider.GetDependency(); + userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { u1, u2, u3 }); + + var dict = providerUsers.ToDictionary(pu => pu.Id, _ => "key"); + var result = await sutProvider.Sut.ConfirmUsersAsync(pu1.ProviderId, dict, user.Id); + + Assert.Equal("Invalid user.", result[0].Item2); + Assert.Equal("", result[1].Item2); + Assert.Equal("Invalid user.", result[2].Item2); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveUserAsync_UserIdIsInvalid_Throws(ProviderUser providerUser, + SutProvider sutProvider) + { + providerUser.Id = default; + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveUserAsync(providerUser, default)); + Assert.Equal("Invite the user first.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveUserAsync_Success( + [ProviderUser(type: ProviderUserType.ProviderAdmin)] ProviderUser providerUser, User savingUser, + SutProvider sutProvider) + { + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); + + await sutProvider.Sut.SaveUserAsync(providerUser, savingUser.Id); + await providerUserRepository.Received().ReplaceAsync(providerUser); + await sutProvider.GetDependency().Received() + .LogProviderUserEventAsync(providerUser, EventType.ProviderUser_Updated, null); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUsersAsync_NoRemainingOwner_Throws(Provider provider, User deletingUser, + ICollection providerUsers, SutProvider sutProvider) + { + var userIds = providerUsers.Select(pu => pu.Id); + + providerUsers.First().UserId = deletingUser.Id; + foreach (var providerUser in providerUsers) { - await sutProvider.Sut.UpdateAsync(provider); + providerUser.ProviderId = provider.Id; } + providerUsers.Last().ProviderId = default; - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task InviteUserAsync_ProviderIdIsInvalid_Throws(ProviderUserInvite invite, SutProvider sutProvider) + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers); + providerUserRepository.GetManyByProviderAsync(default, default).ReturnsForAnyArgs(new ProviderUser[] { }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteUsersAsync(provider.Id, userIds, deletingUser.Id)); + Assert.Equal("Provider must have at least one confirmed ProviderAdmin.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUsersAsync_Success(Provider provider, User deletingUser, ICollection providerUsers, + [ProviderUser(ProviderUserStatusType.Confirmed, ProviderUserType.ProviderAdmin)] ProviderUser remainingOwner, + SutProvider sutProvider) + { + var userIds = providerUsers.Select(pu => pu.Id); + + providerUsers.First().UserId = deletingUser.Id; + foreach (var providerUser in providerUsers) { - sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(true); - - await Assert.ThrowsAsync(() => sutProvider.Sut.InviteUserAsync(invite)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task InviteUserAsync_InvalidPermissions_Throws(ProviderUserInvite invite, SutProvider sutProvider) - { - sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(false); - await Assert.ThrowsAsync(() => sutProvider.Sut.InviteUserAsync(invite)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task InviteUserAsync_EmailsInvalid_Throws(Provider provider, ProviderUserInvite providerUserInvite, - SutProvider sutProvider) - { - var providerRepository = sutProvider.GetDependency(); - providerRepository.GetByIdAsync(providerUserInvite.ProviderId).Returns(provider); - sutProvider.GetDependency().ProviderManageUsers(providerUserInvite.ProviderId).Returns(true); - - providerUserInvite.UserIdentifiers = null; - - await Assert.ThrowsAsync(() => sutProvider.Sut.InviteUserAsync(providerUserInvite)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task InviteUserAsync_AlreadyInvited(Provider provider, ProviderUserInvite providerUserInvite, - SutProvider sutProvider) - { - var providerRepository = sutProvider.GetDependency(); - providerRepository.GetByIdAsync(providerUserInvite.ProviderId).Returns(provider); - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetCountByProviderAsync(default, default, default).ReturnsForAnyArgs(1); - sutProvider.GetDependency().ProviderManageUsers(providerUserInvite.ProviderId).Returns(true); - - var result = await sutProvider.Sut.InviteUserAsync(providerUserInvite); - Assert.Empty(result); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task InviteUserAsync_Success(Provider provider, ProviderUserInvite providerUserInvite, - SutProvider sutProvider) - { - var providerRepository = sutProvider.GetDependency(); - providerRepository.GetByIdAsync(providerUserInvite.ProviderId).Returns(provider); - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetCountByProviderAsync(default, default, default).ReturnsForAnyArgs(0); - sutProvider.GetDependency().ProviderManageUsers(providerUserInvite.ProviderId).Returns(true); - - var result = await sutProvider.Sut.InviteUserAsync(providerUserInvite); - Assert.Equal(providerUserInvite.UserIdentifiers.Count(), result.Count); - Assert.True(result.TrueForAll(pu => pu.Status == ProviderUserStatusType.Invited), "Status must be invited"); - Assert.True(result.TrueForAll(pu => pu.ProviderId == providerUserInvite.ProviderId), "Provider Id must be correct"); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ResendInviteUserAsync_InvalidPermissions_Throws(ProviderUserInvite invite, SutProvider sutProvider) - { - sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(false); - await Assert.ThrowsAsync(() => sutProvider.Sut.ResendInvitesAsync(invite)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ResendInvitesAsync_Errors(Provider provider, - [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser pu1, - [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser pu2, - [ProviderUser(ProviderUserStatusType.Confirmed)] ProviderUser pu3, - [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser pu4, - SutProvider sutProvider) - { - var providerUsers = new[] { pu1, pu2, pu3, pu4 }; - pu1.ProviderId = pu2.ProviderId = pu3.ProviderId = provider.Id; - - var invite = new ProviderUserInvite - { - UserIdentifiers = providerUsers.Select(pu => pu.Id), - ProviderId = provider.Id - }; - - var providerRepository = sutProvider.GetDependency(); - providerRepository.GetByIdAsync(provider.Id).Returns(provider); - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers.ToList()); - sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(true); - - var result = await sutProvider.Sut.ResendInvitesAsync(invite); - Assert.Equal("", result[0].Item2); - Assert.Equal("User invalid.", result[1].Item2); - Assert.Equal("User invalid.", result[2].Item2); - Assert.Equal("User invalid.", result[3].Item2); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ResendInvitesAsync_Success(Provider provider, IEnumerable providerUsers, - SutProvider sutProvider) - { - foreach (var providerUser in providerUsers) - { - providerUser.ProviderId = provider.Id; - providerUser.Status = ProviderUserStatusType.Invited; - } - - var invite = new ProviderUserInvite - { - UserIdentifiers = providerUsers.Select(pu => pu.Id), - ProviderId = provider.Id - }; - - var providerRepository = sutProvider.GetDependency(); - providerRepository.GetByIdAsync(provider.Id).Returns(provider); - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers.ToList()); - sutProvider.GetDependency().ProviderManageUsers(invite.ProviderId).Returns(true); - - var result = await sutProvider.Sut.ResendInvitesAsync(invite); - Assert.True(result.All(r => r.Item2 == "")); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task AcceptUserAsync_UserIsInvalid_Throws(SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.AcceptUserAsync(default, default, default)); - Assert.Equal("User invalid.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task AcceptUserAsync_AlreadyAccepted_Throws( - [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser providerUser, User user, - SutProvider sutProvider) - { - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, default)); - Assert.Equal("Already accepted.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task AcceptUserAsync_TokenIsInvalid_Throws( - [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser providerUser, User user, - SutProvider sutProvider) - { - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, default)); - Assert.Equal("Invalid token.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task AcceptUserAsync_WrongEmail_Throws( - [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser providerUser, User user, - SutProvider sutProvider) - { - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); - - var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); - var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); - sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") - .Returns(protector); - sutProvider.Create(); - - var token = protector.Protect($"ProviderUserInvite {providerUser.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, token)); - Assert.Equal("User email does not match invite.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task AcceptUserAsync_Success( - [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser providerUser, User user, - SutProvider sutProvider) - { - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); - - var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); - var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); - sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") - .Returns(protector); - sutProvider.Create(); - - providerUser.Email = user.Email; - var token = protector.Protect($"ProviderUserInvite {providerUser.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); - - var pu = await sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, token); - Assert.Null(pu.Email); - Assert.Equal(ProviderUserStatusType.Accepted, pu.Status); - Assert.Equal(user.Id, pu.UserId); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ConfirmUsersAsync_NoValid( - [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser pu1, - [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser pu2, - [ProviderUser(ProviderUserStatusType.Confirmed)] ProviderUser pu3, - SutProvider sutProvider) - { - pu1.ProviderId = pu3.ProviderId; - var providerUsers = new[] { pu1, pu2, pu3 }; - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers); - - var dict = providerUsers.ToDictionary(pu => pu.Id, _ => "key"); - var result = await sutProvider.Sut.ConfirmUsersAsync(pu1.ProviderId, dict, default); - - Assert.Empty(result); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ConfirmUsersAsync_Success( - [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser pu1, User u1, - [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser pu2, User u2, - [ProviderUser(ProviderUserStatusType.Confirmed)] ProviderUser pu3, User u3, - Provider provider, User user, SutProvider sutProvider) - { - pu1.ProviderId = pu2.ProviderId = pu3.ProviderId = provider.Id; - pu1.UserId = u1.Id; - pu2.UserId = u2.Id; - pu3.UserId = u3.Id; - var providerUsers = new[] { pu1, pu2, pu3 }; - - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers); - var providerRepository = sutProvider.GetDependency(); - providerRepository.GetByIdAsync(provider.Id).Returns(provider); - var userRepository = sutProvider.GetDependency(); - userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { u1, u2, u3 }); - - var dict = providerUsers.ToDictionary(pu => pu.Id, _ => "key"); - var result = await sutProvider.Sut.ConfirmUsersAsync(pu1.ProviderId, dict, user.Id); - - Assert.Equal("Invalid user.", result[0].Item2); - Assert.Equal("", result[1].Item2); - Assert.Equal("Invalid user.", result[2].Item2); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveUserAsync_UserIdIsInvalid_Throws(ProviderUser providerUser, - SutProvider sutProvider) - { - providerUser.Id = default; - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveUserAsync(providerUser, default)); - Assert.Equal("Invite the user first.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveUserAsync_Success( - [ProviderUser(type: ProviderUserType.ProviderAdmin)] ProviderUser providerUser, User savingUser, - SutProvider sutProvider) - { - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetByIdAsync(providerUser.Id).Returns(providerUser); - - await sutProvider.Sut.SaveUserAsync(providerUser, savingUser.Id); - await providerUserRepository.Received().ReplaceAsync(providerUser); - await sutProvider.GetDependency().Received() - .LogProviderUserEventAsync(providerUser, EventType.ProviderUser_Updated, null); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUsersAsync_NoRemainingOwner_Throws(Provider provider, User deletingUser, - ICollection providerUsers, SutProvider sutProvider) - { - var userIds = providerUsers.Select(pu => pu.Id); - - providerUsers.First().UserId = deletingUser.Id; - foreach (var providerUser in providerUsers) - { - providerUser.ProviderId = provider.Id; - } - providerUsers.Last().ProviderId = default; - - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers); - providerUserRepository.GetManyByProviderAsync(default, default).ReturnsForAnyArgs(new ProviderUser[] { }); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.DeleteUsersAsync(provider.Id, userIds, deletingUser.Id)); - Assert.Equal("Provider must have at least one confirmed ProviderAdmin.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUsersAsync_Success(Provider provider, User deletingUser, ICollection providerUsers, - [ProviderUser(ProviderUserStatusType.Confirmed, ProviderUserType.ProviderAdmin)] ProviderUser remainingOwner, - SutProvider sutProvider) - { - var userIds = providerUsers.Select(pu => pu.Id); - - providerUsers.First().UserId = deletingUser.Id; - foreach (var providerUser in providerUsers) - { - providerUser.ProviderId = provider.Id; - } - providerUsers.Last().ProviderId = default; - - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers); - providerUserRepository.GetManyByProviderAsync(default, default).ReturnsForAnyArgs(new[] { remainingOwner }); - - var result = await sutProvider.Sut.DeleteUsersAsync(provider.Id, userIds, deletingUser.Id); - - Assert.NotEmpty(result); - Assert.Equal("You cannot remove yourself.", result[0].Item2); - Assert.Equal("", result[1].Item2); - Assert.Equal("Invalid user.", result[2].Item2); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task AddOrganization_OrganizationAlreadyBelongsToAProvider_Throws(Provider provider, - Organization organization, ProviderOrganization po, User user, string key, - SutProvider sutProvider) - { - po.OrganizationId = organization.Id; - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - sutProvider.GetDependency().GetByOrganizationId(organization.Id) - .Returns(po); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.AddOrganization(provider.Id, organization.Id, user.Id, key)); - Assert.Equal("Organization already belongs to a provider.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task AddOrganization_Success(Provider provider, Organization organization, User user, string key, - SutProvider sutProvider) - { - organization.PlanType = PlanType.EnterpriseAnnually; - - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - var providerOrganizationRepository = sutProvider.GetDependency(); - providerOrganizationRepository.GetByOrganizationId(organization.Id).ReturnsNull(); - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - - await sutProvider.Sut.AddOrganization(provider.Id, organization.Id, user.Id, key); - - await providerOrganizationRepository.ReceivedWithAnyArgs().CreateAsync(default); - await sutProvider.GetDependency() - .Received().LogProviderOrganizationEventAsync(Arg.Any(), - EventType.ProviderOrganization_Added); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task CreateOrganizationAsync_Success(Provider provider, OrganizationSignup organizationSignup, - Organization organization, string clientOwnerEmail, User user, SutProvider sutProvider) - { - organizationSignup.Plan = PlanType.EnterpriseAnnually; - - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - var providerOrganizationRepository = sutProvider.GetDependency(); - sutProvider.GetDependency().SignUpAsync(organizationSignup, true) - .Returns(Tuple.Create(organization, null as OrganizationUser)); - - var providerOrganization = - await sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user); - - await providerOrganizationRepository.ReceivedWithAnyArgs().CreateAsync(default); - await sutProvider.GetDependency() - .Received().LogProviderOrganizationEventAsync(providerOrganization, - EventType.ProviderOrganization_Created); - await sutProvider.GetDependency() - .Received().InviteUsersAsync(organization.Id, user.Id, Arg.Is>( - t => t.Count() == 1 && - t.First().Item1.Emails.Count() == 1 && - t.First().Item1.Emails.First() == clientOwnerEmail && - t.First().Item1.Type == OrganizationUserType.Owner && - t.First().Item1.AccessAll && - t.First().Item2 == null)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task RemoveOrganization_ProviderOrganizationIsInvalid_Throws(Provider provider, - ProviderOrganization providerOrganization, User user, SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - sutProvider.GetDependency().GetByIdAsync(providerOrganization.Id) - .ReturnsNull(); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RemoveOrganizationAsync(provider.Id, providerOrganization.Id, user.Id)); - Assert.Equal("Invalid organization.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task RemoveOrganization_ProviderOrganizationBelongsToWrongProvider_Throws(Provider provider, - ProviderOrganization providerOrganization, User user, SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - sutProvider.GetDependency().GetByIdAsync(providerOrganization.Id) - .Returns(providerOrganization); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RemoveOrganizationAsync(provider.Id, providerOrganization.Id, user.Id)); - Assert.Equal("Invalid organization.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task RemoveOrganization_HasNoOwners_Throws(Provider provider, - ProviderOrganization providerOrganization, User user, SutProvider sutProvider) - { - providerOrganization.ProviderId = provider.Id; - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - sutProvider.GetDependency().GetByIdAsync(providerOrganization.Id) - .Returns(providerOrganization); - sutProvider.GetDependency().HasConfirmedOwnersExceptAsync(default, default, default) - .ReturnsForAnyArgs(false); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RemoveOrganizationAsync(provider.Id, providerOrganization.Id, user.Id)); - Assert.Equal("Organization needs to have at least one confirmed owner.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task RemoveOrganization_Success(Provider provider, - ProviderOrganization providerOrganization, User user, SutProvider sutProvider) - { - providerOrganization.ProviderId = provider.Id; - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - var providerOrganizationRepository = sutProvider.GetDependency(); - providerOrganizationRepository.GetByIdAsync(providerOrganization.Id).Returns(providerOrganization); - sutProvider.GetDependency().HasConfirmedOwnersExceptAsync(default, default, default) - .ReturnsForAnyArgs(true); - - await sutProvider.Sut.RemoveOrganizationAsync(provider.Id, providerOrganization.Id, user.Id); - await providerOrganizationRepository.Received().DeleteAsync(providerOrganization); - await sutProvider.GetDependency().Received() - .LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed); + providerUser.ProviderId = provider.Id; } + providerUsers.Last().ProviderId = default; + + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetManyAsync(default).ReturnsForAnyArgs(providerUsers); + providerUserRepository.GetManyByProviderAsync(default, default).ReturnsForAnyArgs(new[] { remainingOwner }); + + var result = await sutProvider.Sut.DeleteUsersAsync(provider.Id, userIds, deletingUser.Id); + + Assert.NotEmpty(result); + Assert.Equal("You cannot remove yourself.", result[0].Item2); + Assert.Equal("", result[1].Item2); + Assert.Equal("Invalid user.", result[2].Item2); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task AddOrganization_OrganizationAlreadyBelongsToAProvider_Throws(Provider provider, + Organization organization, ProviderOrganization po, User user, string key, + SutProvider sutProvider) + { + po.OrganizationId = organization.Id; + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + sutProvider.GetDependency().GetByOrganizationId(organization.Id) + .Returns(po); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.AddOrganization(provider.Id, organization.Id, user.Id, key)); + Assert.Equal("Organization already belongs to a provider.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task AddOrganization_Success(Provider provider, Organization organization, User user, string key, + SutProvider sutProvider) + { + organization.PlanType = PlanType.EnterpriseAnnually; + + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + var providerOrganizationRepository = sutProvider.GetDependency(); + providerOrganizationRepository.GetByOrganizationId(organization.Id).ReturnsNull(); + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + await sutProvider.Sut.AddOrganization(provider.Id, organization.Id, user.Id, key); + + await providerOrganizationRepository.ReceivedWithAnyArgs().CreateAsync(default); + await sutProvider.GetDependency() + .Received().LogProviderOrganizationEventAsync(Arg.Any(), + EventType.ProviderOrganization_Added); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task CreateOrganizationAsync_Success(Provider provider, OrganizationSignup organizationSignup, + Organization organization, string clientOwnerEmail, User user, SutProvider sutProvider) + { + organizationSignup.Plan = PlanType.EnterpriseAnnually; + + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + var providerOrganizationRepository = sutProvider.GetDependency(); + sutProvider.GetDependency().SignUpAsync(organizationSignup, true) + .Returns(Tuple.Create(organization, null as OrganizationUser)); + + var providerOrganization = + await sutProvider.Sut.CreateOrganizationAsync(provider.Id, organizationSignup, clientOwnerEmail, user); + + await providerOrganizationRepository.ReceivedWithAnyArgs().CreateAsync(default); + await sutProvider.GetDependency() + .Received().LogProviderOrganizationEventAsync(providerOrganization, + EventType.ProviderOrganization_Created); + await sutProvider.GetDependency() + .Received().InviteUsersAsync(organization.Id, user.Id, Arg.Is>( + t => t.Count() == 1 && + t.First().Item1.Emails.Count() == 1 && + t.First().Item1.Emails.First() == clientOwnerEmail && + t.First().Item1.Type == OrganizationUserType.Owner && + t.First().Item1.AccessAll && + t.First().Item2 == null)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task RemoveOrganization_ProviderOrganizationIsInvalid_Throws(Provider provider, + ProviderOrganization providerOrganization, User user, SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + sutProvider.GetDependency().GetByIdAsync(providerOrganization.Id) + .ReturnsNull(); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.RemoveOrganizationAsync(provider.Id, providerOrganization.Id, user.Id)); + Assert.Equal("Invalid organization.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task RemoveOrganization_ProviderOrganizationBelongsToWrongProvider_Throws(Provider provider, + ProviderOrganization providerOrganization, User user, SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + sutProvider.GetDependency().GetByIdAsync(providerOrganization.Id) + .Returns(providerOrganization); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.RemoveOrganizationAsync(provider.Id, providerOrganization.Id, user.Id)); + Assert.Equal("Invalid organization.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task RemoveOrganization_HasNoOwners_Throws(Provider provider, + ProviderOrganization providerOrganization, User user, SutProvider sutProvider) + { + providerOrganization.ProviderId = provider.Id; + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + sutProvider.GetDependency().GetByIdAsync(providerOrganization.Id) + .Returns(providerOrganization); + sutProvider.GetDependency().HasConfirmedOwnersExceptAsync(default, default, default) + .ReturnsForAnyArgs(false); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.RemoveOrganizationAsync(provider.Id, providerOrganization.Id, user.Id)); + Assert.Equal("Organization needs to have at least one confirmed owner.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task RemoveOrganization_Success(Provider provider, + ProviderOrganization providerOrganization, User user, SutProvider sutProvider) + { + providerOrganization.ProviderId = provider.Id; + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + var providerOrganizationRepository = sutProvider.GetDependency(); + providerOrganizationRepository.GetByIdAsync(providerOrganization.Id).Returns(providerOrganization); + sutProvider.GetDependency().HasConfirmedOwnersExceptAsync(default, default, default) + .ReturnsForAnyArgs(true); + + await sutProvider.Sut.RemoveOrganizationAsync(provider.Id, providerOrganization.Id, user.Id); + await providerOrganizationRepository.Received().DeleteAsync(providerOrganization); + await sutProvider.GetDependency().Received() + .LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed); } } diff --git a/src/Admin/AdminSettings.cs b/src/Admin/AdminSettings.cs index 64de4f083..6941bbc8f 100644 --- a/src/Admin/AdminSettings.cs +++ b/src/Admin/AdminSettings.cs @@ -1,16 +1,15 @@ -namespace Bit.Admin -{ - public class AdminSettings - { - public virtual string Admins { get; set; } - public virtual CloudflareSettings Cloudflare { get; set; } - public int? DeleteTrashDaysAgo { get; set; } +namespace Bit.Admin; - public class CloudflareSettings - { - public string ZoneId { get; set; } - public string AuthEmail { get; set; } - public string AuthKey { get; set; } - } +public class AdminSettings +{ + public virtual string Admins { get; set; } + public virtual CloudflareSettings Cloudflare { get; set; } + public int? DeleteTrashDaysAgo { get; set; } + + public class CloudflareSettings + { + public string ZoneId { get; set; } + public string AuthEmail { get; set; } + public string AuthKey { get; set; } } } diff --git a/src/Admin/Controllers/ErrorController.cs b/src/Admin/Controllers/ErrorController.cs index af6091204..9216537ff 100644 --- a/src/Admin/Controllers/ErrorController.cs +++ b/src/Admin/Controllers/ErrorController.cs @@ -1,24 +1,23 @@ using Microsoft.AspNetCore.Diagnostics; using Microsoft.AspNetCore.Mvc; -namespace Bit.Admin.Controllers -{ - public class ErrorController : Controller - { - [Route("/error")] - public IActionResult Error(int? statusCode = null) - { - var exceptionHandlerPathFeature = HttpContext.Features.Get(); - TempData["Error"] = HttpContext.Features.Get()?.Error.Message; +namespace Bit.Admin.Controllers; - if (exceptionHandlerPathFeature != null) - { - return Redirect(exceptionHandlerPathFeature.Path); - } - else - { - return Redirect("/Home"); - } +public class ErrorController : Controller +{ + [Route("/error")] + public IActionResult Error(int? statusCode = null) + { + var exceptionHandlerPathFeature = HttpContext.Features.Get(); + TempData["Error"] = HttpContext.Features.Get()?.Error.Message; + + if (exceptionHandlerPathFeature != null) + { + return Redirect(exceptionHandlerPathFeature.Path); + } + else + { + return Redirect("/Home"); } } } diff --git a/src/Admin/Controllers/HomeController.cs b/src/Admin/Controllers/HomeController.cs index fe93eef26..5e3b76ebb 100644 --- a/src/Admin/Controllers/HomeController.cs +++ b/src/Admin/Controllers/HomeController.cs @@ -6,109 +6,108 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; using Newtonsoft.Json; -namespace Bit.Admin.Controllers +namespace Bit.Admin.Controllers; + +public class HomeController : Controller { - public class HomeController : Controller + private readonly GlobalSettings _globalSettings; + private readonly HttpClient _httpClient = new HttpClient(); + private readonly ILogger _logger; + + public HomeController(GlobalSettings globalSettings, ILogger logger) { - private readonly GlobalSettings _globalSettings; - private readonly HttpClient _httpClient = new HttpClient(); - private readonly ILogger _logger; - - public HomeController(GlobalSettings globalSettings, ILogger logger) - { - _globalSettings = globalSettings; - _logger = logger; - } - - [Authorize] - public IActionResult Index() - { - return View(new HomeModel - { - GlobalSettings = _globalSettings, - CurrentVersion = Core.Utilities.CoreHelpers.GetVersion() - }); - } - - public IActionResult Error() - { - return View(new ErrorViewModel - { - RequestId = Activity.Current?.Id ?? HttpContext.TraceIdentifier - }); - } - - - public async Task GetLatestVersion(ProjectType project, CancellationToken cancellationToken) - { - var requestUri = $"https://selfhost.bitwarden.com/version.json"; - try - { - var response = await _httpClient.GetAsync(requestUri, cancellationToken); - if (response.IsSuccessStatusCode) - { - var latestVersions = JsonConvert.DeserializeObject(await response.Content.ReadAsStringAsync()); - return project switch - { - ProjectType.Core => new JsonResult(latestVersions.Versions.CoreVersion), - ProjectType.Web => new JsonResult(latestVersions.Versions.WebVersion), - _ => throw new System.NotImplementedException(), - }; - } - } - catch (HttpRequestException e) - { - _logger.LogError(e, $"Error encountered while sending GET request to {requestUri}"); - return new JsonResult("Unable to fetch latest version") { StatusCode = StatusCodes.Status500InternalServerError }; - } - - return new JsonResult("-"); - } - - public async Task GetInstalledWebVersion(CancellationToken cancellationToken) - { - var requestUri = $"{_globalSettings.BaseServiceUri.InternalVault}/version.json"; - try - { - var response = await _httpClient.GetAsync(requestUri, cancellationToken); - if (response.IsSuccessStatusCode) - { - using var jsonDocument = await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync(cancellationToken), cancellationToken: cancellationToken); - var root = jsonDocument.RootElement; - return new JsonResult(root.GetProperty("version").GetString()); - } - } - catch (HttpRequestException e) - { - _logger.LogError(e, $"Error encountered while sending GET request to {requestUri}"); - return new JsonResult("Unable to fetch installed version") { StatusCode = StatusCodes.Status500InternalServerError }; - } - - return new JsonResult("-"); - } - - private class LatestVersions - { - [JsonProperty("versions")] - public Versions Versions { get; set; } - } - - private class Versions - { - [JsonProperty("coreVersion")] - public string CoreVersion { get; set; } - - [JsonProperty("webVersion")] - public string WebVersion { get; set; } - - [JsonProperty("keyConnectorVersion")] - public string KeyConnectorVersion { get; set; } - } + _globalSettings = globalSettings; + _logger = logger; } - public enum ProjectType + [Authorize] + public IActionResult Index() { - Core, - Web, + return View(new HomeModel + { + GlobalSettings = _globalSettings, + CurrentVersion = Core.Utilities.CoreHelpers.GetVersion() + }); + } + + public IActionResult Error() + { + return View(new ErrorViewModel + { + RequestId = Activity.Current?.Id ?? HttpContext.TraceIdentifier + }); + } + + + public async Task GetLatestVersion(ProjectType project, CancellationToken cancellationToken) + { + var requestUri = $"https://selfhost.bitwarden.com/version.json"; + try + { + var response = await _httpClient.GetAsync(requestUri, cancellationToken); + if (response.IsSuccessStatusCode) + { + var latestVersions = JsonConvert.DeserializeObject(await response.Content.ReadAsStringAsync()); + return project switch + { + ProjectType.Core => new JsonResult(latestVersions.Versions.CoreVersion), + ProjectType.Web => new JsonResult(latestVersions.Versions.WebVersion), + _ => throw new System.NotImplementedException(), + }; + } + } + catch (HttpRequestException e) + { + _logger.LogError(e, $"Error encountered while sending GET request to {requestUri}"); + return new JsonResult("Unable to fetch latest version") { StatusCode = StatusCodes.Status500InternalServerError }; + } + + return new JsonResult("-"); + } + + public async Task GetInstalledWebVersion(CancellationToken cancellationToken) + { + var requestUri = $"{_globalSettings.BaseServiceUri.InternalVault}/version.json"; + try + { + var response = await _httpClient.GetAsync(requestUri, cancellationToken); + if (response.IsSuccessStatusCode) + { + using var jsonDocument = await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync(cancellationToken), cancellationToken: cancellationToken); + var root = jsonDocument.RootElement; + return new JsonResult(root.GetProperty("version").GetString()); + } + } + catch (HttpRequestException e) + { + _logger.LogError(e, $"Error encountered while sending GET request to {requestUri}"); + return new JsonResult("Unable to fetch installed version") { StatusCode = StatusCodes.Status500InternalServerError }; + } + + return new JsonResult("-"); + } + + private class LatestVersions + { + [JsonProperty("versions")] + public Versions Versions { get; set; } + } + + private class Versions + { + [JsonProperty("coreVersion")] + public string CoreVersion { get; set; } + + [JsonProperty("webVersion")] + public string WebVersion { get; set; } + + [JsonProperty("keyConnectorVersion")] + public string KeyConnectorVersion { get; set; } } } + +public enum ProjectType +{ + Core, + Web, +} diff --git a/src/Admin/Controllers/InfoController.cs b/src/Admin/Controllers/InfoController.cs index 7f39da6ed..0c097fde7 100644 --- a/src/Admin/Controllers/InfoController.cs +++ b/src/Admin/Controllers/InfoController.cs @@ -1,21 +1,20 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Admin.Controllers -{ - public class InfoController : Controller - { - [HttpGet("~/alive")] - [HttpGet("~/now")] - public DateTime GetAlive() - { - return DateTime.UtcNow; - } +namespace Bit.Admin.Controllers; - [HttpGet("~/version")] - public JsonResult GetVersion() - { - return Json(CoreHelpers.GetVersion()); - } +public class InfoController : Controller +{ + [HttpGet("~/alive")] + [HttpGet("~/now")] + public DateTime GetAlive() + { + return DateTime.UtcNow; + } + + [HttpGet("~/version")] + public JsonResult GetVersion() + { + return Json(CoreHelpers.GetVersion()); } } diff --git a/src/Admin/Controllers/LoginController.cs b/src/Admin/Controllers/LoginController.cs index a8e3e9dd0..47f9d5b34 100644 --- a/src/Admin/Controllers/LoginController.cs +++ b/src/Admin/Controllers/LoginController.cs @@ -3,91 +3,90 @@ using Bit.Core.Identity; using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.Mvc; -namespace Bit.Admin.Controllers +namespace Bit.Admin.Controllers; + +public class LoginController : Controller { - public class LoginController : Controller + private readonly PasswordlessSignInManager _signInManager; + + public LoginController( + PasswordlessSignInManager signInManager) { - private readonly PasswordlessSignInManager _signInManager; + _signInManager = signInManager; + } - public LoginController( - PasswordlessSignInManager signInManager) + public IActionResult Index(string returnUrl = null, int? error = null, int? success = null, + bool accessDenied = false) + { + if (!error.HasValue && accessDenied) { - _signInManager = signInManager; + error = 4; } - public IActionResult Index(string returnUrl = null, int? error = null, int? success = null, - bool accessDenied = false) + return View(new LoginModel { - if (!error.HasValue && accessDenied) - { - error = 4; - } + ReturnUrl = returnUrl, + Error = GetMessage(error), + Success = GetMessage(success) + }); + } - return View(new LoginModel - { - ReturnUrl = returnUrl, - Error = GetMessage(error), - Success = GetMessage(success) - }); - } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task Index(LoginModel model) + [HttpPost] + [ValidateAntiForgeryToken] + public async Task Index(LoginModel model) + { + if (ModelState.IsValid) { - if (ModelState.IsValid) - { - await _signInManager.PasswordlessSignInAsync(model.Email, model.ReturnUrl); - return RedirectToAction("Index", new - { - success = 3 - }); - } - - return View(model); - } - - public async Task Confirm(string email, string token, string returnUrl) - { - var result = await _signInManager.PasswordlessSignInAsync(email, token, true); - if (!result.Succeeded) - { - return RedirectToAction("Index", new - { - error = 2 - }); - } - - if (!string.IsNullOrWhiteSpace(returnUrl) && Url.IsLocalUrl(returnUrl)) - { - return Redirect(returnUrl); - } - - return RedirectToAction("Index", "Home"); - } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task Logout() - { - await _signInManager.SignOutAsync(); + await _signInManager.PasswordlessSignInAsync(model.Email, model.ReturnUrl); return RedirectToAction("Index", new { - success = 1 + success = 3 }); } - private string GetMessage(int? messageCode) + return View(model); + } + + public async Task Confirm(string email, string token, string returnUrl) + { + var result = await _signInManager.PasswordlessSignInAsync(email, token, true); + if (!result.Succeeded) { - return messageCode switch + return RedirectToAction("Index", new { - 1 => "You have been logged out.", - 2 => "This login confirmation link is invalid. Try logging in again.", - 3 => "If a valid admin user with this email address exists, " + - "we've sent you an email with a secure link to log in.", - 4 => "Access denied. Please log in.", - _ => null, - }; + error = 2 + }); } + + if (!string.IsNullOrWhiteSpace(returnUrl) && Url.IsLocalUrl(returnUrl)) + { + return Redirect(returnUrl); + } + + return RedirectToAction("Index", "Home"); + } + + [HttpPost] + [ValidateAntiForgeryToken] + public async Task Logout() + { + await _signInManager.SignOutAsync(); + return RedirectToAction("Index", new + { + success = 1 + }); + } + + private string GetMessage(int? messageCode) + { + return messageCode switch + { + 1 => "You have been logged out.", + 2 => "This login confirmation link is invalid. Try logging in again.", + 3 => "If a valid admin user with this email address exists, " + + "we've sent you an email with a secure link to log in.", + 4 => "Access denied. Please log in.", + _ => null, + }; } } diff --git a/src/Admin/Controllers/LogsController.cs b/src/Admin/Controllers/LogsController.cs index feb1a91b2..449c8cc86 100644 --- a/src/Admin/Controllers/LogsController.cs +++ b/src/Admin/Controllers/LogsController.cs @@ -7,87 +7,86 @@ using Microsoft.Azure.Cosmos; using Microsoft.Azure.Cosmos.Linq; using Serilog.Events; -namespace Bit.Admin.Controllers +namespace Bit.Admin.Controllers; + +[Authorize] +[SelfHosted(NotSelfHostedOnly = true)] +public class LogsController : Controller { - [Authorize] - [SelfHosted(NotSelfHostedOnly = true)] - public class LogsController : Controller + private const string Database = "Diagnostics"; + private const string Container = "Logs"; + + private readonly GlobalSettings _globalSettings; + + public LogsController(GlobalSettings globalSettings) { - private const string Database = "Diagnostics"; - private const string Container = "Logs"; + _globalSettings = globalSettings; + } - private readonly GlobalSettings _globalSettings; - - public LogsController(GlobalSettings globalSettings) + public async Task Index(string cursor = null, int count = 50, + LogEventLevel? level = null, string project = null, DateTime? start = null, DateTime? end = null) + { + using (var client = new CosmosClient(_globalSettings.DocumentDb.Uri, + _globalSettings.DocumentDb.Key)) { - _globalSettings = globalSettings; - } + var cosmosContainer = client.GetContainer(Database, Container); + var query = cosmosContainer.GetItemLinqQueryable( + requestOptions: new QueryRequestOptions() + { + MaxItemCount = count + }, + continuationToken: cursor + ).AsQueryable(); - public async Task Index(string cursor = null, int count = 50, - LogEventLevel? level = null, string project = null, DateTime? start = null, DateTime? end = null) - { - using (var client = new CosmosClient(_globalSettings.DocumentDb.Uri, - _globalSettings.DocumentDb.Key)) + if (level.HasValue) { - var cosmosContainer = client.GetContainer(Database, Container); - var query = cosmosContainer.GetItemLinqQueryable( - requestOptions: new QueryRequestOptions() - { - MaxItemCount = count - }, - continuationToken: cursor - ).AsQueryable(); - - if (level.HasValue) - { - query = query.Where(l => l.Level == level.Value.ToString()); - } - if (!string.IsNullOrWhiteSpace(project)) - { - query = query.Where(l => l.Properties != null && l.Properties["Project"] == (object)project); - } - if (start.HasValue) - { - query = query.Where(l => l.Timestamp >= start.Value); - } - if (end.HasValue) - { - query = query.Where(l => l.Timestamp <= end.Value); - } - var feedIterator = query.OrderByDescending(l => l.Timestamp).ToFeedIterator(); - var response = await feedIterator.ReadNextAsync(); - - return View(new LogsModel - { - Level = level, - Project = project, - Start = start, - End = end, - Items = response.ToList(), - Count = count, - Cursor = cursor, - NextCursor = response.ContinuationToken - }); + query = query.Where(l => l.Level == level.Value.ToString()); } - } - - public async Task View(Guid id) - { - using (var client = new CosmosClient(_globalSettings.DocumentDb.Uri, - _globalSettings.DocumentDb.Key)) + if (!string.IsNullOrWhiteSpace(project)) { - var cosmosContainer = client.GetContainer(Database, Container); - var query = cosmosContainer.GetItemLinqQueryable() - .AsQueryable() - .Where(l => l.Id == id.ToString()); - - var response = await query.ToFeedIterator().ReadNextAsync(); - if (response == null || response.Count == 0) - { - return RedirectToAction("Index"); - } - return View(response.First()); + query = query.Where(l => l.Properties != null && l.Properties["Project"] == (object)project); } + if (start.HasValue) + { + query = query.Where(l => l.Timestamp >= start.Value); + } + if (end.HasValue) + { + query = query.Where(l => l.Timestamp <= end.Value); + } + var feedIterator = query.OrderByDescending(l => l.Timestamp).ToFeedIterator(); + var response = await feedIterator.ReadNextAsync(); + + return View(new LogsModel + { + Level = level, + Project = project, + Start = start, + End = end, + Items = response.ToList(), + Count = count, + Cursor = cursor, + NextCursor = response.ContinuationToken + }); + } + } + + public async Task View(Guid id) + { + using (var client = new CosmosClient(_globalSettings.DocumentDb.Uri, + _globalSettings.DocumentDb.Key)) + { + var cosmosContainer = client.GetContainer(Database, Container); + var query = cosmosContainer.GetItemLinqQueryable() + .AsQueryable() + .Where(l => l.Id == id.ToString()); + + var response = await query.ToFeedIterator().ReadNextAsync(); + if (response == null || response.Count == 0) + { + return RedirectToAction("Index"); + } + return View(response.First()); } } } diff --git a/src/Admin/Controllers/OrganizationsController.cs b/src/Admin/Controllers/OrganizationsController.cs index eccc7ced6..76c00d025 100644 --- a/src/Admin/Controllers/OrganizationsController.cs +++ b/src/Admin/Controllers/OrganizationsController.cs @@ -11,207 +11,206 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Admin.Controllers +namespace Bit.Admin.Controllers; + +[Authorize] +public class OrganizationsController : Controller { - [Authorize] - public class OrganizationsController : Controller + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IOrganizationConnectionRepository _organizationConnectionRepository; + private readonly ISelfHostedSyncSponsorshipsCommand _syncSponsorshipsCommand; + private readonly ICipherRepository _cipherRepository; + private readonly ICollectionRepository _collectionRepository; + private readonly IGroupRepository _groupRepository; + private readonly IPolicyRepository _policyRepository; + private readonly IPaymentService _paymentService; + private readonly ILicensingService _licensingService; + private readonly IApplicationCacheService _applicationCacheService; + private readonly GlobalSettings _globalSettings; + private readonly IReferenceEventService _referenceEventService; + private readonly IUserService _userService; + private readonly ILogger _logger; + + public OrganizationsController( + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationConnectionRepository organizationConnectionRepository, + ISelfHostedSyncSponsorshipsCommand syncSponsorshipsCommand, + ICipherRepository cipherRepository, + ICollectionRepository collectionRepository, + IGroupRepository groupRepository, + IPolicyRepository policyRepository, + IPaymentService paymentService, + ILicensingService licensingService, + IApplicationCacheService applicationCacheService, + GlobalSettings globalSettings, + IReferenceEventService referenceEventService, + IUserService userService, + ILogger logger) { - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IOrganizationConnectionRepository _organizationConnectionRepository; - private readonly ISelfHostedSyncSponsorshipsCommand _syncSponsorshipsCommand; - private readonly ICipherRepository _cipherRepository; - private readonly ICollectionRepository _collectionRepository; - private readonly IGroupRepository _groupRepository; - private readonly IPolicyRepository _policyRepository; - private readonly IPaymentService _paymentService; - private readonly ILicensingService _licensingService; - private readonly IApplicationCacheService _applicationCacheService; - private readonly GlobalSettings _globalSettings; - private readonly IReferenceEventService _referenceEventService; - private readonly IUserService _userService; - private readonly ILogger _logger; - - public OrganizationsController( - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationConnectionRepository organizationConnectionRepository, - ISelfHostedSyncSponsorshipsCommand syncSponsorshipsCommand, - ICipherRepository cipherRepository, - ICollectionRepository collectionRepository, - IGroupRepository groupRepository, - IPolicyRepository policyRepository, - IPaymentService paymentService, - ILicensingService licensingService, - IApplicationCacheService applicationCacheService, - GlobalSettings globalSettings, - IReferenceEventService referenceEventService, - IUserService userService, - ILogger logger) - { - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _organizationConnectionRepository = organizationConnectionRepository; - _syncSponsorshipsCommand = syncSponsorshipsCommand; - _cipherRepository = cipherRepository; - _collectionRepository = collectionRepository; - _groupRepository = groupRepository; - _policyRepository = policyRepository; - _paymentService = paymentService; - _licensingService = licensingService; - _applicationCacheService = applicationCacheService; - _globalSettings = globalSettings; - _referenceEventService = referenceEventService; - _userService = userService; - _logger = logger; - } - - public async Task Index(string name = null, string userEmail = null, bool? paid = null, - int page = 1, int count = 25) - { - if (page < 1) - { - page = 1; - } - - if (count < 1) - { - count = 1; - } - - var skip = (page - 1) * count; - var organizations = await _organizationRepository.SearchAsync(name, userEmail, paid, skip, count); - return View(new OrganizationsModel - { - Items = organizations as List, - Name = string.IsNullOrWhiteSpace(name) ? null : name, - UserEmail = string.IsNullOrWhiteSpace(userEmail) ? null : userEmail, - Paid = paid, - Page = page, - Count = count, - Action = _globalSettings.SelfHosted ? "View" : "Edit", - SelfHosted = _globalSettings.SelfHosted - }); - } - - public async Task View(Guid id) - { - var organization = await _organizationRepository.GetByIdAsync(id); - if (organization == null) - { - return RedirectToAction("Index"); - } - - var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(id); - var collections = await _collectionRepository.GetManyByOrganizationIdAsync(id); - IEnumerable groups = null; - if (organization.UseGroups) - { - groups = await _groupRepository.GetManyByOrganizationIdAsync(id); - } - IEnumerable policies = null; - if (organization.UsePolicies) - { - policies = await _policyRepository.GetManyByOrganizationIdAsync(id); - } - var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(id); - var billingSyncConnection = _globalSettings.EnableCloudCommunication ? await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync) : null; - return View(new OrganizationViewModel(organization, billingSyncConnection, users, ciphers, collections, groups, policies)); - } - - [SelfHosted(NotSelfHostedOnly = true)] - public async Task Edit(Guid id) - { - var organization = await _organizationRepository.GetByIdAsync(id); - if (organization == null) - { - return RedirectToAction("Index"); - } - - var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(id); - var collections = await _collectionRepository.GetManyByOrganizationIdAsync(id); - IEnumerable groups = null; - if (organization.UseGroups) - { - groups = await _groupRepository.GetManyByOrganizationIdAsync(id); - } - IEnumerable policies = null; - if (organization.UsePolicies) - { - policies = await _policyRepository.GetManyByOrganizationIdAsync(id); - } - var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(id); - var billingInfo = await _paymentService.GetBillingAsync(organization); - var billingSyncConnection = _globalSettings.EnableCloudCommunication ? await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync) : null; - return View(new OrganizationEditModel(organization, users, ciphers, collections, groups, policies, - billingInfo, billingSyncConnection, _globalSettings)); - } - - [HttpPost] - [ValidateAntiForgeryToken] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task Edit(Guid id, OrganizationEditModel model) - { - var organization = await _organizationRepository.GetByIdAsync(id); - model.ToOrganization(organization); - await _organizationRepository.ReplaceAsync(organization); - await _applicationCacheService.UpsertOrganizationAbilityAsync(organization); - await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.OrganizationEditedByAdmin, organization) - { - EventRaisedByUser = _userService.GetUserName(User), - SalesAssistedTrialStarted = model.SalesAssistedTrialStarted, - }); - return RedirectToAction("Edit", new { id }); - } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task Delete(Guid id) - { - var organization = await _organizationRepository.GetByIdAsync(id); - if (organization != null) - { - await _organizationRepository.DeleteAsync(organization); - await _applicationCacheService.DeleteOrganizationAbilityAsync(organization.Id); - } - - return RedirectToAction("Index"); - } - - public async Task TriggerBillingSync(Guid id) - { - var organization = await _organizationRepository.GetByIdAsync(id); - if (organization == null) - { - return RedirectToAction("Index"); - } - var connection = (await _organizationConnectionRepository.GetEnabledByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync)).FirstOrDefault(); - if (connection != null) - { - try - { - var config = connection.GetConfig(); - await _syncSponsorshipsCommand.SyncOrganization(id, config.CloudOrganizationId, connection); - TempData["ConnectionActivated"] = id; - TempData["ConnectionError"] = null; - } - catch (Exception ex) - { - TempData["ConnectionError"] = ex.Message; - _logger.LogWarning(ex, "Error while attempting to do billing sync for organization with id '{OrganizationId}'", id); - } - - if (_globalSettings.SelfHosted) - { - return RedirectToAction("View", new { id }); - } - else - { - return RedirectToAction("Edit", new { id }); - } - } - return RedirectToAction("Index"); - } - + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _organizationConnectionRepository = organizationConnectionRepository; + _syncSponsorshipsCommand = syncSponsorshipsCommand; + _cipherRepository = cipherRepository; + _collectionRepository = collectionRepository; + _groupRepository = groupRepository; + _policyRepository = policyRepository; + _paymentService = paymentService; + _licensingService = licensingService; + _applicationCacheService = applicationCacheService; + _globalSettings = globalSettings; + _referenceEventService = referenceEventService; + _userService = userService; + _logger = logger; } + + public async Task Index(string name = null, string userEmail = null, bool? paid = null, + int page = 1, int count = 25) + { + if (page < 1) + { + page = 1; + } + + if (count < 1) + { + count = 1; + } + + var skip = (page - 1) * count; + var organizations = await _organizationRepository.SearchAsync(name, userEmail, paid, skip, count); + return View(new OrganizationsModel + { + Items = organizations as List, + Name = string.IsNullOrWhiteSpace(name) ? null : name, + UserEmail = string.IsNullOrWhiteSpace(userEmail) ? null : userEmail, + Paid = paid, + Page = page, + Count = count, + Action = _globalSettings.SelfHosted ? "View" : "Edit", + SelfHosted = _globalSettings.SelfHosted + }); + } + + public async Task View(Guid id) + { + var organization = await _organizationRepository.GetByIdAsync(id); + if (organization == null) + { + return RedirectToAction("Index"); + } + + var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(id); + var collections = await _collectionRepository.GetManyByOrganizationIdAsync(id); + IEnumerable groups = null; + if (organization.UseGroups) + { + groups = await _groupRepository.GetManyByOrganizationIdAsync(id); + } + IEnumerable policies = null; + if (organization.UsePolicies) + { + policies = await _policyRepository.GetManyByOrganizationIdAsync(id); + } + var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(id); + var billingSyncConnection = _globalSettings.EnableCloudCommunication ? await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync) : null; + return View(new OrganizationViewModel(organization, billingSyncConnection, users, ciphers, collections, groups, policies)); + } + + [SelfHosted(NotSelfHostedOnly = true)] + public async Task Edit(Guid id) + { + var organization = await _organizationRepository.GetByIdAsync(id); + if (organization == null) + { + return RedirectToAction("Index"); + } + + var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(id); + var collections = await _collectionRepository.GetManyByOrganizationIdAsync(id); + IEnumerable groups = null; + if (organization.UseGroups) + { + groups = await _groupRepository.GetManyByOrganizationIdAsync(id); + } + IEnumerable policies = null; + if (organization.UsePolicies) + { + policies = await _policyRepository.GetManyByOrganizationIdAsync(id); + } + var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(id); + var billingInfo = await _paymentService.GetBillingAsync(organization); + var billingSyncConnection = _globalSettings.EnableCloudCommunication ? await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync) : null; + return View(new OrganizationEditModel(organization, users, ciphers, collections, groups, policies, + billingInfo, billingSyncConnection, _globalSettings)); + } + + [HttpPost] + [ValidateAntiForgeryToken] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task Edit(Guid id, OrganizationEditModel model) + { + var organization = await _organizationRepository.GetByIdAsync(id); + model.ToOrganization(organization); + await _organizationRepository.ReplaceAsync(organization); + await _applicationCacheService.UpsertOrganizationAbilityAsync(organization); + await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.OrganizationEditedByAdmin, organization) + { + EventRaisedByUser = _userService.GetUserName(User), + SalesAssistedTrialStarted = model.SalesAssistedTrialStarted, + }); + return RedirectToAction("Edit", new { id }); + } + + [HttpPost] + [ValidateAntiForgeryToken] + public async Task Delete(Guid id) + { + var organization = await _organizationRepository.GetByIdAsync(id); + if (organization != null) + { + await _organizationRepository.DeleteAsync(organization); + await _applicationCacheService.DeleteOrganizationAbilityAsync(organization.Id); + } + + return RedirectToAction("Index"); + } + + public async Task TriggerBillingSync(Guid id) + { + var organization = await _organizationRepository.GetByIdAsync(id); + if (organization == null) + { + return RedirectToAction("Index"); + } + var connection = (await _organizationConnectionRepository.GetEnabledByOrganizationIdTypeAsync(id, OrganizationConnectionType.CloudBillingSync)).FirstOrDefault(); + if (connection != null) + { + try + { + var config = connection.GetConfig(); + await _syncSponsorshipsCommand.SyncOrganization(id, config.CloudOrganizationId, connection); + TempData["ConnectionActivated"] = id; + TempData["ConnectionError"] = null; + } + catch (Exception ex) + { + TempData["ConnectionError"] = ex.Message; + _logger.LogWarning(ex, "Error while attempting to do billing sync for organization with id '{OrganizationId}'", id); + } + + if (_globalSettings.SelfHosted) + { + return RedirectToAction("View", new { id }); + } + else + { + return RedirectToAction("Edit", new { id }); + } + } + return RedirectToAction("Index"); + } + } diff --git a/src/Admin/Controllers/ProvidersController.cs b/src/Admin/Controllers/ProvidersController.cs index e0e448499..a141b9fd0 100644 --- a/src/Admin/Controllers/ProvidersController.cs +++ b/src/Admin/Controllers/ProvidersController.cs @@ -7,128 +7,127 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Admin.Controllers +namespace Bit.Admin.Controllers; + +[Authorize] +[SelfHosted(NotSelfHostedOnly = true)] +public class ProvidersController : Controller { - [Authorize] - [SelfHosted(NotSelfHostedOnly = true)] - public class ProvidersController : Controller + private readonly IProviderRepository _providerRepository; + private readonly IProviderUserRepository _providerUserRepository; + private readonly IProviderOrganizationRepository _providerOrganizationRepository; + private readonly GlobalSettings _globalSettings; + private readonly IApplicationCacheService _applicationCacheService; + private readonly IProviderService _providerService; + + public ProvidersController(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository, + IProviderOrganizationRepository providerOrganizationRepository, IProviderService providerService, + GlobalSettings globalSettings, IApplicationCacheService applicationCacheService) { - private readonly IProviderRepository _providerRepository; - private readonly IProviderUserRepository _providerUserRepository; - private readonly IProviderOrganizationRepository _providerOrganizationRepository; - private readonly GlobalSettings _globalSettings; - private readonly IApplicationCacheService _applicationCacheService; - private readonly IProviderService _providerService; + _providerRepository = providerRepository; + _providerUserRepository = providerUserRepository; + _providerOrganizationRepository = providerOrganizationRepository; + _providerService = providerService; + _globalSettings = globalSettings; + _applicationCacheService = applicationCacheService; + } - public ProvidersController(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository, - IProviderOrganizationRepository providerOrganizationRepository, IProviderService providerService, - GlobalSettings globalSettings, IApplicationCacheService applicationCacheService) + public async Task Index(string name = null, string userEmail = null, int page = 1, int count = 25) + { + if (page < 1) { - _providerRepository = providerRepository; - _providerUserRepository = providerUserRepository; - _providerOrganizationRepository = providerOrganizationRepository; - _providerService = providerService; - _globalSettings = globalSettings; - _applicationCacheService = applicationCacheService; + page = 1; } - public async Task Index(string name = null, string userEmail = null, int page = 1, int count = 25) + if (count < 1) { - if (page < 1) - { - page = 1; - } - - if (count < 1) - { - count = 1; - } - - var skip = (page - 1) * count; - var providers = await _providerRepository.SearchAsync(name, userEmail, skip, count); - return View(new ProvidersModel - { - Items = providers as List, - Name = string.IsNullOrWhiteSpace(name) ? null : name, - UserEmail = string.IsNullOrWhiteSpace(userEmail) ? null : userEmail, - Page = page, - Count = count, - Action = _globalSettings.SelfHosted ? "View" : "Edit", - SelfHosted = _globalSettings.SelfHosted - }); + count = 1; } - public IActionResult Create(string ownerEmail = null) + var skip = (page - 1) * count; + var providers = await _providerRepository.SearchAsync(name, userEmail, skip, count); + return View(new ProvidersModel { - return View(new CreateProviderModel - { - OwnerEmail = ownerEmail - }); + Items = providers as List, + Name = string.IsNullOrWhiteSpace(name) ? null : name, + UserEmail = string.IsNullOrWhiteSpace(userEmail) ? null : userEmail, + Page = page, + Count = count, + Action = _globalSettings.SelfHosted ? "View" : "Edit", + SelfHosted = _globalSettings.SelfHosted + }); + } + + public IActionResult Create(string ownerEmail = null) + { + return View(new CreateProviderModel + { + OwnerEmail = ownerEmail + }); + } + + [HttpPost] + [ValidateAntiForgeryToken] + public async Task Create(CreateProviderModel model) + { + if (!ModelState.IsValid) + { + return View(model); } - [HttpPost] - [ValidateAntiForgeryToken] - public async Task Create(CreateProviderModel model) + await _providerService.CreateAsync(model.OwnerEmail); + + return RedirectToAction("Index"); + } + + public async Task View(Guid id) + { + var provider = await _providerRepository.GetByIdAsync(id); + if (provider == null) { - if (!ModelState.IsValid) - { - return View(model); - } - - await _providerService.CreateAsync(model.OwnerEmail); - return RedirectToAction("Index"); } - public async Task View(Guid id) - { - var provider = await _providerRepository.GetByIdAsync(id); - if (provider == null) - { - return RedirectToAction("Index"); - } + var users = await _providerUserRepository.GetManyDetailsByProviderAsync(id); + var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(id); + return View(new ProviderViewModel(provider, users, providerOrganizations)); + } - var users = await _providerUserRepository.GetManyDetailsByProviderAsync(id); - var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(id); - return View(new ProviderViewModel(provider, users, providerOrganizations)); + [SelfHosted(NotSelfHostedOnly = true)] + public async Task Edit(Guid id) + { + var provider = await _providerRepository.GetByIdAsync(id); + if (provider == null) + { + return RedirectToAction("Index"); } - [SelfHosted(NotSelfHostedOnly = true)] - public async Task Edit(Guid id) - { - var provider = await _providerRepository.GetByIdAsync(id); - if (provider == null) - { - return RedirectToAction("Index"); - } + var users = await _providerUserRepository.GetManyDetailsByProviderAsync(id); + var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(id); + return View(new ProviderEditModel(provider, users, providerOrganizations)); + } - var users = await _providerUserRepository.GetManyDetailsByProviderAsync(id); - var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(id); - return View(new ProviderEditModel(provider, users, providerOrganizations)); + [HttpPost] + [ValidateAntiForgeryToken] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task Edit(Guid id, ProviderEditModel model) + { + var provider = await _providerRepository.GetByIdAsync(id); + if (provider == null) + { + return RedirectToAction("Index"); } - [HttpPost] - [ValidateAntiForgeryToken] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task Edit(Guid id, ProviderEditModel model) - { - var provider = await _providerRepository.GetByIdAsync(id); - if (provider == null) - { - return RedirectToAction("Index"); - } + model.ToProvider(provider); + await _providerRepository.ReplaceAsync(provider); + await _applicationCacheService.UpsertProviderAbilityAsync(provider); + return RedirectToAction("Edit", new { id }); + } - model.ToProvider(provider); - await _providerRepository.ReplaceAsync(provider); - await _applicationCacheService.UpsertProviderAbilityAsync(provider); - return RedirectToAction("Edit", new { id }); - } - - public async Task ResendInvite(Guid ownerId, Guid providerId) - { - await _providerService.ResendProviderSetupInviteEmailAsync(providerId, ownerId); - TempData["InviteResentTo"] = ownerId; - return RedirectToAction("Edit", new { id = providerId }); - } + public async Task ResendInvite(Guid ownerId, Guid providerId) + { + await _providerService.ResendProviderSetupInviteEmailAsync(providerId, ownerId); + TempData["InviteResentTo"] = ownerId; + return RedirectToAction("Edit", new { id = providerId }); } } diff --git a/src/Admin/Controllers/ToolsController.cs b/src/Admin/Controllers/ToolsController.cs index 9a483c137..9bd6189b3 100644 --- a/src/Admin/Controllers/ToolsController.cs +++ b/src/Admin/Controllers/ToolsController.cs @@ -10,410 +10,373 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Admin.Controllers +namespace Bit.Admin.Controllers; + +[Authorize] +[SelfHosted(NotSelfHostedOnly = true)] +public class ToolsController : Controller { - [Authorize] - [SelfHosted(NotSelfHostedOnly = true)] - public class ToolsController : Controller + private readonly GlobalSettings _globalSettings; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationService _organizationService; + private readonly IUserService _userService; + private readonly ITransactionRepository _transactionRepository; + private readonly IInstallationRepository _installationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IPaymentService _paymentService; + private readonly ITaxRateRepository _taxRateRepository; + private readonly IStripeAdapter _stripeAdapter; + + public ToolsController( + GlobalSettings globalSettings, + IOrganizationRepository organizationRepository, + IOrganizationService organizationService, + IUserService userService, + ITransactionRepository transactionRepository, + IInstallationRepository installationRepository, + IOrganizationUserRepository organizationUserRepository, + ITaxRateRepository taxRateRepository, + IPaymentService paymentService, + IStripeAdapter stripeAdapter) { - private readonly GlobalSettings _globalSettings; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationService _organizationService; - private readonly IUserService _userService; - private readonly ITransactionRepository _transactionRepository; - private readonly IInstallationRepository _installationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPaymentService _paymentService; - private readonly ITaxRateRepository _taxRateRepository; - private readonly IStripeAdapter _stripeAdapter; + _globalSettings = globalSettings; + _organizationRepository = organizationRepository; + _organizationService = organizationService; + _userService = userService; + _transactionRepository = transactionRepository; + _installationRepository = installationRepository; + _organizationUserRepository = organizationUserRepository; + _taxRateRepository = taxRateRepository; + _paymentService = paymentService; + _stripeAdapter = stripeAdapter; + } - public ToolsController( - GlobalSettings globalSettings, - IOrganizationRepository organizationRepository, - IOrganizationService organizationService, - IUserService userService, - ITransactionRepository transactionRepository, - IInstallationRepository installationRepository, - IOrganizationUserRepository organizationUserRepository, - ITaxRateRepository taxRateRepository, - IPaymentService paymentService, - IStripeAdapter stripeAdapter) + public IActionResult ChargeBraintree() + { + return View(new ChargeBraintreeModel()); + } + + [HttpPost] + [ValidateAntiForgeryToken] + public async Task ChargeBraintree(ChargeBraintreeModel model) + { + if (!ModelState.IsValid) { - _globalSettings = globalSettings; - _organizationRepository = organizationRepository; - _organizationService = organizationService; - _userService = userService; - _transactionRepository = transactionRepository; - _installationRepository = installationRepository; - _organizationUserRepository = organizationUserRepository; - _taxRateRepository = taxRateRepository; - _paymentService = paymentService; - _stripeAdapter = stripeAdapter; - } - - public IActionResult ChargeBraintree() - { - return View(new ChargeBraintreeModel()); - } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task ChargeBraintree(ChargeBraintreeModel model) - { - if (!ModelState.IsValid) - { - return View(model); - } - - var btGateway = new Braintree.BraintreeGateway - { - Environment = _globalSettings.Braintree.Production ? - Braintree.Environment.PRODUCTION : Braintree.Environment.SANDBOX, - MerchantId = _globalSettings.Braintree.MerchantId, - PublicKey = _globalSettings.Braintree.PublicKey, - PrivateKey = _globalSettings.Braintree.PrivateKey - }; - - var btObjIdField = model.Id[0] == 'o' ? "organization_id" : "user_id"; - var btObjId = new Guid(model.Id.Substring(1, 32)); - - var transactionResult = await btGateway.Transaction.SaleAsync( - new Braintree.TransactionRequest - { - Amount = model.Amount.Value, - CustomerId = model.Id, - Options = new Braintree.TransactionOptionsRequest - { - SubmitForSettlement = true, - PayPal = new Braintree.TransactionOptionsPayPalRequest - { - CustomField = $"{btObjIdField}:{btObjId}" - } - }, - CustomFields = new Dictionary - { - [btObjIdField] = btObjId.ToString() - } - }); - - if (!transactionResult.IsSuccess()) - { - ModelState.AddModelError(string.Empty, "Charge failed. " + - "Refer to Braintree admin portal for more information."); - } - else - { - model.TransactionId = transactionResult.Target.Id; - model.PayPalTransactionId = transactionResult.Target?.PayPalDetails?.CaptureId; - } return View(model); } - public IActionResult CreateTransaction(Guid? organizationId = null, Guid? userId = null) + var btGateway = new Braintree.BraintreeGateway { - return View("CreateUpdateTransaction", new CreateUpdateTransactionModel + Environment = _globalSettings.Braintree.Production ? + Braintree.Environment.PRODUCTION : Braintree.Environment.SANDBOX, + MerchantId = _globalSettings.Braintree.MerchantId, + PublicKey = _globalSettings.Braintree.PublicKey, + PrivateKey = _globalSettings.Braintree.PrivateKey + }; + + var btObjIdField = model.Id[0] == 'o' ? "organization_id" : "user_id"; + var btObjId = new Guid(model.Id.Substring(1, 32)); + + var transactionResult = await btGateway.Transaction.SaleAsync( + new Braintree.TransactionRequest { - OrganizationId = organizationId, - UserId = userId + Amount = model.Amount.Value, + CustomerId = model.Id, + Options = new Braintree.TransactionOptionsRequest + { + SubmitForSettlement = true, + PayPal = new Braintree.TransactionOptionsPayPalRequest + { + CustomField = $"{btObjIdField}:{btObjId}" + } + }, + CustomFields = new Dictionary + { + [btObjIdField] = btObjId.ToString() + } }); + + if (!transactionResult.IsSuccess()) + { + ModelState.AddModelError(string.Empty, "Charge failed. " + + "Refer to Braintree admin portal for more information."); + } + else + { + model.TransactionId = transactionResult.Target.Id; + model.PayPalTransactionId = transactionResult.Target?.PayPalDetails?.CaptureId; + } + return View(model); + } + + public IActionResult CreateTransaction(Guid? organizationId = null, Guid? userId = null) + { + return View("CreateUpdateTransaction", new CreateUpdateTransactionModel + { + OrganizationId = organizationId, + UserId = userId + }); + } + + [HttpPost] + [ValidateAntiForgeryToken] + public async Task CreateTransaction(CreateUpdateTransactionModel model) + { + if (!ModelState.IsValid) + { + return View("CreateUpdateTransaction", model); } - [HttpPost] - [ValidateAntiForgeryToken] - public async Task CreateTransaction(CreateUpdateTransactionModel model) + await _transactionRepository.CreateAsync(model.ToTransaction()); + if (model.UserId.HasValue) { - if (!ModelState.IsValid) - { - return View("CreateUpdateTransaction", model); - } + return RedirectToAction("Edit", "Users", new { id = model.UserId }); + } + else + { + return RedirectToAction("Edit", "Organizations", new { id = model.OrganizationId }); + } + } - await _transactionRepository.CreateAsync(model.ToTransaction()); - if (model.UserId.HasValue) - { - return RedirectToAction("Edit", "Users", new { id = model.UserId }); - } - else - { - return RedirectToAction("Edit", "Organizations", new { id = model.OrganizationId }); - } + public async Task EditTransaction(Guid id) + { + var transaction = await _transactionRepository.GetByIdAsync(id); + if (transaction == null) + { + return RedirectToAction("Index", "Home"); + } + return View("CreateUpdateTransaction", new CreateUpdateTransactionModel(transaction)); + } + + [HttpPost] + [ValidateAntiForgeryToken] + public async Task EditTransaction(Guid id, CreateUpdateTransactionModel model) + { + if (!ModelState.IsValid) + { + return View("CreateUpdateTransaction", model); + } + await _transactionRepository.ReplaceAsync(model.ToTransaction(id)); + if (model.UserId.HasValue) + { + return RedirectToAction("Edit", "Users", new { id = model.UserId }); + } + else + { + return RedirectToAction("Edit", "Organizations", new { id = model.OrganizationId }); + } + } + + public IActionResult PromoteAdmin() + { + return View(); + } + + [HttpPost] + [ValidateAntiForgeryToken] + public async Task PromoteAdmin(PromoteAdminModel model) + { + if (!ModelState.IsValid) + { + return View(model); } - public async Task EditTransaction(Guid id) + var orgUsers = await _organizationUserRepository.GetManyByOrganizationAsync( + model.OrganizationId.Value, null); + var user = orgUsers.FirstOrDefault(u => u.UserId == model.UserId.Value); + if (user == null) { - var transaction = await _transactionRepository.GetByIdAsync(id); - if (transaction == null) - { - return RedirectToAction("Index", "Home"); - } - return View("CreateUpdateTransaction", new CreateUpdateTransactionModel(transaction)); + ModelState.AddModelError(nameof(model.UserId), "User Id not found in this organization."); + } + else if (user.Type != Core.Enums.OrganizationUserType.Admin) + { + ModelState.AddModelError(nameof(model.UserId), "User is not an admin of this organization."); } - [HttpPost] - [ValidateAntiForgeryToken] - public async Task EditTransaction(Guid id, CreateUpdateTransactionModel model) + if (!ModelState.IsValid) { - if (!ModelState.IsValid) - { - return View("CreateUpdateTransaction", model); - } - await _transactionRepository.ReplaceAsync(model.ToTransaction(id)); - if (model.UserId.HasValue) - { - return RedirectToAction("Edit", "Users", new { id = model.UserId }); - } - else - { - return RedirectToAction("Edit", "Organizations", new { id = model.OrganizationId }); - } + return View(model); } - public IActionResult PromoteAdmin() + user.Type = Core.Enums.OrganizationUserType.Owner; + await _organizationUserRepository.ReplaceAsync(user); + return RedirectToAction("Edit", "Organizations", new { id = model.OrganizationId.Value }); + } + + public IActionResult GenerateLicense() + { + return View(); + } + + [HttpPost] + [ValidateAntiForgeryToken] + public async Task GenerateLicense(LicenseModel model) + { + if (!ModelState.IsValid) { - return View(); + return View(model); } - [HttpPost] - [ValidateAntiForgeryToken] - public async Task PromoteAdmin(PromoteAdminModel model) + User user = null; + Organization organization = null; + if (model.UserId.HasValue) { - if (!ModelState.IsValid) - { - return View(model); - } - - var orgUsers = await _organizationUserRepository.GetManyByOrganizationAsync( - model.OrganizationId.Value, null); - var user = orgUsers.FirstOrDefault(u => u.UserId == model.UserId.Value); + user = await _userService.GetUserByIdAsync(model.UserId.Value); if (user == null) { - ModelState.AddModelError(nameof(model.UserId), "User Id not found in this organization."); + ModelState.AddModelError(nameof(model.UserId), "User Id not found."); } - else if (user.Type != Core.Enums.OrganizationUserType.Admin) - { - ModelState.AddModelError(nameof(model.UserId), "User is not an admin of this organization."); - } - - if (!ModelState.IsValid) - { - return View(model); - } - - user.Type = Core.Enums.OrganizationUserType.Owner; - await _organizationUserRepository.ReplaceAsync(user); - return RedirectToAction("Edit", "Organizations", new { id = model.OrganizationId.Value }); } - - public IActionResult GenerateLicense() + else if (model.OrganizationId.HasValue) { - return View(); + organization = await _organizationRepository.GetByIdAsync(model.OrganizationId.Value); + if (organization == null) + { + ModelState.AddModelError(nameof(model.OrganizationId), "Organization not found."); + } + else if (!organization.Enabled) + { + ModelState.AddModelError(nameof(model.OrganizationId), "Organization is disabled."); + } } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task GenerateLicense(LicenseModel model) + if (model.InstallationId.HasValue) { - if (!ModelState.IsValid) + var installation = await _installationRepository.GetByIdAsync(model.InstallationId.Value); + if (installation == null) { - return View(model); + ModelState.AddModelError(nameof(model.InstallationId), "Installation not found."); } - - User user = null; - Organization organization = null; - if (model.UserId.HasValue) + else if (!installation.Enabled) { - user = await _userService.GetUserByIdAsync(model.UserId.Value); - if (user == null) - { - ModelState.AddModelError(nameof(model.UserId), "User Id not found."); - } - } - else if (model.OrganizationId.HasValue) - { - organization = await _organizationRepository.GetByIdAsync(model.OrganizationId.Value); - if (organization == null) - { - ModelState.AddModelError(nameof(model.OrganizationId), "Organization not found."); - } - else if (!organization.Enabled) - { - ModelState.AddModelError(nameof(model.OrganizationId), "Organization is disabled."); - } - } - if (model.InstallationId.HasValue) - { - var installation = await _installationRepository.GetByIdAsync(model.InstallationId.Value); - if (installation == null) - { - ModelState.AddModelError(nameof(model.InstallationId), "Installation not found."); - } - else if (!installation.Enabled) - { - ModelState.AddModelError(nameof(model.OrganizationId), "Installation is disabled."); - } - } - - if (!ModelState.IsValid) - { - return View(model); - } - - if (organization != null) - { - var license = await _organizationService.GenerateLicenseAsync(organization, - model.InstallationId.Value, model.Version); - var ms = new MemoryStream(); - await JsonSerializer.SerializeAsync(ms, license, JsonHelpers.Indented); - ms.Seek(0, SeekOrigin.Begin); - return File(ms, "text/plain", "bitwarden_organization_license.json"); - } - else if (user != null) - { - var license = await _userService.GenerateLicenseAsync(user, null, model.Version); - var ms = new MemoryStream(); - ms.Seek(0, SeekOrigin.Begin); - await JsonSerializer.SerializeAsync(ms, license, JsonHelpers.Indented); - ms.Seek(0, SeekOrigin.Begin); - return File(ms, "text/plain", "bitwarden_premium_license.json"); - } - else - { - throw new Exception("No license to generate."); + ModelState.AddModelError(nameof(model.OrganizationId), "Installation is disabled."); } } - public async Task TaxRate(int page = 1, int count = 25) + if (!ModelState.IsValid) { - if (page < 1) - { - page = 1; - } - - if (count < 1) - { - count = 1; - } - - var skip = (page - 1) * count; - var rates = await _taxRateRepository.SearchAsync(skip, count); - return View(new TaxRatesModel - { - Items = rates.ToList(), - Page = page, - Count = count - }); - } - - public async Task TaxRateAddEdit(string stripeTaxRateId = null) - { - if (string.IsNullOrWhiteSpace(stripeTaxRateId)) - { - return View(new TaxRateAddEditModel()); - } - - var rate = await _taxRateRepository.GetByIdAsync(stripeTaxRateId); - var model = new TaxRateAddEditModel() - { - StripeTaxRateId = stripeTaxRateId, - Country = rate.Country, - State = rate.State, - PostalCode = rate.PostalCode, - Rate = rate.Rate - }; - return View(model); } - [ValidateAntiForgeryToken] - public async Task TaxRateUpload(IFormFile file) + if (organization != null) { - if (file == null || file.Length == 0) - { - throw new ArgumentNullException(nameof(file)); - } + var license = await _organizationService.GenerateLicenseAsync(organization, + model.InstallationId.Value, model.Version); + var ms = new MemoryStream(); + await JsonSerializer.SerializeAsync(ms, license, JsonHelpers.Indented); + ms.Seek(0, SeekOrigin.Begin); + return File(ms, "text/plain", "bitwarden_organization_license.json"); + } + else if (user != null) + { + var license = await _userService.GenerateLicenseAsync(user, null, model.Version); + var ms = new MemoryStream(); + ms.Seek(0, SeekOrigin.Begin); + await JsonSerializer.SerializeAsync(ms, license, JsonHelpers.Indented); + ms.Seek(0, SeekOrigin.Begin); + return File(ms, "text/plain", "bitwarden_premium_license.json"); + } + else + { + throw new Exception("No license to generate."); + } + } - // Build rates and validate them first before updating DB & Stripe - var taxRateUpdates = new List(); - var currentTaxRates = await _taxRateRepository.GetAllActiveAsync(); - using var reader = new StreamReader(file.OpenReadStream()); - while (!reader.EndOfStream) - { - var line = await reader.ReadLineAsync(); - if (string.IsNullOrWhiteSpace(line)) - { - continue; - } - var taxParts = line.Split(','); - if (taxParts.Length < 2) - { - throw new Exception($"This line is not in the format of ,,,: {line}"); - } - var postalCode = taxParts[0].Trim(); - if (string.IsNullOrWhiteSpace(postalCode)) - { - throw new Exception($"'{line}' is not valid, the first element must contain a postal code."); - } - if (!decimal.TryParse(taxParts[1], out var rate) || rate <= 0M || rate > 100) - { - throw new Exception($"{taxParts[1]} is not a valid rate/decimal for {postalCode}"); - } - var state = taxParts.Length > 2 ? taxParts[2] : null; - var country = (taxParts.Length > 3 ? taxParts[3] : null); - if (string.IsNullOrWhiteSpace(country)) - { - country = "US"; - } - var taxRate = currentTaxRates.FirstOrDefault(r => r.Country == country && r.PostalCode == postalCode) ?? - new TaxRate - { - Country = country, - PostalCode = postalCode, - Active = true, - }; - taxRate.Rate = rate; - taxRate.State = state ?? taxRate.State; - taxRateUpdates.Add(taxRate); - } - - foreach (var taxRate in taxRateUpdates) - { - if (!string.IsNullOrWhiteSpace(taxRate.Id)) - { - await _paymentService.UpdateTaxRateAsync(taxRate); - } - else - { - await _paymentService.CreateTaxRateAsync(taxRate); - } - } - - return RedirectToAction("TaxRate"); + public async Task TaxRate(int page = 1, int count = 25) + { + if (page < 1) + { + page = 1; } - [HttpPost] - [ValidateAntiForgeryToken] - public async Task TaxRateAddEdit(TaxRateAddEditModel model) + if (count < 1) { - var existingRateCheck = await _taxRateRepository.GetByLocationAsync(new TaxRate() { Country = model.Country, PostalCode = model.PostalCode }); - if (existingRateCheck.Any()) + count = 1; + } + + var skip = (page - 1) * count; + var rates = await _taxRateRepository.SearchAsync(skip, count); + return View(new TaxRatesModel + { + Items = rates.ToList(), + Page = page, + Count = count + }); + } + + public async Task TaxRateAddEdit(string stripeTaxRateId = null) + { + if (string.IsNullOrWhiteSpace(stripeTaxRateId)) + { + return View(new TaxRateAddEditModel()); + } + + var rate = await _taxRateRepository.GetByIdAsync(stripeTaxRateId); + var model = new TaxRateAddEditModel() + { + StripeTaxRateId = stripeTaxRateId, + Country = rate.Country, + State = rate.State, + PostalCode = rate.PostalCode, + Rate = rate.Rate + }; + + return View(model); + } + + [ValidateAntiForgeryToken] + public async Task TaxRateUpload(IFormFile file) + { + if (file == null || file.Length == 0) + { + throw new ArgumentNullException(nameof(file)); + } + + // Build rates and validate them first before updating DB & Stripe + var taxRateUpdates = new List(); + var currentTaxRates = await _taxRateRepository.GetAllActiveAsync(); + using var reader = new StreamReader(file.OpenReadStream()); + while (!reader.EndOfStream) + { + var line = await reader.ReadLineAsync(); + if (string.IsNullOrWhiteSpace(line)) { - ModelState.AddModelError(nameof(model.PostalCode), "A tax rate already exists for this Country/Postal Code combination."); + continue; } - - if (!ModelState.IsValid) + var taxParts = line.Split(','); + if (taxParts.Length < 2) { - return View(model); + throw new Exception($"This line is not in the format of ,,,: {line}"); } - - var taxRate = new TaxRate() + var postalCode = taxParts[0].Trim(); + if (string.IsNullOrWhiteSpace(postalCode)) { - Id = model.StripeTaxRateId, - Country = model.Country, - State = model.State, - PostalCode = model.PostalCode, - Rate = model.Rate - }; + throw new Exception($"'{line}' is not valid, the first element must contain a postal code."); + } + if (!decimal.TryParse(taxParts[1], out var rate) || rate <= 0M || rate > 100) + { + throw new Exception($"{taxParts[1]} is not a valid rate/decimal for {postalCode}"); + } + var state = taxParts.Length > 2 ? taxParts[2] : null; + var country = (taxParts.Length > 3 ? taxParts[3] : null); + if (string.IsNullOrWhiteSpace(country)) + { + country = "US"; + } + var taxRate = currentTaxRates.FirstOrDefault(r => r.Country == country && r.PostalCode == postalCode) ?? + new TaxRate + { + Country = country, + PostalCode = postalCode, + Active = true, + }; + taxRate.Rate = rate; + taxRate.State = state ?? taxRate.State; + taxRateUpdates.Add(taxRate); + } - if (!string.IsNullOrWhiteSpace(model.StripeTaxRateId)) + foreach (var taxRate in taxRateUpdates) + { + if (!string.IsNullOrWhiteSpace(taxRate.Id)) { await _paymentService.UpdateTaxRateAsync(taxRate); } @@ -421,139 +384,175 @@ namespace Bit.Admin.Controllers { await _paymentService.CreateTaxRateAsync(taxRate); } - - return RedirectToAction("TaxRate"); } - public async Task TaxRateArchive(string stripeTaxRateId) - { - if (!string.IsNullOrWhiteSpace(stripeTaxRateId)) - { - await _paymentService.ArchiveTaxRateAsync(new TaxRate() { Id = stripeTaxRateId }); - } + return RedirectToAction("TaxRate"); + } - return RedirectToAction("TaxRate"); + [HttpPost] + [ValidateAntiForgeryToken] + public async Task TaxRateAddEdit(TaxRateAddEditModel model) + { + var existingRateCheck = await _taxRateRepository.GetByLocationAsync(new TaxRate() { Country = model.Country, PostalCode = model.PostalCode }); + if (existingRateCheck.Any()) + { + ModelState.AddModelError(nameof(model.PostalCode), "A tax rate already exists for this Country/Postal Code combination."); } - public async Task StripeSubscriptions(StripeSubscriptionListOptions options) + if (!ModelState.IsValid) { - options = options ?? new StripeSubscriptionListOptions(); - options.Limit = 10; - options.Expand = new List() { "data.customer", "data.latest_invoice" }; - options.SelectAll = false; - - var subscriptions = await _stripeAdapter.SubscriptionListAsync(options); - - options.StartingAfter = subscriptions.LastOrDefault()?.Id; - options.EndingBefore = await StripeSubscriptionsGetHasPreviousPage(subscriptions, options) ? - subscriptions.FirstOrDefault()?.Id : - null; - - var model = new StripeSubscriptionsModel() - { - Items = subscriptions.Select(s => new StripeSubscriptionRowModel(s)).ToList(), - Prices = (await _stripeAdapter.PriceListAsync(new Stripe.PriceListOptions() { Limit = 100 })).Data, - TestClocks = await _stripeAdapter.TestClockListAsync(), - Filter = options - }; return View(model); } - [HttpPost] - public async Task StripeSubscriptions([FromForm] StripeSubscriptionsModel model) + var taxRate = new TaxRate() { - if (!ModelState.IsValid) - { - model.Prices = (await _stripeAdapter.PriceListAsync(new Stripe.PriceListOptions() { Limit = 100 })).Data; - model.TestClocks = await _stripeAdapter.TestClockListAsync(); - return View(model); - } + Id = model.StripeTaxRateId, + Country = model.Country, + State = model.State, + PostalCode = model.PostalCode, + Rate = model.Rate + }; - if (model.Action == StripeSubscriptionsAction.Export || model.Action == StripeSubscriptionsAction.BulkCancel) - { - var subscriptions = model.Filter.SelectAll ? - await _stripeAdapter.SubscriptionListAsync(model.Filter) : - model.Items.Where(x => x.Selected).Select(x => x.Subscription); - - if (model.Action == StripeSubscriptionsAction.Export) - { - return StripeSubscriptionsExport(subscriptions); - } - - if (model.Action == StripeSubscriptionsAction.BulkCancel) - { - await StripeSubscriptionsCancel(subscriptions); - } - } - else - { - if (model.Action == StripeSubscriptionsAction.PreviousPage || model.Action == StripeSubscriptionsAction.Search) - { - model.Filter.StartingAfter = null; - } - if (model.Action == StripeSubscriptionsAction.NextPage || model.Action == StripeSubscriptionsAction.Search) - { - model.Filter.EndingBefore = null; - } - } - - - return RedirectToAction("StripeSubscriptions", model.Filter); + if (!string.IsNullOrWhiteSpace(model.StripeTaxRateId)) + { + await _paymentService.UpdateTaxRateAsync(taxRate); + } + else + { + await _paymentService.CreateTaxRateAsync(taxRate); } - // This requires a redundant API call to Stripe because of the way they handle pagination. - // The StartingBefore value has to be infered from the list we get, and isn't supplied by Stripe. - private async Task StripeSubscriptionsGetHasPreviousPage(List subscriptions, StripeSubscriptionListOptions options) + return RedirectToAction("TaxRate"); + } + + public async Task TaxRateArchive(string stripeTaxRateId) + { + if (!string.IsNullOrWhiteSpace(stripeTaxRateId)) { - var hasPreviousPage = false; - if (subscriptions.FirstOrDefault()?.Id != null) - { - var previousPageSearchOptions = new StripeSubscriptionListOptions() - { - EndingBefore = subscriptions.FirstOrDefault().Id, - Limit = 1, - Status = options.Status, - CurrentPeriodEndDate = options.CurrentPeriodEndDate, - CurrentPeriodEndRange = options.CurrentPeriodEndRange, - Price = options.Price - }; - hasPreviousPage = (await _stripeAdapter.SubscriptionListAsync(previousPageSearchOptions)).Count > 0; - } - return hasPreviousPage; + await _paymentService.ArchiveTaxRateAsync(new TaxRate() { Id = stripeTaxRateId }); } - private async Task StripeSubscriptionsCancel(IEnumerable subscriptions) + return RedirectToAction("TaxRate"); + } + + public async Task StripeSubscriptions(StripeSubscriptionListOptions options) + { + options = options ?? new StripeSubscriptionListOptions(); + options.Limit = 10; + options.Expand = new List() { "data.customer", "data.latest_invoice" }; + options.SelectAll = false; + + var subscriptions = await _stripeAdapter.SubscriptionListAsync(options); + + options.StartingAfter = subscriptions.LastOrDefault()?.Id; + options.EndingBefore = await StripeSubscriptionsGetHasPreviousPage(subscriptions, options) ? + subscriptions.FirstOrDefault()?.Id : + null; + + var model = new StripeSubscriptionsModel() { - foreach (var s in subscriptions) + Items = subscriptions.Select(s => new StripeSubscriptionRowModel(s)).ToList(), + Prices = (await _stripeAdapter.PriceListAsync(new Stripe.PriceListOptions() { Limit = 100 })).Data, + TestClocks = await _stripeAdapter.TestClockListAsync(), + Filter = options + }; + return View(model); + } + + [HttpPost] + public async Task StripeSubscriptions([FromForm] StripeSubscriptionsModel model) + { + if (!ModelState.IsValid) + { + model.Prices = (await _stripeAdapter.PriceListAsync(new Stripe.PriceListOptions() { Limit = 100 })).Data; + model.TestClocks = await _stripeAdapter.TestClockListAsync(); + return View(model); + } + + if (model.Action == StripeSubscriptionsAction.Export || model.Action == StripeSubscriptionsAction.BulkCancel) + { + var subscriptions = model.Filter.SelectAll ? + await _stripeAdapter.SubscriptionListAsync(model.Filter) : + model.Items.Where(x => x.Selected).Select(x => x.Subscription); + + if (model.Action == StripeSubscriptionsAction.Export) { - await _stripeAdapter.SubscriptionCancelAsync(s.Id); - if (s.LatestInvoice?.Status == "open") - { - await _stripeAdapter.InvoiceVoidInvoiceAsync(s.LatestInvoiceId); - } + return StripeSubscriptionsExport(subscriptions); + } + + if (model.Action == StripeSubscriptionsAction.BulkCancel) + { + await StripeSubscriptionsCancel(subscriptions); + } + } + else + { + if (model.Action == StripeSubscriptionsAction.PreviousPage || model.Action == StripeSubscriptionsAction.Search) + { + model.Filter.StartingAfter = null; + } + if (model.Action == StripeSubscriptionsAction.NextPage || model.Action == StripeSubscriptionsAction.Search) + { + model.Filter.EndingBefore = null; } } - private FileResult StripeSubscriptionsExport(IEnumerable subscriptions) - { - var fieldsToExport = subscriptions.Select(s => new - { - StripeId = s.Id, - CustomerEmail = s.Customer?.Email, - SubscriptionStatus = s.Status, - InvoiceDueDate = s.CurrentPeriodEnd, - SubscriptionProducts = s.Items?.Data.Select(p => p.Plan.Id) - }); - var options = new JsonSerializerOptions + return RedirectToAction("StripeSubscriptions", model.Filter); + } + + // This requires a redundant API call to Stripe because of the way they handle pagination. + // The StartingBefore value has to be infered from the list we get, and isn't supplied by Stripe. + private async Task StripeSubscriptionsGetHasPreviousPage(List subscriptions, StripeSubscriptionListOptions options) + { + var hasPreviousPage = false; + if (subscriptions.FirstOrDefault()?.Id != null) + { + var previousPageSearchOptions = new StripeSubscriptionListOptions() { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - WriteIndented = true + EndingBefore = subscriptions.FirstOrDefault().Id, + Limit = 1, + Status = options.Status, + CurrentPeriodEndDate = options.CurrentPeriodEndDate, + CurrentPeriodEndRange = options.CurrentPeriodEndRange, + Price = options.Price }; + hasPreviousPage = (await _stripeAdapter.SubscriptionListAsync(previousPageSearchOptions)).Count > 0; + } + return hasPreviousPage; + } - var result = System.Text.Json.JsonSerializer.Serialize(fieldsToExport, options); - var bytes = Encoding.UTF8.GetBytes(result); - return File(bytes, "application/json", "StripeSubscriptionsSearch.json"); + private async Task StripeSubscriptionsCancel(IEnumerable subscriptions) + { + foreach (var s in subscriptions) + { + await _stripeAdapter.SubscriptionCancelAsync(s.Id); + if (s.LatestInvoice?.Status == "open") + { + await _stripeAdapter.InvoiceVoidInvoiceAsync(s.LatestInvoiceId); + } } } + + private FileResult StripeSubscriptionsExport(IEnumerable subscriptions) + { + var fieldsToExport = subscriptions.Select(s => new + { + StripeId = s.Id, + CustomerEmail = s.Customer?.Email, + SubscriptionStatus = s.Status, + InvoiceDueDate = s.CurrentPeriodEnd, + SubscriptionProducts = s.Items?.Data.Select(p => p.Plan.Id) + }); + + var options = new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + WriteIndented = true + }; + + var result = System.Text.Json.JsonSerializer.Serialize(fieldsToExport, options); + var bytes = Encoding.UTF8.GetBytes(result); + return File(bytes, "application/json", "StripeSubscriptionsSearch.json"); + } } diff --git a/src/Admin/Controllers/UsersController.cs b/src/Admin/Controllers/UsersController.cs index e8ea2e0cd..0a4becb69 100644 --- a/src/Admin/Controllers/UsersController.cs +++ b/src/Admin/Controllers/UsersController.cs @@ -7,105 +7,104 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Admin.Controllers +namespace Bit.Admin.Controllers; + +[Authorize] +public class UsersController : Controller { - [Authorize] - public class UsersController : Controller + private readonly IUserRepository _userRepository; + private readonly ICipherRepository _cipherRepository; + private readonly IPaymentService _paymentService; + private readonly GlobalSettings _globalSettings; + + public UsersController( + IUserRepository userRepository, + ICipherRepository cipherRepository, + IPaymentService paymentService, + GlobalSettings globalSettings) { - private readonly IUserRepository _userRepository; - private readonly ICipherRepository _cipherRepository; - private readonly IPaymentService _paymentService; - private readonly GlobalSettings _globalSettings; + _userRepository = userRepository; + _cipherRepository = cipherRepository; + _paymentService = paymentService; + _globalSettings = globalSettings; + } - public UsersController( - IUserRepository userRepository, - ICipherRepository cipherRepository, - IPaymentService paymentService, - GlobalSettings globalSettings) + public async Task Index(string email, int page = 1, int count = 25) + { + if (page < 1) { - _userRepository = userRepository; - _cipherRepository = cipherRepository; - _paymentService = paymentService; - _globalSettings = globalSettings; + page = 1; } - public async Task Index(string email, int page = 1, int count = 25) + if (count < 1) { - if (page < 1) - { - page = 1; - } - - if (count < 1) - { - count = 1; - } - - var skip = (page - 1) * count; - var users = await _userRepository.SearchAsync(email, skip, count); - return View(new UsersModel - { - Items = users as List, - Email = string.IsNullOrWhiteSpace(email) ? null : email, - Page = page, - Count = count, - Action = _globalSettings.SelfHosted ? "View" : "Edit" - }); + count = 1; } - public async Task View(Guid id) + var skip = (page - 1) * count; + var users = await _userRepository.SearchAsync(email, skip, count); + return View(new UsersModel { - var user = await _userRepository.GetByIdAsync(id); - if (user == null) - { - return RedirectToAction("Index"); - } + Items = users as List, + Email = string.IsNullOrWhiteSpace(email) ? null : email, + Page = page, + Count = count, + Action = _globalSettings.SelfHosted ? "View" : "Edit" + }); + } - var ciphers = await _cipherRepository.GetManyByUserIdAsync(id); - return View(new UserViewModel(user, ciphers)); - } - - [SelfHosted(NotSelfHostedOnly = true)] - public async Task Edit(Guid id) + public async Task View(Guid id) + { + var user = await _userRepository.GetByIdAsync(id); + if (user == null) { - var user = await _userRepository.GetByIdAsync(id); - if (user == null) - { - return RedirectToAction("Index"); - } - - var ciphers = await _cipherRepository.GetManyByUserIdAsync(id); - var billingInfo = await _paymentService.GetBillingAsync(user); - return View(new UserEditModel(user, ciphers, billingInfo, _globalSettings)); - } - - [HttpPost] - [ValidateAntiForgeryToken] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task Edit(Guid id, UserEditModel model) - { - var user = await _userRepository.GetByIdAsync(id); - if (user == null) - { - return RedirectToAction("Index"); - } - - model.ToUser(user); - await _userRepository.ReplaceAsync(user); - return RedirectToAction("Edit", new { id }); - } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task Delete(Guid id) - { - var user = await _userRepository.GetByIdAsync(id); - if (user != null) - { - await _userRepository.DeleteAsync(user); - } - return RedirectToAction("Index"); } + + var ciphers = await _cipherRepository.GetManyByUserIdAsync(id); + return View(new UserViewModel(user, ciphers)); + } + + [SelfHosted(NotSelfHostedOnly = true)] + public async Task Edit(Guid id) + { + var user = await _userRepository.GetByIdAsync(id); + if (user == null) + { + return RedirectToAction("Index"); + } + + var ciphers = await _cipherRepository.GetManyByUserIdAsync(id); + var billingInfo = await _paymentService.GetBillingAsync(user); + return View(new UserEditModel(user, ciphers, billingInfo, _globalSettings)); + } + + [HttpPost] + [ValidateAntiForgeryToken] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task Edit(Guid id, UserEditModel model) + { + var user = await _userRepository.GetByIdAsync(id); + if (user == null) + { + return RedirectToAction("Index"); + } + + model.ToUser(user); + await _userRepository.ReplaceAsync(user); + return RedirectToAction("Edit", new { id }); + } + + [HttpPost] + [ValidateAntiForgeryToken] + public async Task Delete(Guid id) + { + var user = await _userRepository.GetByIdAsync(id); + if (user != null) + { + await _userRepository.DeleteAsync(user); + } + + return RedirectToAction("Index"); } } diff --git a/src/Admin/HostedServices/AmazonSqsBlockIpHostedService.cs b/src/Admin/HostedServices/AmazonSqsBlockIpHostedService.cs index b0222d06f..646da09c5 100644 --- a/src/Admin/HostedServices/AmazonSqsBlockIpHostedService.cs +++ b/src/Admin/HostedServices/AmazonSqsBlockIpHostedService.cs @@ -4,81 +4,80 @@ using Amazon.SQS.Model; using Bit.Core.Settings; using Microsoft.Extensions.Options; -namespace Bit.Admin.HostedServices +namespace Bit.Admin.HostedServices; + +public class AmazonSqsBlockIpHostedService : BlockIpHostedService { - public class AmazonSqsBlockIpHostedService : BlockIpHostedService + private AmazonSQSClient _client; + + public AmazonSqsBlockIpHostedService( + ILogger logger, + IOptions adminSettings, + GlobalSettings globalSettings) + : base(logger, adminSettings, globalSettings) + { } + + public override void Dispose() { - private AmazonSQSClient _client; + _client?.Dispose(); + } - public AmazonSqsBlockIpHostedService( - ILogger logger, - IOptions adminSettings, - GlobalSettings globalSettings) - : base(logger, adminSettings, globalSettings) - { } + protected override async Task ExecuteAsync(CancellationToken cancellationToken) + { + _client = new AmazonSQSClient(_globalSettings.Amazon.AccessKeyId, + _globalSettings.Amazon.AccessKeySecret, RegionEndpoint.GetBySystemName(_globalSettings.Amazon.Region)); + var blockIpQueue = await _client.GetQueueUrlAsync("block-ip", cancellationToken); + var blockIpQueueUrl = blockIpQueue.QueueUrl; + var unblockIpQueue = await _client.GetQueueUrlAsync("unblock-ip", cancellationToken); + var unblockIpQueueUrl = unblockIpQueue.QueueUrl; - public override void Dispose() + while (!cancellationToken.IsCancellationRequested) { - _client?.Dispose(); - } - - protected override async Task ExecuteAsync(CancellationToken cancellationToken) - { - _client = new AmazonSQSClient(_globalSettings.Amazon.AccessKeyId, - _globalSettings.Amazon.AccessKeySecret, RegionEndpoint.GetBySystemName(_globalSettings.Amazon.Region)); - var blockIpQueue = await _client.GetQueueUrlAsync("block-ip", cancellationToken); - var blockIpQueueUrl = blockIpQueue.QueueUrl; - var unblockIpQueue = await _client.GetQueueUrlAsync("unblock-ip", cancellationToken); - var unblockIpQueueUrl = unblockIpQueue.QueueUrl; - - while (!cancellationToken.IsCancellationRequested) + var blockMessageResponse = await _client.ReceiveMessageAsync(new ReceiveMessageRequest { - var blockMessageResponse = await _client.ReceiveMessageAsync(new ReceiveMessageRequest + QueueUrl = blockIpQueueUrl, + MaxNumberOfMessages = 10, + WaitTimeSeconds = 15 + }, cancellationToken); + if (blockMessageResponse.Messages.Any()) + { + foreach (var message in blockMessageResponse.Messages) { - QueueUrl = blockIpQueueUrl, - MaxNumberOfMessages = 10, - WaitTimeSeconds = 15 - }, cancellationToken); - if (blockMessageResponse.Messages.Any()) - { - foreach (var message in blockMessageResponse.Messages) + try { - try - { - await BlockIpAsync(message.Body, cancellationToken); - } - catch (Exception e) - { - _logger.LogError(e, "Failed to block IP."); - } - await _client.DeleteMessageAsync(blockIpQueueUrl, message.ReceiptHandle, cancellationToken); + await BlockIpAsync(message.Body, cancellationToken); } - } - - var unblockMessageResponse = await _client.ReceiveMessageAsync(new ReceiveMessageRequest - { - QueueUrl = unblockIpQueueUrl, - MaxNumberOfMessages = 10, - WaitTimeSeconds = 15 - }, cancellationToken); - if (unblockMessageResponse.Messages.Any()) - { - foreach (var message in unblockMessageResponse.Messages) + catch (Exception e) { - try - { - await UnblockIpAsync(message.Body, cancellationToken); - } - catch (Exception e) - { - _logger.LogError(e, "Failed to unblock IP."); - } - await _client.DeleteMessageAsync(unblockIpQueueUrl, message.ReceiptHandle, cancellationToken); + _logger.LogError(e, "Failed to block IP."); } + await _client.DeleteMessageAsync(blockIpQueueUrl, message.ReceiptHandle, cancellationToken); } - - await Task.Delay(TimeSpan.FromSeconds(15)); } + + var unblockMessageResponse = await _client.ReceiveMessageAsync(new ReceiveMessageRequest + { + QueueUrl = unblockIpQueueUrl, + MaxNumberOfMessages = 10, + WaitTimeSeconds = 15 + }, cancellationToken); + if (unblockMessageResponse.Messages.Any()) + { + foreach (var message in unblockMessageResponse.Messages) + { + try + { + await UnblockIpAsync(message.Body, cancellationToken); + } + catch (Exception e) + { + _logger.LogError(e, "Failed to unblock IP."); + } + await _client.DeleteMessageAsync(unblockIpQueueUrl, message.ReceiptHandle, cancellationToken); + } + } + + await Task.Delay(TimeSpan.FromSeconds(15)); } } } diff --git a/src/Admin/HostedServices/AzureQueueBlockIpHostedService.cs b/src/Admin/HostedServices/AzureQueueBlockIpHostedService.cs index cd96f359a..f1590377e 100644 --- a/src/Admin/HostedServices/AzureQueueBlockIpHostedService.cs +++ b/src/Admin/HostedServices/AzureQueueBlockIpHostedService.cs @@ -2,63 +2,62 @@ using Bit.Core.Settings; using Microsoft.Extensions.Options; -namespace Bit.Admin.HostedServices +namespace Bit.Admin.HostedServices; + +public class AzureQueueBlockIpHostedService : BlockIpHostedService { - public class AzureQueueBlockIpHostedService : BlockIpHostedService + private QueueClient _blockIpQueueClient; + private QueueClient _unblockIpQueueClient; + + public AzureQueueBlockIpHostedService( + ILogger logger, + IOptions adminSettings, + GlobalSettings globalSettings) + : base(logger, adminSettings, globalSettings) + { } + + protected override async Task ExecuteAsync(CancellationToken cancellationToken) { - private QueueClient _blockIpQueueClient; - private QueueClient _unblockIpQueueClient; + _blockIpQueueClient = new QueueClient(_globalSettings.Storage.ConnectionString, "blockip"); + _unblockIpQueueClient = new QueueClient(_globalSettings.Storage.ConnectionString, "unblockip"); - public AzureQueueBlockIpHostedService( - ILogger logger, - IOptions adminSettings, - GlobalSettings globalSettings) - : base(logger, adminSettings, globalSettings) - { } - - protected override async Task ExecuteAsync(CancellationToken cancellationToken) + while (!cancellationToken.IsCancellationRequested) { - _blockIpQueueClient = new QueueClient(_globalSettings.Storage.ConnectionString, "blockip"); - _unblockIpQueueClient = new QueueClient(_globalSettings.Storage.ConnectionString, "unblockip"); - - while (!cancellationToken.IsCancellationRequested) + var blockMessages = await _blockIpQueueClient.ReceiveMessagesAsync(maxMessages: 32); + if (blockMessages.Value?.Any() ?? false) { - var blockMessages = await _blockIpQueueClient.ReceiveMessagesAsync(maxMessages: 32); - if (blockMessages.Value?.Any() ?? false) + foreach (var message in blockMessages.Value) { - foreach (var message in blockMessages.Value) + try { - try - { - await BlockIpAsync(message.MessageText, cancellationToken); - } - catch (Exception e) - { - _logger.LogError(e, "Failed to block IP."); - } - await _blockIpQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + await BlockIpAsync(message.MessageText, cancellationToken); } - } - - var unblockMessages = await _unblockIpQueueClient.ReceiveMessagesAsync(maxMessages: 32); - if (unblockMessages.Value?.Any() ?? false) - { - foreach (var message in unblockMessages.Value) + catch (Exception e) { - try - { - await UnblockIpAsync(message.MessageText, cancellationToken); - } - catch (Exception e) - { - _logger.LogError(e, "Failed to unblock IP."); - } - await _unblockIpQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + _logger.LogError(e, "Failed to block IP."); } + await _blockIpQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); } - - await Task.Delay(TimeSpan.FromSeconds(15)); } + + var unblockMessages = await _unblockIpQueueClient.ReceiveMessagesAsync(maxMessages: 32); + if (unblockMessages.Value?.Any() ?? false) + { + foreach (var message in unblockMessages.Value) + { + try + { + await UnblockIpAsync(message.MessageText, cancellationToken); + } + catch (Exception e) + { + _logger.LogError(e, "Failed to unblock IP."); + } + await _unblockIpQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + } + } + + await Task.Delay(TimeSpan.FromSeconds(15)); } } } diff --git a/src/Admin/HostedServices/AzureQueueMailHostedService.cs b/src/Admin/HostedServices/AzureQueueMailHostedService.cs index 6e976f0b7..b2031a405 100644 --- a/src/Admin/HostedServices/AzureQueueMailHostedService.cs +++ b/src/Admin/HostedServices/AzureQueueMailHostedService.cs @@ -6,97 +6,96 @@ using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Admin.HostedServices +namespace Bit.Admin.HostedServices; + +public class AzureQueueMailHostedService : IHostedService { - public class AzureQueueMailHostedService : IHostedService + private readonly ILogger _logger; + private readonly GlobalSettings _globalSettings; + private readonly IMailService _mailService; + private CancellationTokenSource _cts; + private Task _executingTask; + + private QueueClient _mailQueueClient; + + public AzureQueueMailHostedService( + ILogger logger, + IMailService mailService, + GlobalSettings globalSettings) { - private readonly ILogger _logger; - private readonly GlobalSettings _globalSettings; - private readonly IMailService _mailService; - private CancellationTokenSource _cts; - private Task _executingTask; + _logger = logger; + _mailService = mailService; + _globalSettings = globalSettings; + } - private QueueClient _mailQueueClient; + public Task StartAsync(CancellationToken cancellationToken) + { + _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _executingTask = ExecuteAsync(_cts.Token); + return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; + } - public AzureQueueMailHostedService( - ILogger logger, - IMailService mailService, - GlobalSettings globalSettings) + public async Task StopAsync(CancellationToken cancellationToken) + { + if (_executingTask == null) { - _logger = logger; - _mailService = mailService; - _globalSettings = globalSettings; + return; } + _cts.Cancel(); + await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); + cancellationToken.ThrowIfCancellationRequested(); + } - public Task StartAsync(CancellationToken cancellationToken) - { - _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - _executingTask = ExecuteAsync(_cts.Token); - return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; - } + private async Task ExecuteAsync(CancellationToken cancellationToken) + { + _mailQueueClient = new QueueClient(_globalSettings.Mail.ConnectionString, "mail"); - public async Task StopAsync(CancellationToken cancellationToken) + QueueMessage[] mailMessages; + while (!cancellationToken.IsCancellationRequested) { - if (_executingTask == null) + if (!(mailMessages = await RetrieveMessagesAsync()).Any()) { - return; + await Task.Delay(TimeSpan.FromSeconds(15)); } - _cts.Cancel(); - await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); - cancellationToken.ThrowIfCancellationRequested(); - } - private async Task ExecuteAsync(CancellationToken cancellationToken) - { - _mailQueueClient = new QueueClient(_globalSettings.Mail.ConnectionString, "mail"); - - QueueMessage[] mailMessages; - while (!cancellationToken.IsCancellationRequested) + foreach (var message in mailMessages) { - if (!(mailMessages = await RetrieveMessagesAsync()).Any()) + try { - await Task.Delay(TimeSpan.FromSeconds(15)); - } + using var document = JsonDocument.Parse(message.DecodeMessageText()); + var root = document.RootElement; - foreach (var message in mailMessages) - { - try + if (root.ValueKind == JsonValueKind.Array) { - using var document = JsonDocument.Parse(message.DecodeMessageText()); - var root = document.RootElement; - - if (root.ValueKind == JsonValueKind.Array) + foreach (var mailQueueMessage in root.ToObject>()) { - foreach (var mailQueueMessage in root.ToObject>()) - { - await _mailService.SendEnqueuedMailMessageAsync(mailQueueMessage); - } - } - else if (root.ValueKind == JsonValueKind.Object) - { - var mailQueueMessage = root.ToObject(); await _mailService.SendEnqueuedMailMessageAsync(mailQueueMessage); } } - catch (Exception e) + else if (root.ValueKind == JsonValueKind.Object) { - _logger.LogError(e, "Failed to send email"); - // TODO: retries? + var mailQueueMessage = root.ToObject(); + await _mailService.SendEnqueuedMailMessageAsync(mailQueueMessage); } + } + catch (Exception e) + { + _logger.LogError(e, "Failed to send email"); + // TODO: retries? + } - await _mailQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + await _mailQueueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); - if (cancellationToken.IsCancellationRequested) - { - break; - } + if (cancellationToken.IsCancellationRequested) + { + break; } } } + } - private async Task RetrieveMessagesAsync() - { - return (await _mailQueueClient.ReceiveMessagesAsync(maxMessages: 32))?.Value ?? new QueueMessage[] { }; - } + private async Task RetrieveMessagesAsync() + { + return (await _mailQueueClient.ReceiveMessagesAsync(maxMessages: 32))?.Value ?? new QueueMessage[] { }; } } diff --git a/src/Admin/HostedServices/BlockIpHostedService.cs b/src/Admin/HostedServices/BlockIpHostedService.cs index 17f0c50ce..6a1f58c6b 100644 --- a/src/Admin/HostedServices/BlockIpHostedService.cs +++ b/src/Admin/HostedServices/BlockIpHostedService.cs @@ -1,71 +1,105 @@ using Bit.Core.Settings; using Microsoft.Extensions.Options; -namespace Bit.Admin.HostedServices +namespace Bit.Admin.HostedServices; + +public abstract class BlockIpHostedService : IHostedService, IDisposable { - public abstract class BlockIpHostedService : IHostedService, IDisposable + protected readonly ILogger _logger; + protected readonly GlobalSettings _globalSettings; + private readonly AdminSettings _adminSettings; + + private Task _executingTask; + private CancellationTokenSource _cts; + private HttpClient _httpClient = new HttpClient(); + + public BlockIpHostedService( + ILogger logger, + IOptions adminSettings, + GlobalSettings globalSettings) { - protected readonly ILogger _logger; - protected readonly GlobalSettings _globalSettings; - private readonly AdminSettings _adminSettings; + _logger = logger; + _globalSettings = globalSettings; + _adminSettings = adminSettings?.Value; + } - private Task _executingTask; - private CancellationTokenSource _cts; - private HttpClient _httpClient = new HttpClient(); + public Task StartAsync(CancellationToken cancellationToken) + { + _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _executingTask = ExecuteAsync(_cts.Token); + return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; + } - public BlockIpHostedService( - ILogger logger, - IOptions adminSettings, - GlobalSettings globalSettings) + public async Task StopAsync(CancellationToken cancellationToken) + { + if (_executingTask == null) { - _logger = logger; - _globalSettings = globalSettings; - _adminSettings = adminSettings?.Value; + return; } + _cts.Cancel(); + await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); + cancellationToken.ThrowIfCancellationRequested(); + } - public Task StartAsync(CancellationToken cancellationToken) - { - _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - _executingTask = ExecuteAsync(_cts.Token); - return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; - } + public virtual void Dispose() + { } - public async Task StopAsync(CancellationToken cancellationToken) + protected abstract Task ExecuteAsync(CancellationToken cancellationToken); + + protected async Task BlockIpAsync(string message, CancellationToken cancellationToken) + { + var request = new HttpRequestMessage(); + request.Headers.Accept.Clear(); + request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail); + request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey); + request.Method = HttpMethod.Post; + request.RequestUri = new Uri("https://api.cloudflare.com/" + + $"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules"); + + request.Content = JsonContent.Create(new { - if (_executingTask == null) + mode = "block", + configuration = new { - return; - } - _cts.Cancel(); - await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); - cancellationToken.ThrowIfCancellationRequested(); + target = "ip", + value = message + }, + notes = $"Rate limit abuse on {DateTime.UtcNow.ToString()}." + }); + + var response = await _httpClient.SendAsync(request, cancellationToken); + if (!response.IsSuccessStatusCode) + { + return; } - public virtual void Dispose() - { } - - protected abstract Task ExecuteAsync(CancellationToken cancellationToken); - - protected async Task BlockIpAsync(string message, CancellationToken cancellationToken) + var accessRuleResponse = await response.Content.ReadFromJsonAsync(cancellationToken: cancellationToken); + if (!accessRuleResponse.Success) { + return; + } + + // TODO: Send `accessRuleResponse.Result?.Id` message to unblock queue + } + + protected async Task UnblockIpAsync(string message, CancellationToken cancellationToken) + { + if (string.IsNullOrWhiteSpace(message)) + { + return; + } + + if (message.Contains(".") || message.Contains(":")) + { + // IP address messages var request = new HttpRequestMessage(); request.Headers.Accept.Clear(); request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail); request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey); - request.Method = HttpMethod.Post; + request.Method = HttpMethod.Get; request.RequestUri = new Uri("https://api.cloudflare.com/" + - $"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules"); - - request.Content = JsonContent.Create(new - { - mode = "block", - configuration = new - { - target = "ip", - value = message - }, - notes = $"Rate limit abuse on {DateTime.UtcNow.ToString()}." - }); + $"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules?" + + $"configuration_target=ip&configuration_value={message}"); var response = await _httpClient.SendAsync(request, cancellationToken); if (!response.IsSuccessStatusCode) @@ -73,93 +107,58 @@ namespace Bit.Admin.HostedServices return; } - var accessRuleResponse = await response.Content.ReadFromJsonAsync(cancellationToken: cancellationToken); - if (!accessRuleResponse.Success) + var listResponse = await response.Content.ReadFromJsonAsync(cancellationToken: cancellationToken); + if (!listResponse.Success) { return; } - // TODO: Send `accessRuleResponse.Result?.Id` message to unblock queue - } - - protected async Task UnblockIpAsync(string message, CancellationToken cancellationToken) - { - if (string.IsNullOrWhiteSpace(message)) + foreach (var rule in listResponse.Result) { - return; - } - - if (message.Contains(".") || message.Contains(":")) - { - // IP address messages - var request = new HttpRequestMessage(); - request.Headers.Accept.Clear(); - request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail); - request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey); - request.Method = HttpMethod.Get; - request.RequestUri = new Uri("https://api.cloudflare.com/" + - $"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules?" + - $"configuration_target=ip&configuration_value={message}"); - - var response = await _httpClient.SendAsync(request, cancellationToken); - if (!response.IsSuccessStatusCode) - { - return; - } - - var listResponse = await response.Content.ReadFromJsonAsync(cancellationToken: cancellationToken); - if (!listResponse.Success) - { - return; - } - - foreach (var rule in listResponse.Result) - { - await DeleteAccessRuleAsync(rule.Id, cancellationToken); - } - } - else - { - // Rule Id messages - await DeleteAccessRuleAsync(message, cancellationToken); + await DeleteAccessRuleAsync(rule.Id, cancellationToken); } } - - protected async Task DeleteAccessRuleAsync(string ruleId, CancellationToken cancellationToken) + else { - var request = new HttpRequestMessage(); - request.Headers.Accept.Clear(); - request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail); - request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey); - request.Method = HttpMethod.Delete; - request.RequestUri = new Uri("https://api.cloudflare.com/" + - $"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules/{ruleId}"); - await _httpClient.SendAsync(request, cancellationToken); + // Rule Id messages + await DeleteAccessRuleAsync(message, cancellationToken); } + } - public class ListResponse + protected async Task DeleteAccessRuleAsync(string ruleId, CancellationToken cancellationToken) + { + var request = new HttpRequestMessage(); + request.Headers.Accept.Clear(); + request.Headers.Add("X-Auth-Email", _adminSettings.Cloudflare.AuthEmail); + request.Headers.Add("X-Auth-Key", _adminSettings.Cloudflare.AuthKey); + request.Method = HttpMethod.Delete; + request.RequestUri = new Uri("https://api.cloudflare.com/" + + $"client/v4/zones/{_adminSettings.Cloudflare.ZoneId}/firewall/access_rules/rules/{ruleId}"); + await _httpClient.SendAsync(request, cancellationToken); + } + + public class ListResponse + { + public bool Success { get; set; } + public List Result { get; set; } + } + + public class AccessRuleResponse + { + public bool Success { get; set; } + public AccessRuleResultResponse Result { get; set; } + } + + public class AccessRuleResultResponse + { + public string Id { get; set; } + public string Notes { get; set; } + public ConfigurationResponse Configuration { get; set; } + + public class ConfigurationResponse { - public bool Success { get; set; } - public List Result { get; set; } - } - - public class AccessRuleResponse - { - public bool Success { get; set; } - public AccessRuleResultResponse Result { get; set; } - } - - public class AccessRuleResultResponse - { - public string Id { get; set; } - public string Notes { get; set; } - public ConfigurationResponse Configuration { get; set; } - - public class ConfigurationResponse - { - public string Target { get; set; } - public string Value { get; set; } - } + public string Target { get; set; } + public string Value { get; set; } } } } diff --git a/src/Admin/HostedServices/DatabaseMigrationHostedService.cs b/src/Admin/HostedServices/DatabaseMigrationHostedService.cs index 06cf01428..0f660729e 100644 --- a/src/Admin/HostedServices/DatabaseMigrationHostedService.cs +++ b/src/Admin/HostedServices/DatabaseMigrationHostedService.cs @@ -3,62 +3,61 @@ using Bit.Core.Jobs; using Bit.Core.Settings; using Bit.Migrator; -namespace Bit.Admin.HostedServices +namespace Bit.Admin.HostedServices; + +public class DatabaseMigrationHostedService : IHostedService, IDisposable { - public class DatabaseMigrationHostedService : IHostedService, IDisposable + private readonly GlobalSettings _globalSettings; + private readonly ILogger _logger; + private readonly DbMigrator _dbMigrator; + + public DatabaseMigrationHostedService( + GlobalSettings globalSettings, + ILogger logger, + ILogger migratorLogger, + ILogger listenerLogger) { - private readonly GlobalSettings _globalSettings; - private readonly ILogger _logger; - private readonly DbMigrator _dbMigrator; + _globalSettings = globalSettings; + _logger = logger; + _dbMigrator = new DbMigrator(globalSettings.SqlServer.ConnectionString, migratorLogger); + } - public DatabaseMigrationHostedService( - GlobalSettings globalSettings, - ILogger logger, - ILogger migratorLogger, - ILogger listenerLogger) + public virtual async Task StartAsync(CancellationToken cancellationToken) + { + // Wait 20 seconds to allow database to come online + await Task.Delay(20000); + + var maxMigrationAttempts = 10; + for (var i = 1; i <= maxMigrationAttempts; i++) { - _globalSettings = globalSettings; - _logger = logger; - _dbMigrator = new DbMigrator(globalSettings.SqlServer.ConnectionString, migratorLogger); - } - - public virtual async Task StartAsync(CancellationToken cancellationToken) - { - // Wait 20 seconds to allow database to come online - await Task.Delay(20000); - - var maxMigrationAttempts = 10; - for (var i = 1; i <= maxMigrationAttempts; i++) + try { - try + _dbMigrator.MigrateMsSqlDatabase(true, cancellationToken); + // TODO: Maybe flip a flag somewhere to indicate migration is complete?? + break; + } + catch (SqlException e) + { + if (i >= maxMigrationAttempts) { - _dbMigrator.MigrateMsSqlDatabase(true, cancellationToken); - // TODO: Maybe flip a flag somewhere to indicate migration is complete?? - break; + _logger.LogError(e, "Database failed to migrate."); + throw; } - catch (SqlException e) + else { - if (i >= maxMigrationAttempts) - { - _logger.LogError(e, "Database failed to migrate."); - throw; - } - else - { - _logger.LogError(e, - "Database unavailable for migration. Trying again (attempt #{0})...", i + 1); - await Task.Delay(20000); - } + _logger.LogError(e, + "Database unavailable for migration. Trying again (attempt #{0})...", i + 1); + await Task.Delay(20000); } } } - - public virtual Task StopAsync(CancellationToken cancellationToken) - { - return Task.FromResult(0); - } - - public virtual void Dispose() - { } } + + public virtual Task StopAsync(CancellationToken cancellationToken) + { + return Task.FromResult(0); + } + + public virtual void Dispose() + { } } diff --git a/src/Admin/Jobs/AliveJob.cs b/src/Admin/Jobs/AliveJob.cs index 27d23c342..b97d597e5 100644 --- a/src/Admin/Jobs/AliveJob.cs +++ b/src/Admin/Jobs/AliveJob.cs @@ -3,27 +3,26 @@ using Bit.Core.Jobs; using Bit.Core.Settings; using Quartz; -namespace Bit.Admin.Jobs +namespace Bit.Admin.Jobs; + +public class AliveJob : BaseJob { - public class AliveJob : BaseJob + private readonly GlobalSettings _globalSettings; + private HttpClient _httpClient = new HttpClient(); + + public AliveJob( + GlobalSettings globalSettings, + ILogger logger) + : base(logger) { - private readonly GlobalSettings _globalSettings; - private HttpClient _httpClient = new HttpClient(); + _globalSettings = globalSettings; + } - public AliveJob( - GlobalSettings globalSettings, - ILogger logger) - : base(logger) - { - _globalSettings = globalSettings; - } - - protected async override Task ExecuteJobAsync(IJobExecutionContext context) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: Keep alive"); - var response = await _httpClient.GetAsync(_globalSettings.BaseServiceUri.Admin); - _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: Keep alive, " + - response.StatusCode); - } + protected async override Task ExecuteJobAsync(IJobExecutionContext context) + { + _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: Keep alive"); + var response = await _httpClient.GetAsync(_globalSettings.BaseServiceUri.Admin); + _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: Keep alive, " + + response.StatusCode); } } diff --git a/src/Admin/Jobs/DatabaseExpiredGrantsJob.cs b/src/Admin/Jobs/DatabaseExpiredGrantsJob.cs index 60ac44828..626eb00d5 100644 --- a/src/Admin/Jobs/DatabaseExpiredGrantsJob.cs +++ b/src/Admin/Jobs/DatabaseExpiredGrantsJob.cs @@ -3,25 +3,24 @@ using Bit.Core.Jobs; using Bit.Core.Repositories; using Quartz; -namespace Bit.Admin.Jobs +namespace Bit.Admin.Jobs; + +public class DatabaseExpiredGrantsJob : BaseJob { - public class DatabaseExpiredGrantsJob : BaseJob + private readonly IMaintenanceRepository _maintenanceRepository; + + public DatabaseExpiredGrantsJob( + IMaintenanceRepository maintenanceRepository, + ILogger logger) + : base(logger) { - private readonly IMaintenanceRepository _maintenanceRepository; + _maintenanceRepository = maintenanceRepository; + } - public DatabaseExpiredGrantsJob( - IMaintenanceRepository maintenanceRepository, - ILogger logger) - : base(logger) - { - _maintenanceRepository = maintenanceRepository; - } - - protected async override Task ExecuteJobAsync(IJobExecutionContext context) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteExpiredGrantsAsync"); - await _maintenanceRepository.DeleteExpiredGrantsAsync(); - _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteExpiredGrantsAsync"); - } + protected async override Task ExecuteJobAsync(IJobExecutionContext context) + { + _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteExpiredGrantsAsync"); + await _maintenanceRepository.DeleteExpiredGrantsAsync(); + _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteExpiredGrantsAsync"); } } diff --git a/src/Admin/Jobs/DatabaseExpiredSponsorshipsJob.cs b/src/Admin/Jobs/DatabaseExpiredSponsorshipsJob.cs index 609351e9f..7a00445fd 100644 --- a/src/Admin/Jobs/DatabaseExpiredSponsorshipsJob.cs +++ b/src/Admin/Jobs/DatabaseExpiredSponsorshipsJob.cs @@ -4,36 +4,35 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Quartz; -namespace Bit.Admin.Jobs +namespace Bit.Admin.Jobs; + +public class DatabaseExpiredSponsorshipsJob : BaseJob { - public class DatabaseExpiredSponsorshipsJob : BaseJob + private GlobalSettings _globalSettings; + private readonly IMaintenanceRepository _maintenanceRepository; + + public DatabaseExpiredSponsorshipsJob( + IMaintenanceRepository maintenanceRepository, + ILogger logger, + GlobalSettings globalSettings) + : base(logger) { - private GlobalSettings _globalSettings; - private readonly IMaintenanceRepository _maintenanceRepository; + _maintenanceRepository = maintenanceRepository; + _globalSettings = globalSettings; + } - public DatabaseExpiredSponsorshipsJob( - IMaintenanceRepository maintenanceRepository, - ILogger logger, - GlobalSettings globalSettings) - : base(logger) + protected override async Task ExecuteJobAsync(IJobExecutionContext context) + { + if (_globalSettings.SelfHosted && !_globalSettings.EnableCloudCommunication) { - _maintenanceRepository = maintenanceRepository; - _globalSettings = globalSettings; + return; } + _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteExpiredSponsorshipsAsync"); - protected override async Task ExecuteJobAsync(IJobExecutionContext context) - { - if (_globalSettings.SelfHosted && !_globalSettings.EnableCloudCommunication) - { - return; - } - _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteExpiredSponsorshipsAsync"); + // allow a 90 day grace period before deleting + var deleteDate = DateTime.UtcNow.AddDays(-90); - // allow a 90 day grace period before deleting - var deleteDate = DateTime.UtcNow.AddDays(-90); - - await _maintenanceRepository.DeleteExpiredSponsorshipsAsync(deleteDate); - _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteExpiredSponsorshipsAsync"); - } + await _maintenanceRepository.DeleteExpiredSponsorshipsAsync(deleteDate); + _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteExpiredSponsorshipsAsync"); } } diff --git a/src/Admin/Jobs/DatabaseRebuildlIndexesJob.cs b/src/Admin/Jobs/DatabaseRebuildlIndexesJob.cs index 24e05043a..78e48bb6f 100644 --- a/src/Admin/Jobs/DatabaseRebuildlIndexesJob.cs +++ b/src/Admin/Jobs/DatabaseRebuildlIndexesJob.cs @@ -3,25 +3,24 @@ using Bit.Core.Jobs; using Bit.Core.Repositories; using Quartz; -namespace Bit.Admin.Jobs +namespace Bit.Admin.Jobs; + +public class DatabaseRebuildlIndexesJob : BaseJob { - public class DatabaseRebuildlIndexesJob : BaseJob + private readonly IMaintenanceRepository _maintenanceRepository; + + public DatabaseRebuildlIndexesJob( + IMaintenanceRepository maintenanceRepository, + ILogger logger) + : base(logger) { - private readonly IMaintenanceRepository _maintenanceRepository; + _maintenanceRepository = maintenanceRepository; + } - public DatabaseRebuildlIndexesJob( - IMaintenanceRepository maintenanceRepository, - ILogger logger) - : base(logger) - { - _maintenanceRepository = maintenanceRepository; - } - - protected async override Task ExecuteJobAsync(IJobExecutionContext context) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: RebuildIndexesAsync"); - await _maintenanceRepository.RebuildIndexesAsync(); - _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: RebuildIndexesAsync"); - } + protected async override Task ExecuteJobAsync(IJobExecutionContext context) + { + _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: RebuildIndexesAsync"); + await _maintenanceRepository.RebuildIndexesAsync(); + _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: RebuildIndexesAsync"); } } diff --git a/src/Admin/Jobs/DatabaseUpdateStatisticsJob.cs b/src/Admin/Jobs/DatabaseUpdateStatisticsJob.cs index 4a03d08fd..14c13918b 100644 --- a/src/Admin/Jobs/DatabaseUpdateStatisticsJob.cs +++ b/src/Admin/Jobs/DatabaseUpdateStatisticsJob.cs @@ -3,28 +3,27 @@ using Bit.Core.Jobs; using Bit.Core.Repositories; using Quartz; -namespace Bit.Admin.Jobs +namespace Bit.Admin.Jobs; + +public class DatabaseUpdateStatisticsJob : BaseJob { - public class DatabaseUpdateStatisticsJob : BaseJob + private readonly IMaintenanceRepository _maintenanceRepository; + + public DatabaseUpdateStatisticsJob( + IMaintenanceRepository maintenanceRepository, + ILogger logger) + : base(logger) { - private readonly IMaintenanceRepository _maintenanceRepository; + _maintenanceRepository = maintenanceRepository; + } - public DatabaseUpdateStatisticsJob( - IMaintenanceRepository maintenanceRepository, - ILogger logger) - : base(logger) - { - _maintenanceRepository = maintenanceRepository; - } - - protected async override Task ExecuteJobAsync(IJobExecutionContext context) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: UpdateStatisticsAsync"); - await _maintenanceRepository.UpdateStatisticsAsync(); - _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: UpdateStatisticsAsync"); - _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DisableCipherAutoStatsAsync"); - await _maintenanceRepository.DisableCipherAutoStatsAsync(); - _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DisableCipherAutoStatsAsync"); - } + protected async override Task ExecuteJobAsync(IJobExecutionContext context) + { + _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: UpdateStatisticsAsync"); + await _maintenanceRepository.UpdateStatisticsAsync(); + _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: UpdateStatisticsAsync"); + _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DisableCipherAutoStatsAsync"); + await _maintenanceRepository.DisableCipherAutoStatsAsync(); + _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DisableCipherAutoStatsAsync"); } } diff --git a/src/Admin/Jobs/DeleteCiphersJob.cs b/src/Admin/Jobs/DeleteCiphersJob.cs index a5a92aa28..ecf13401e 100644 --- a/src/Admin/Jobs/DeleteCiphersJob.cs +++ b/src/Admin/Jobs/DeleteCiphersJob.cs @@ -4,34 +4,33 @@ using Bit.Core.Repositories; using Microsoft.Extensions.Options; using Quartz; -namespace Bit.Admin.Jobs +namespace Bit.Admin.Jobs; + +public class DeleteCiphersJob : BaseJob { - public class DeleteCiphersJob : BaseJob + private readonly ICipherRepository _cipherRepository; + private readonly AdminSettings _adminSettings; + + public DeleteCiphersJob( + ICipherRepository cipherRepository, + IOptions adminSettings, + ILogger logger) + : base(logger) { - private readonly ICipherRepository _cipherRepository; - private readonly AdminSettings _adminSettings; + _cipherRepository = cipherRepository; + _adminSettings = adminSettings?.Value; + } - public DeleteCiphersJob( - ICipherRepository cipherRepository, - IOptions adminSettings, - ILogger logger) - : base(logger) + protected async override Task ExecuteJobAsync(IJobExecutionContext context) + { + _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteDeletedAsync"); + var deleteDate = DateTime.UtcNow.AddDays(-30); + var daysAgoSetting = (_adminSettings?.DeleteTrashDaysAgo).GetValueOrDefault(); + if (daysAgoSetting > 0) { - _cipherRepository = cipherRepository; - _adminSettings = adminSettings?.Value; - } - - protected async override Task ExecuteJobAsync(IJobExecutionContext context) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: DeleteDeletedAsync"); - var deleteDate = DateTime.UtcNow.AddDays(-30); - var daysAgoSetting = (_adminSettings?.DeleteTrashDaysAgo).GetValueOrDefault(); - if (daysAgoSetting > 0) - { - deleteDate = DateTime.UtcNow.AddDays(-1 * daysAgoSetting); - } - await _cipherRepository.DeleteDeletedAsync(deleteDate); - _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteDeletedAsync"); + deleteDate = DateTime.UtcNow.AddDays(-1 * daysAgoSetting); } + await _cipherRepository.DeleteDeletedAsync(deleteDate); + _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: DeleteDeletedAsync"); } } diff --git a/src/Admin/Jobs/DeleteSendsJob.cs b/src/Admin/Jobs/DeleteSendsJob.cs index 814840fc4..9f3ed96ef 100644 --- a/src/Admin/Jobs/DeleteSendsJob.cs +++ b/src/Admin/Jobs/DeleteSendsJob.cs @@ -4,38 +4,37 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Quartz; -namespace Bit.Admin.Jobs +namespace Bit.Admin.Jobs; + +public class DeleteSendsJob : BaseJob { - public class DeleteSendsJob : BaseJob + private readonly ISendRepository _sendRepository; + private readonly IServiceProvider _serviceProvider; + + public DeleteSendsJob( + ISendRepository sendRepository, + IServiceProvider serviceProvider, + ILogger logger) + : base(logger) { - private readonly ISendRepository _sendRepository; - private readonly IServiceProvider _serviceProvider; + _sendRepository = sendRepository; + _serviceProvider = serviceProvider; + } - public DeleteSendsJob( - ISendRepository sendRepository, - IServiceProvider serviceProvider, - ILogger logger) - : base(logger) + protected async override Task ExecuteJobAsync(IJobExecutionContext context) + { + var sends = await _sendRepository.GetManyByDeletionDateAsync(DateTime.UtcNow); + _logger.LogInformation(Constants.BypassFiltersEventId, "Deleting {0} sends.", sends.Count); + if (!sends.Any()) { - _sendRepository = sendRepository; - _serviceProvider = serviceProvider; + return; } - - protected async override Task ExecuteJobAsync(IJobExecutionContext context) + using (var scope = _serviceProvider.CreateScope()) { - var sends = await _sendRepository.GetManyByDeletionDateAsync(DateTime.UtcNow); - _logger.LogInformation(Constants.BypassFiltersEventId, "Deleting {0} sends.", sends.Count); - if (!sends.Any()) + var sendService = scope.ServiceProvider.GetRequiredService(); + foreach (var send in sends) { - return; - } - using (var scope = _serviceProvider.CreateScope()) - { - var sendService = scope.ServiceProvider.GetRequiredService(); - foreach (var send in sends) - { - await sendService.DeleteSendAsync(send); - } + await sendService.DeleteSendAsync(send); } } } diff --git a/src/Admin/Jobs/JobsHostedService.cs b/src/Admin/Jobs/JobsHostedService.cs index 01ac66a84..53b5c0566 100644 --- a/src/Admin/Jobs/JobsHostedService.cs +++ b/src/Admin/Jobs/JobsHostedService.cs @@ -3,94 +3,93 @@ using Bit.Core.Jobs; using Bit.Core.Settings; using Quartz; -namespace Bit.Admin.Jobs +namespace Bit.Admin.Jobs; + +public class JobsHostedService : BaseJobsHostedService { - public class JobsHostedService : BaseJobsHostedService + public JobsHostedService( + GlobalSettings globalSettings, + IServiceProvider serviceProvider, + ILogger logger, + ILogger listenerLogger) + : base(globalSettings, serviceProvider, logger, listenerLogger) { } + + public override async Task StartAsync(CancellationToken cancellationToken) { - public JobsHostedService( - GlobalSettings globalSettings, - IServiceProvider serviceProvider, - ILogger logger, - ILogger listenerLogger) - : base(globalSettings, serviceProvider, logger, listenerLogger) { } - - public override async Task StartAsync(CancellationToken cancellationToken) + var timeZone = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? + TimeZoneInfo.FindSystemTimeZoneById("Eastern Standard Time") : + TimeZoneInfo.FindSystemTimeZoneById("America/New_York"); + if (_globalSettings.SelfHosted) { - var timeZone = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? - TimeZoneInfo.FindSystemTimeZoneById("Eastern Standard Time") : - TimeZoneInfo.FindSystemTimeZoneById("America/New_York"); - if (_globalSettings.SelfHosted) - { - timeZone = TimeZoneInfo.Local; - } - - var everyTopOfTheHourTrigger = TriggerBuilder.Create() - .WithIdentity("EveryTopOfTheHourTrigger") - .StartNow() - .WithCronSchedule("0 0 * * * ?") - .Build(); - var everyFiveMinutesTrigger = TriggerBuilder.Create() - .WithIdentity("EveryFiveMinutesTrigger") - .StartNow() - .WithCronSchedule("0 */5 * * * ?") - .Build(); - var everyFridayAt10pmTrigger = TriggerBuilder.Create() - .WithIdentity("EveryFridayAt10pmTrigger") - .StartNow() - .WithCronSchedule("0 0 22 ? * FRI", x => x.InTimeZone(timeZone)) - .Build(); - var everySaturdayAtMidnightTrigger = TriggerBuilder.Create() - .WithIdentity("EverySaturdayAtMidnightTrigger") - .StartNow() - .WithCronSchedule("0 0 0 ? * SAT", x => x.InTimeZone(timeZone)) - .Build(); - var everySundayAtMidnightTrigger = TriggerBuilder.Create() - .WithIdentity("EverySundayAtMidnightTrigger") - .StartNow() - .WithCronSchedule("0 0 0 ? * SUN", x => x.InTimeZone(timeZone)) - .Build(); - var everyMondayAtMidnightTrigger = TriggerBuilder.Create() - .WithIdentity("EveryMondayAtMidnightTrigger") - .StartNow() - .WithCronSchedule("0 0 0 ? * MON", x => x.InTimeZone(timeZone)) - .Build(); - var everyDayAtMidnightUtc = TriggerBuilder.Create() - .WithIdentity("EveryDayAtMidnightUtc") - .StartNow() - .WithCronSchedule("0 0 0 * * ?") - .Build(); - - var jobs = new List> - { - new Tuple(typeof(DeleteSendsJob), everyFiveMinutesTrigger), - new Tuple(typeof(DatabaseExpiredGrantsJob), everyFridayAt10pmTrigger), - new Tuple(typeof(DatabaseUpdateStatisticsJob), everySaturdayAtMidnightTrigger), - new Tuple(typeof(DatabaseRebuildlIndexesJob), everySundayAtMidnightTrigger), - new Tuple(typeof(DeleteCiphersJob), everyDayAtMidnightUtc), - new Tuple(typeof(DatabaseExpiredSponsorshipsJob), everyMondayAtMidnightTrigger) - }; - - if (!_globalSettings.SelfHosted) - { - jobs.Add(new Tuple(typeof(AliveJob), everyTopOfTheHourTrigger)); - } - - Jobs = jobs; - await base.StartAsync(cancellationToken); + timeZone = TimeZoneInfo.Local; } - public static void AddJobsServices(IServiceCollection services, bool selfHosted) + var everyTopOfTheHourTrigger = TriggerBuilder.Create() + .WithIdentity("EveryTopOfTheHourTrigger") + .StartNow() + .WithCronSchedule("0 0 * * * ?") + .Build(); + var everyFiveMinutesTrigger = TriggerBuilder.Create() + .WithIdentity("EveryFiveMinutesTrigger") + .StartNow() + .WithCronSchedule("0 */5 * * * ?") + .Build(); + var everyFridayAt10pmTrigger = TriggerBuilder.Create() + .WithIdentity("EveryFridayAt10pmTrigger") + .StartNow() + .WithCronSchedule("0 0 22 ? * FRI", x => x.InTimeZone(timeZone)) + .Build(); + var everySaturdayAtMidnightTrigger = TriggerBuilder.Create() + .WithIdentity("EverySaturdayAtMidnightTrigger") + .StartNow() + .WithCronSchedule("0 0 0 ? * SAT", x => x.InTimeZone(timeZone)) + .Build(); + var everySundayAtMidnightTrigger = TriggerBuilder.Create() + .WithIdentity("EverySundayAtMidnightTrigger") + .StartNow() + .WithCronSchedule("0 0 0 ? * SUN", x => x.InTimeZone(timeZone)) + .Build(); + var everyMondayAtMidnightTrigger = TriggerBuilder.Create() + .WithIdentity("EveryMondayAtMidnightTrigger") + .StartNow() + .WithCronSchedule("0 0 0 ? * MON", x => x.InTimeZone(timeZone)) + .Build(); + var everyDayAtMidnightUtc = TriggerBuilder.Create() + .WithIdentity("EveryDayAtMidnightUtc") + .StartNow() + .WithCronSchedule("0 0 0 * * ?") + .Build(); + + var jobs = new List> { - if (!selfHosted) - { - services.AddTransient(); - } - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); + new Tuple(typeof(DeleteSendsJob), everyFiveMinutesTrigger), + new Tuple(typeof(DatabaseExpiredGrantsJob), everyFridayAt10pmTrigger), + new Tuple(typeof(DatabaseUpdateStatisticsJob), everySaturdayAtMidnightTrigger), + new Tuple(typeof(DatabaseRebuildlIndexesJob), everySundayAtMidnightTrigger), + new Tuple(typeof(DeleteCiphersJob), everyDayAtMidnightUtc), + new Tuple(typeof(DatabaseExpiredSponsorshipsJob), everyMondayAtMidnightTrigger) + }; + + if (!_globalSettings.SelfHosted) + { + jobs.Add(new Tuple(typeof(AliveJob), everyTopOfTheHourTrigger)); } + + Jobs = jobs; + await base.StartAsync(cancellationToken); + } + + public static void AddJobsServices(IServiceCollection services, bool selfHosted) + { + if (!selfHosted) + { + services.AddTransient(); + } + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); } } diff --git a/src/Admin/Models/BillingInformationModel.cs b/src/Admin/Models/BillingInformationModel.cs index 1457a0851..a90ec7955 100644 --- a/src/Admin/Models/BillingInformationModel.cs +++ b/src/Admin/Models/BillingInformationModel.cs @@ -1,11 +1,10 @@ using Bit.Core.Models.Business; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class BillingInformationModel { - public class BillingInformationModel - { - public BillingInfo BillingInfo { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - } + public BillingInfo BillingInfo { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } } diff --git a/src/Admin/Models/ChargeBraintreeModel.cs b/src/Admin/Models/ChargeBraintreeModel.cs index b7adba8f1..2ba06cb98 100644 --- a/src/Admin/Models/ChargeBraintreeModel.cs +++ b/src/Admin/Models/ChargeBraintreeModel.cs @@ -1,27 +1,26 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Admin.Models -{ - public class ChargeBraintreeModel : IValidatableObject - { - [Required] - [Display(Name = "Braintree Customer Id")] - public string Id { get; set; } - [Required] - [Display(Name = "Amount")] - public decimal? Amount { get; set; } - public string TransactionId { get; set; } - public string PayPalTransactionId { get; set; } +namespace Bit.Admin.Models; - public IEnumerable Validate(ValidationContext validationContext) +public class ChargeBraintreeModel : IValidatableObject +{ + [Required] + [Display(Name = "Braintree Customer Id")] + public string Id { get; set; } + [Required] + [Display(Name = "Amount")] + public decimal? Amount { get; set; } + public string TransactionId { get; set; } + public string PayPalTransactionId { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (Id != null) { - if (Id != null) + if (Id.Length != 36 || (Id[0] != 'o' && Id[0] != 'u') || + !Guid.TryParse(Id.Substring(1, 32), out var guid)) { - if (Id.Length != 36 || (Id[0] != 'o' && Id[0] != 'u') || - !Guid.TryParse(Id.Substring(1, 32), out var guid)) - { - yield return new ValidationResult("Customer Id is not a valid format."); - } + yield return new ValidationResult("Customer Id is not a valid format."); } } } diff --git a/src/Admin/Models/CreateProviderModel.cs b/src/Admin/Models/CreateProviderModel.cs index 582c388af..9bcbf1f75 100644 --- a/src/Admin/Models/CreateProviderModel.cs +++ b/src/Admin/Models/CreateProviderModel.cs @@ -1,13 +1,12 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Admin.Models -{ - public class CreateProviderModel - { - public CreateProviderModel() { } +namespace Bit.Admin.Models; - [Display(Name = "Owner Email")] - [Required] - public string OwnerEmail { get; set; } - } +public class CreateProviderModel +{ + public CreateProviderModel() { } + + [Display(Name = "Owner Email")] + [Required] + public string OwnerEmail { get; set; } } diff --git a/src/Admin/Models/CreateUpdateTransactionModel.cs b/src/Admin/Models/CreateUpdateTransactionModel.cs index 0ab1f0dc8..8004546f9 100644 --- a/src/Admin/Models/CreateUpdateTransactionModel.cs +++ b/src/Admin/Models/CreateUpdateTransactionModel.cs @@ -2,77 +2,76 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class CreateUpdateTransactionModel : IValidatableObject { - public class CreateUpdateTransactionModel : IValidatableObject + public CreateUpdateTransactionModel() { } + + public CreateUpdateTransactionModel(Transaction transaction) { - public CreateUpdateTransactionModel() { } + Edit = true; + UserId = transaction.UserId; + OrganizationId = transaction.OrganizationId; + Amount = transaction.Amount; + RefundedAmount = transaction.RefundedAmount; + Refunded = transaction.Refunded.GetValueOrDefault(); + Details = transaction.Details; + Date = transaction.CreationDate; + PaymentMethod = transaction.PaymentMethodType; + Gateway = transaction.Gateway; + GatewayId = transaction.GatewayId; + Type = transaction.Type; + } - public CreateUpdateTransactionModel(Transaction transaction) + public bool Edit { get; set; } + + [Display(Name = "User Id")] + public Guid? UserId { get; set; } + [Display(Name = "Organization Id")] + public Guid? OrganizationId { get; set; } + [Required] + public decimal? Amount { get; set; } + [Display(Name = "Refunded Amount")] + public decimal? RefundedAmount { get; set; } + public bool Refunded { get; set; } + [Required] + public string Details { get; set; } + [Required] + public DateTime? Date { get; set; } + [Display(Name = "Payment Method")] + public PaymentMethodType? PaymentMethod { get; set; } + public GatewayType? Gateway { get; set; } + [Display(Name = "Gateway Id")] + public string GatewayId { get; set; } + [Required] + public TransactionType? Type { get; set; } + + + public IEnumerable Validate(ValidationContext validationContext) + { + if ((!UserId.HasValue && !OrganizationId.HasValue) || (UserId.HasValue && OrganizationId.HasValue)) { - Edit = true; - UserId = transaction.UserId; - OrganizationId = transaction.OrganizationId; - Amount = transaction.Amount; - RefundedAmount = transaction.RefundedAmount; - Refunded = transaction.Refunded.GetValueOrDefault(); - Details = transaction.Details; - Date = transaction.CreationDate; - PaymentMethod = transaction.PaymentMethodType; - Gateway = transaction.Gateway; - GatewayId = transaction.GatewayId; - Type = transaction.Type; - } - - public bool Edit { get; set; } - - [Display(Name = "User Id")] - public Guid? UserId { get; set; } - [Display(Name = "Organization Id")] - public Guid? OrganizationId { get; set; } - [Required] - public decimal? Amount { get; set; } - [Display(Name = "Refunded Amount")] - public decimal? RefundedAmount { get; set; } - public bool Refunded { get; set; } - [Required] - public string Details { get; set; } - [Required] - public DateTime? Date { get; set; } - [Display(Name = "Payment Method")] - public PaymentMethodType? PaymentMethod { get; set; } - public GatewayType? Gateway { get; set; } - [Display(Name = "Gateway Id")] - public string GatewayId { get; set; } - [Required] - public TransactionType? Type { get; set; } - - - public IEnumerable Validate(ValidationContext validationContext) - { - if ((!UserId.HasValue && !OrganizationId.HasValue) || (UserId.HasValue && OrganizationId.HasValue)) - { - yield return new ValidationResult("Must provide either User Id, or Organization Id."); - } - } - - public Transaction ToTransaction(Guid? id = null) - { - return new Transaction - { - Id = id.GetValueOrDefault(), - UserId = UserId, - OrganizationId = OrganizationId, - Amount = Amount.Value, - RefundedAmount = RefundedAmount, - Refunded = Refunded ? true : (bool?)null, - Details = Details, - CreationDate = Date.Value, - PaymentMethodType = PaymentMethod, - Gateway = Gateway, - GatewayId = GatewayId, - Type = Type.Value - }; + yield return new ValidationResult("Must provide either User Id, or Organization Id."); } } + + public Transaction ToTransaction(Guid? id = null) + { + return new Transaction + { + Id = id.GetValueOrDefault(), + UserId = UserId, + OrganizationId = OrganizationId, + Amount = Amount.Value, + RefundedAmount = RefundedAmount, + Refunded = Refunded ? true : (bool?)null, + Details = Details, + CreationDate = Date.Value, + PaymentMethodType = PaymentMethod, + Gateway = Gateway, + GatewayId = GatewayId, + Type = Type.Value + }; + } } diff --git a/src/Admin/Models/CursorPagedModel.cs b/src/Admin/Models/CursorPagedModel.cs index 59d13e268..35a4de922 100644 --- a/src/Admin/Models/CursorPagedModel.cs +++ b/src/Admin/Models/CursorPagedModel.cs @@ -1,10 +1,9 @@ -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class CursorPagedModel { - public class CursorPagedModel - { - public List Items { get; set; } - public int Count { get; set; } - public string Cursor { get; set; } - public string NextCursor { get; set; } - } + public List Items { get; set; } + public int Count { get; set; } + public string Cursor { get; set; } + public string NextCursor { get; set; } } diff --git a/src/Admin/Models/ErrorViewModel.cs b/src/Admin/Models/ErrorViewModel.cs index 7a448776d..3b24a1ece 100644 --- a/src/Admin/Models/ErrorViewModel.cs +++ b/src/Admin/Models/ErrorViewModel.cs @@ -1,9 +1,8 @@ -namespace Bit.Admin.Models -{ - public class ErrorViewModel - { - public string RequestId { get; set; } +namespace Bit.Admin.Models; - public bool ShowRequestId => !string.IsNullOrEmpty(RequestId); - } +public class ErrorViewModel +{ + public string RequestId { get; set; } + + public bool ShowRequestId => !string.IsNullOrEmpty(RequestId); } diff --git a/src/Admin/Models/HomeModel.cs b/src/Admin/Models/HomeModel.cs index 1bdebbe02..900a04e41 100644 --- a/src/Admin/Models/HomeModel.cs +++ b/src/Admin/Models/HomeModel.cs @@ -1,10 +1,9 @@ using Bit.Core.Settings; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class HomeModel { - public class HomeModel - { - public string CurrentVersion { get; set; } - public GlobalSettings GlobalSettings { get; set; } - } + public string CurrentVersion { get; set; } + public GlobalSettings GlobalSettings { get; set; } } diff --git a/src/Admin/Models/LicenseModel.cs b/src/Admin/Models/LicenseModel.cs index 47d34ad18..b0fd91201 100644 --- a/src/Admin/Models/LicenseModel.cs +++ b/src/Admin/Models/LicenseModel.cs @@ -1,35 +1,34 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class LicenseModel : IValidatableObject { - public class LicenseModel : IValidatableObject + [Display(Name = "User Id")] + public Guid? UserId { get; set; } + [Display(Name = "Organization Id")] + public Guid? OrganizationId { get; set; } + [Display(Name = "Installation Id")] + public Guid? InstallationId { get; set; } + [Required] + [Display(Name = "Version")] + public int Version { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) { - [Display(Name = "User Id")] - public Guid? UserId { get; set; } - [Display(Name = "Organization Id")] - public Guid? OrganizationId { get; set; } - [Display(Name = "Installation Id")] - public Guid? InstallationId { get; set; } - [Required] - [Display(Name = "Version")] - public int Version { get; set; } - - public IEnumerable Validate(ValidationContext validationContext) + if (UserId.HasValue && OrganizationId.HasValue) { - if (UserId.HasValue && OrganizationId.HasValue) - { - yield return new ValidationResult("Use either User Id or Organization Id. Not both."); - } + yield return new ValidationResult("Use either User Id or Organization Id. Not both."); + } - if (!UserId.HasValue && !OrganizationId.HasValue) - { - yield return new ValidationResult("User Id or Organization Id is required."); - } + if (!UserId.HasValue && !OrganizationId.HasValue) + { + yield return new ValidationResult("User Id or Organization Id is required."); + } - if (OrganizationId.HasValue && !InstallationId.HasValue) - { - yield return new ValidationResult("Installation Id is required for organization licenses."); - } + if (OrganizationId.HasValue && !InstallationId.HasValue) + { + yield return new ValidationResult("Installation Id is required for organization licenses."); } } } diff --git a/src/Admin/Models/LogModel.cs b/src/Admin/Models/LogModel.cs index 3e0437998..8967025d1 100644 --- a/src/Admin/Models/LogModel.cs +++ b/src/Admin/Models/LogModel.cs @@ -1,55 +1,54 @@ using Microsoft.Azure.Documents; using Newtonsoft.Json.Linq; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class LogModel : Resource { - public class LogModel : Resource - { - public long EventIdHash { get; set; } - public string Level { get; set; } - public string Message { get; set; } - public string MessageTruncated => Message.Length > 200 ? $"{Message.Substring(0, 200)}..." : Message; - public string MessageTemplate { get; set; } - public IDictionary Properties { get; set; } - public string Project => Properties?.ContainsKey("Project") ?? false ? Properties["Project"].ToString() : null; - } + public long EventIdHash { get; set; } + public string Level { get; set; } + public string Message { get; set; } + public string MessageTruncated => Message.Length > 200 ? $"{Message.Substring(0, 200)}..." : Message; + public string MessageTemplate { get; set; } + public IDictionary Properties { get; set; } + public string Project => Properties?.ContainsKey("Project") ?? false ? Properties["Project"].ToString() : null; +} - public class LogDetailsModel : LogModel - { - public JObject Exception { get; set; } +public class LogDetailsModel : LogModel +{ + public JObject Exception { get; set; } - public string ExceptionToString(JObject e) + public string ExceptionToString(JObject e) + { + if (e == null) { - if (e == null) - { - return null; - } - - var val = string.Empty; - if (e["Message"] != null && e["Message"].ToObject() != null) - { - val += "Message:\n"; - val += e["Message"] + "\n"; - } - - if (e["StackTrace"] != null && e["StackTrace"].ToObject() != null) - { - val += "\nStack Trace:\n"; - val += e["StackTrace"]; - } - else if (e["StackTraceString"] != null && e["StackTraceString"].ToObject() != null) - { - val += "\nStack Trace String:\n"; - val += e["StackTraceString"]; - } - - if (e["InnerException"] != null && e["InnerException"].ToObject() != null) - { - val += "\n\n=== Inner Exception ===\n\n"; - val += ExceptionToString(e["InnerException"].ToObject()); - } - - return val; + return null; } + + var val = string.Empty; + if (e["Message"] != null && e["Message"].ToObject() != null) + { + val += "Message:\n"; + val += e["Message"] + "\n"; + } + + if (e["StackTrace"] != null && e["StackTrace"].ToObject() != null) + { + val += "\nStack Trace:\n"; + val += e["StackTrace"]; + } + else if (e["StackTraceString"] != null && e["StackTraceString"].ToObject() != null) + { + val += "\nStack Trace String:\n"; + val += e["StackTraceString"]; + } + + if (e["InnerException"] != null && e["InnerException"].ToObject() != null) + { + val += "\n\n=== Inner Exception ===\n\n"; + val += ExceptionToString(e["InnerException"].ToObject()); + } + + return val; } } diff --git a/src/Admin/Models/LoginModel.cs b/src/Admin/Models/LoginModel.cs index fa77ddfe1..7f147874b 100644 --- a/src/Admin/Models/LoginModel.cs +++ b/src/Admin/Models/LoginModel.cs @@ -1,14 +1,13 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class LoginModel { - public class LoginModel - { - [Required] - [EmailAddress] - public string Email { get; set; } - public string ReturnUrl { get; set; } - public string Error { get; set; } - public string Success { get; set; } - } + [Required] + [EmailAddress] + public string Email { get; set; } + public string ReturnUrl { get; set; } + public string Error { get; set; } + public string Success { get; set; } } diff --git a/src/Admin/Models/LogsModel.cs b/src/Admin/Models/LogsModel.cs index d274aa9be..c5527a319 100644 --- a/src/Admin/Models/LogsModel.cs +++ b/src/Admin/Models/LogsModel.cs @@ -1,12 +1,11 @@ using Serilog.Events; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class LogsModel : CursorPagedModel { - public class LogsModel : CursorPagedModel - { - public LogEventLevel? Level { get; set; } - public string Project { get; set; } - public DateTime? Start { get; set; } - public DateTime? End { get; set; } - } + public LogEventLevel? Level { get; set; } + public string Project { get; set; } + public DateTime? Start { get; set; } + public DateTime? End { get; set; } } diff --git a/src/Admin/Models/OrganizationEditModel.cs b/src/Admin/Models/OrganizationEditModel.cs index bf0d6c8d5..4a6fdde5e 100644 --- a/src/Admin/Models/OrganizationEditModel.cs +++ b/src/Admin/Models/OrganizationEditModel.cs @@ -6,148 +6,147 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class OrganizationEditModel : OrganizationViewModel { - public class OrganizationEditModel : OrganizationViewModel + public OrganizationEditModel() { } + + public OrganizationEditModel(Organization org, IEnumerable orgUsers, + IEnumerable ciphers, IEnumerable collections, IEnumerable groups, + IEnumerable policies, BillingInfo billingInfo, IEnumerable connections, + GlobalSettings globalSettings) + : base(org, connections, orgUsers, ciphers, collections, groups, policies) { - public OrganizationEditModel() { } + BillingInfo = billingInfo; + BraintreeMerchantId = globalSettings.Braintree.MerchantId; - public OrganizationEditModel(Organization org, IEnumerable orgUsers, - IEnumerable ciphers, IEnumerable collections, IEnumerable groups, - IEnumerable policies, BillingInfo billingInfo, IEnumerable connections, - GlobalSettings globalSettings) - : base(org, connections, orgUsers, ciphers, collections, groups, policies) - { - BillingInfo = billingInfo; - BraintreeMerchantId = globalSettings.Braintree.MerchantId; + Name = org.Name; + BusinessName = org.BusinessName; + BillingEmail = org.BillingEmail; + PlanType = org.PlanType; + Plan = org.Plan; + Seats = org.Seats; + MaxAutoscaleSeats = org.MaxAutoscaleSeats; + MaxCollections = org.MaxCollections; + UsePolicies = org.UsePolicies; + UseSso = org.UseSso; + UseKeyConnector = org.UseKeyConnector; + UseScim = org.UseScim; + UseGroups = org.UseGroups; + UseDirectory = org.UseDirectory; + UseEvents = org.UseEvents; + UseTotp = org.UseTotp; + Use2fa = org.Use2fa; + UseApi = org.UseApi; + UseResetPassword = org.UseResetPassword; + SelfHost = org.SelfHost; + UsersGetPremium = org.UsersGetPremium; + MaxStorageGb = org.MaxStorageGb; + Gateway = org.Gateway; + GatewayCustomerId = org.GatewayCustomerId; + GatewaySubscriptionId = org.GatewaySubscriptionId; + Enabled = org.Enabled; + LicenseKey = org.LicenseKey; + ExpirationDate = org.ExpirationDate; + } - Name = org.Name; - BusinessName = org.BusinessName; - BillingEmail = org.BillingEmail; - PlanType = org.PlanType; - Plan = org.Plan; - Seats = org.Seats; - MaxAutoscaleSeats = org.MaxAutoscaleSeats; - MaxCollections = org.MaxCollections; - UsePolicies = org.UsePolicies; - UseSso = org.UseSso; - UseKeyConnector = org.UseKeyConnector; - UseScim = org.UseScim; - UseGroups = org.UseGroups; - UseDirectory = org.UseDirectory; - UseEvents = org.UseEvents; - UseTotp = org.UseTotp; - Use2fa = org.Use2fa; - UseApi = org.UseApi; - UseResetPassword = org.UseResetPassword; - SelfHost = org.SelfHost; - UsersGetPremium = org.UsersGetPremium; - MaxStorageGb = org.MaxStorageGb; - Gateway = org.Gateway; - GatewayCustomerId = org.GatewayCustomerId; - GatewaySubscriptionId = org.GatewaySubscriptionId; - Enabled = org.Enabled; - LicenseKey = org.LicenseKey; - ExpirationDate = org.ExpirationDate; - } + public BillingInfo BillingInfo { get; set; } + public string RandomLicenseKey => CoreHelpers.SecureRandomString(20); + public string FourteenDayExpirationDate => DateTime.Now.AddDays(14).ToString("yyyy-MM-ddTHH:mm"); + public string BraintreeMerchantId { get; set; } - public BillingInfo BillingInfo { get; set; } - public string RandomLicenseKey => CoreHelpers.SecureRandomString(20); - public string FourteenDayExpirationDate => DateTime.Now.AddDays(14).ToString("yyyy-MM-ddTHH:mm"); - public string BraintreeMerchantId { get; set; } + [Required] + [Display(Name = "Name")] + public string Name { get; set; } + [Display(Name = "Business Name")] + public string BusinessName { get; set; } + [Display(Name = "Billing Email")] + public string BillingEmail { get; set; } + [Required] + [Display(Name = "Plan")] + public PlanType? PlanType { get; set; } + [Required] + [Display(Name = "Plan Name")] + public string Plan { get; set; } + [Display(Name = "Seats")] + public int? Seats { get; set; } + [Display(Name = "Max. Autoscale Seats")] + public int? MaxAutoscaleSeats { get; set; } + [Display(Name = "Max. Collections")] + public short? MaxCollections { get; set; } + [Display(Name = "Policies")] + public bool UsePolicies { get; set; } + [Display(Name = "SSO")] + public bool UseSso { get; set; } + [Display(Name = "Key Connector with Customer Encryption")] + public bool UseKeyConnector { get; set; } + [Display(Name = "Groups")] + public bool UseGroups { get; set; } + [Display(Name = "Directory")] + public bool UseDirectory { get; set; } + [Display(Name = "Events")] + public bool UseEvents { get; set; } + [Display(Name = "TOTP")] + public bool UseTotp { get; set; } + [Display(Name = "2FA")] + public bool Use2fa { get; set; } + [Display(Name = "API")] + public bool UseApi { get; set; } + [Display(Name = "Reset Password")] + public bool UseResetPassword { get; set; } + [Display(Name = "SCIM")] + public bool UseScim { get; set; } + [Display(Name = "Self Host")] + public bool SelfHost { get; set; } + [Display(Name = "Users Get Premium")] + public bool UsersGetPremium { get; set; } + [Display(Name = "Max. Storage GB")] + public short? MaxStorageGb { get; set; } + [Display(Name = "Gateway")] + public GatewayType? Gateway { get; set; } + [Display(Name = "Gateway Customer Id")] + public string GatewayCustomerId { get; set; } + [Display(Name = "Gateway Subscription Id")] + public string GatewaySubscriptionId { get; set; } + [Display(Name = "Enabled")] + public bool Enabled { get; set; } + [Display(Name = "License Key")] + public string LicenseKey { get; set; } + [Display(Name = "Expiration Date")] + public DateTime? ExpirationDate { get; set; } + public bool SalesAssistedTrialStarted { get; set; } - [Required] - [Display(Name = "Name")] - public string Name { get; set; } - [Display(Name = "Business Name")] - public string BusinessName { get; set; } - [Display(Name = "Billing Email")] - public string BillingEmail { get; set; } - [Required] - [Display(Name = "Plan")] - public PlanType? PlanType { get; set; } - [Required] - [Display(Name = "Plan Name")] - public string Plan { get; set; } - [Display(Name = "Seats")] - public int? Seats { get; set; } - [Display(Name = "Max. Autoscale Seats")] - public int? MaxAutoscaleSeats { get; set; } - [Display(Name = "Max. Collections")] - public short? MaxCollections { get; set; } - [Display(Name = "Policies")] - public bool UsePolicies { get; set; } - [Display(Name = "SSO")] - public bool UseSso { get; set; } - [Display(Name = "Key Connector with Customer Encryption")] - public bool UseKeyConnector { get; set; } - [Display(Name = "Groups")] - public bool UseGroups { get; set; } - [Display(Name = "Directory")] - public bool UseDirectory { get; set; } - [Display(Name = "Events")] - public bool UseEvents { get; set; } - [Display(Name = "TOTP")] - public bool UseTotp { get; set; } - [Display(Name = "2FA")] - public bool Use2fa { get; set; } - [Display(Name = "API")] - public bool UseApi { get; set; } - [Display(Name = "Reset Password")] - public bool UseResetPassword { get; set; } - [Display(Name = "SCIM")] - public bool UseScim { get; set; } - [Display(Name = "Self Host")] - public bool SelfHost { get; set; } - [Display(Name = "Users Get Premium")] - public bool UsersGetPremium { get; set; } - [Display(Name = "Max. Storage GB")] - public short? MaxStorageGb { get; set; } - [Display(Name = "Gateway")] - public GatewayType? Gateway { get; set; } - [Display(Name = "Gateway Customer Id")] - public string GatewayCustomerId { get; set; } - [Display(Name = "Gateway Subscription Id")] - public string GatewaySubscriptionId { get; set; } - [Display(Name = "Enabled")] - public bool Enabled { get; set; } - [Display(Name = "License Key")] - public string LicenseKey { get; set; } - [Display(Name = "Expiration Date")] - public DateTime? ExpirationDate { get; set; } - public bool SalesAssistedTrialStarted { get; set; } - - public Organization ToOrganization(Organization existingOrganization) - { - existingOrganization.Name = Name; - existingOrganization.BusinessName = BusinessName; - existingOrganization.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); - existingOrganization.PlanType = PlanType.Value; - existingOrganization.Plan = Plan; - existingOrganization.Seats = Seats; - existingOrganization.MaxCollections = MaxCollections; - existingOrganization.UsePolicies = UsePolicies; - existingOrganization.UseSso = UseSso; - existingOrganization.UseKeyConnector = UseKeyConnector; - existingOrganization.UseScim = UseScim; - existingOrganization.UseGroups = UseGroups; - existingOrganization.UseDirectory = UseDirectory; - existingOrganization.UseEvents = UseEvents; - existingOrganization.UseTotp = UseTotp; - existingOrganization.Use2fa = Use2fa; - existingOrganization.UseApi = UseApi; - existingOrganization.UseResetPassword = UseResetPassword; - existingOrganization.SelfHost = SelfHost; - existingOrganization.UsersGetPremium = UsersGetPremium; - existingOrganization.MaxStorageGb = MaxStorageGb; - existingOrganization.Gateway = Gateway; - existingOrganization.GatewayCustomerId = GatewayCustomerId; - existingOrganization.GatewaySubscriptionId = GatewaySubscriptionId; - existingOrganization.Enabled = Enabled; - existingOrganization.LicenseKey = LicenseKey; - existingOrganization.ExpirationDate = ExpirationDate; - existingOrganization.MaxAutoscaleSeats = MaxAutoscaleSeats; - return existingOrganization; - } + public Organization ToOrganization(Organization existingOrganization) + { + existingOrganization.Name = Name; + existingOrganization.BusinessName = BusinessName; + existingOrganization.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); + existingOrganization.PlanType = PlanType.Value; + existingOrganization.Plan = Plan; + existingOrganization.Seats = Seats; + existingOrganization.MaxCollections = MaxCollections; + existingOrganization.UsePolicies = UsePolicies; + existingOrganization.UseSso = UseSso; + existingOrganization.UseKeyConnector = UseKeyConnector; + existingOrganization.UseScim = UseScim; + existingOrganization.UseGroups = UseGroups; + existingOrganization.UseDirectory = UseDirectory; + existingOrganization.UseEvents = UseEvents; + existingOrganization.UseTotp = UseTotp; + existingOrganization.Use2fa = Use2fa; + existingOrganization.UseApi = UseApi; + existingOrganization.UseResetPassword = UseResetPassword; + existingOrganization.SelfHost = SelfHost; + existingOrganization.UsersGetPremium = UsersGetPremium; + existingOrganization.MaxStorageGb = MaxStorageGb; + existingOrganization.Gateway = Gateway; + existingOrganization.GatewayCustomerId = GatewayCustomerId; + existingOrganization.GatewaySubscriptionId = GatewaySubscriptionId; + existingOrganization.Enabled = Enabled; + existingOrganization.LicenseKey = LicenseKey; + existingOrganization.ExpirationDate = ExpirationDate; + existingOrganization.MaxAutoscaleSeats = MaxAutoscaleSeats; + return existingOrganization; } } diff --git a/src/Admin/Models/OrganizationViewModel.cs b/src/Admin/Models/OrganizationViewModel.cs index c17f273a7..5a487cd03 100644 --- a/src/Admin/Models/OrganizationViewModel.cs +++ b/src/Admin/Models/OrganizationViewModel.cs @@ -2,49 +2,48 @@ using Bit.Core.Enums; using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class OrganizationViewModel { - public class OrganizationViewModel + public OrganizationViewModel() { } + + public OrganizationViewModel(Organization org, IEnumerable connections, + IEnumerable orgUsers, IEnumerable ciphers, IEnumerable collections, + IEnumerable groups, IEnumerable policies) { - public OrganizationViewModel() { } - - public OrganizationViewModel(Organization org, IEnumerable connections, - IEnumerable orgUsers, IEnumerable ciphers, IEnumerable collections, - IEnumerable groups, IEnumerable policies) - { - Organization = org; - Connections = connections ?? Enumerable.Empty(); - HasPublicPrivateKeys = org.PublicKey != null && org.PrivateKey != null; - UserInvitedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Invited); - UserAcceptedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Accepted); - UserConfirmedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Confirmed); - UserCount = orgUsers.Count(); - CipherCount = ciphers.Count(); - CollectionCount = collections.Count(); - GroupCount = groups?.Count() ?? 0; - PolicyCount = policies?.Count() ?? 0; - Owners = string.Join(", ", - orgUsers - .Where(u => u.Type == OrganizationUserType.Owner && u.Status == OrganizationUserStatusType.Confirmed) - .Select(u => u.Email)); - Admins = string.Join(", ", - orgUsers - .Where(u => u.Type == OrganizationUserType.Admin && u.Status == OrganizationUserStatusType.Confirmed) - .Select(u => u.Email)); - } - - public Organization Organization { get; set; } - public IEnumerable Connections { get; set; } - public string Owners { get; set; } - public string Admins { get; set; } - public int UserInvitedCount { get; set; } - public int UserConfirmedCount { get; set; } - public int UserAcceptedCount { get; set; } - public int UserCount { get; set; } - public int CipherCount { get; set; } - public int CollectionCount { get; set; } - public int GroupCount { get; set; } - public int PolicyCount { get; set; } - public bool HasPublicPrivateKeys { get; set; } + Organization = org; + Connections = connections ?? Enumerable.Empty(); + HasPublicPrivateKeys = org.PublicKey != null && org.PrivateKey != null; + UserInvitedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Invited); + UserAcceptedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Accepted); + UserConfirmedCount = orgUsers.Count(u => u.Status == OrganizationUserStatusType.Confirmed); + UserCount = orgUsers.Count(); + CipherCount = ciphers.Count(); + CollectionCount = collections.Count(); + GroupCount = groups?.Count() ?? 0; + PolicyCount = policies?.Count() ?? 0; + Owners = string.Join(", ", + orgUsers + .Where(u => u.Type == OrganizationUserType.Owner && u.Status == OrganizationUserStatusType.Confirmed) + .Select(u => u.Email)); + Admins = string.Join(", ", + orgUsers + .Where(u => u.Type == OrganizationUserType.Admin && u.Status == OrganizationUserStatusType.Confirmed) + .Select(u => u.Email)); } + + public Organization Organization { get; set; } + public IEnumerable Connections { get; set; } + public string Owners { get; set; } + public string Admins { get; set; } + public int UserInvitedCount { get; set; } + public int UserConfirmedCount { get; set; } + public int UserAcceptedCount { get; set; } + public int UserCount { get; set; } + public int CipherCount { get; set; } + public int CollectionCount { get; set; } + public int GroupCount { get; set; } + public int PolicyCount { get; set; } + public bool HasPublicPrivateKeys { get; set; } } diff --git a/src/Admin/Models/OrganizationsModel.cs b/src/Admin/Models/OrganizationsModel.cs index da2eb20d6..706377f8e 100644 --- a/src/Admin/Models/OrganizationsModel.cs +++ b/src/Admin/Models/OrganizationsModel.cs @@ -1,13 +1,12 @@ using Bit.Core.Entities; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class OrganizationsModel : PagedModel { - public class OrganizationsModel : PagedModel - { - public string Name { get; set; } - public string UserEmail { get; set; } - public bool? Paid { get; set; } - public string Action { get; set; } - public bool SelfHosted { get; set; } - } + public string Name { get; set; } + public string UserEmail { get; set; } + public bool? Paid { get; set; } + public string Action { get; set; } + public bool SelfHosted { get; set; } } diff --git a/src/Admin/Models/PagedModel.cs b/src/Admin/Models/PagedModel.cs index ac4f2e84d..4c9c8e171 100644 --- a/src/Admin/Models/PagedModel.cs +++ b/src/Admin/Models/PagedModel.cs @@ -1,11 +1,10 @@ -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public abstract class PagedModel { - public abstract class PagedModel - { - public List Items { get; set; } - public int Page { get; set; } - public int Count { get; set; } - public int? PreviousPage => Page < 2 ? (int?)null : Page - 1; - public int? NextPage => Items.Count < Count ? (int?)null : Page + 1; - } + public List Items { get; set; } + public int Page { get; set; } + public int Count { get; set; } + public int? PreviousPage => Page < 2 ? (int?)null : Page - 1; + public int? NextPage => Items.Count < Count ? (int?)null : Page + 1; } diff --git a/src/Admin/Models/PromoteAdminModel.cs b/src/Admin/Models/PromoteAdminModel.cs index 0beae6bd8..bc076d6ab 100644 --- a/src/Admin/Models/PromoteAdminModel.cs +++ b/src/Admin/Models/PromoteAdminModel.cs @@ -1,14 +1,13 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class PromoteAdminModel { - public class PromoteAdminModel - { - [Required] - [Display(Name = "Admin User Id")] - public Guid? UserId { get; set; } - [Required] - [Display(Name = "Organization Id")] - public Guid? OrganizationId { get; set; } - } + [Required] + [Display(Name = "Admin User Id")] + public Guid? UserId { get; set; } + [Required] + [Display(Name = "Organization Id")] + public Guid? OrganizationId { get; set; } } diff --git a/src/Admin/Models/ProviderEditModel.cs b/src/Admin/Models/ProviderEditModel.cs index 578d0ff22..92b2f89e9 100644 --- a/src/Admin/Models/ProviderEditModel.cs +++ b/src/Admin/Models/ProviderEditModel.cs @@ -2,33 +2,32 @@ using Bit.Core.Entities.Provider; using Bit.Core.Models.Data; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class ProviderEditModel : ProviderViewModel { - public class ProviderEditModel : ProviderViewModel + public ProviderEditModel() { } + + public ProviderEditModel(Provider provider, IEnumerable providerUsers, IEnumerable organizations) + : base(provider, providerUsers, organizations) { - public ProviderEditModel() { } + Name = provider.Name; + BusinessName = provider.BusinessName; + BillingEmail = provider.BillingEmail; + } - public ProviderEditModel(Provider provider, IEnumerable providerUsers, IEnumerable organizations) - : base(provider, providerUsers, organizations) - { - Name = provider.Name; - BusinessName = provider.BusinessName; - BillingEmail = provider.BillingEmail; - } + [Display(Name = "Billing Email")] + public string BillingEmail { get; set; } + [Display(Name = "Business Name")] + public string BusinessName { get; set; } + public string Name { get; set; } + [Display(Name = "Events")] - [Display(Name = "Billing Email")] - public string BillingEmail { get; set; } - [Display(Name = "Business Name")] - public string BusinessName { get; set; } - public string Name { get; set; } - [Display(Name = "Events")] - - public Provider ToProvider(Provider existingProvider) - { - existingProvider.Name = Name; - existingProvider.BusinessName = BusinessName; - existingProvider.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); - return existingProvider; - } + public Provider ToProvider(Provider existingProvider) + { + existingProvider.Name = Name; + existingProvider.BusinessName = BusinessName; + existingProvider.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); + return existingProvider; } } diff --git a/src/Admin/Models/ProviderViewModel.cs b/src/Admin/Models/ProviderViewModel.cs index 05fae3c9c..766101e88 100644 --- a/src/Admin/Models/ProviderViewModel.cs +++ b/src/Admin/Models/ProviderViewModel.cs @@ -2,24 +2,23 @@ using Bit.Core.Enums.Provider; using Bit.Core.Models.Data; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class ProviderViewModel { - public class ProviderViewModel + public ProviderViewModel() { } + + public ProviderViewModel(Provider provider, IEnumerable providerUsers, IEnumerable organizations) { - public ProviderViewModel() { } + Provider = provider; + UserCount = providerUsers.Count(); + ProviderAdmins = providerUsers.Where(u => u.Type == ProviderUserType.ProviderAdmin); - public ProviderViewModel(Provider provider, IEnumerable providerUsers, IEnumerable organizations) - { - Provider = provider; - UserCount = providerUsers.Count(); - ProviderAdmins = providerUsers.Where(u => u.Type == ProviderUserType.ProviderAdmin); - - ProviderOrganizations = organizations.Where(o => o.ProviderId == provider.Id); - } - - public int UserCount { get; set; } - public Provider Provider { get; set; } - public IEnumerable ProviderAdmins { get; set; } - public IEnumerable ProviderOrganizations { get; set; } + ProviderOrganizations = organizations.Where(o => o.ProviderId == provider.Id); } + + public int UserCount { get; set; } + public Provider Provider { get; set; } + public IEnumerable ProviderAdmins { get; set; } + public IEnumerable ProviderOrganizations { get; set; } } diff --git a/src/Admin/Models/ProvidersModel.cs b/src/Admin/Models/ProvidersModel.cs index 02509593d..dccf4a4d7 100644 --- a/src/Admin/Models/ProvidersModel.cs +++ b/src/Admin/Models/ProvidersModel.cs @@ -1,13 +1,12 @@ using Bit.Core.Entities.Provider; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class ProvidersModel : PagedModel { - public class ProvidersModel : PagedModel - { - public string Name { get; set; } - public string UserEmail { get; set; } - public bool? Paid { get; set; } - public string Action { get; set; } - public bool SelfHosted { get; set; } - } + public string Name { get; set; } + public string UserEmail { get; set; } + public bool? Paid { get; set; } + public string Action { get; set; } + public bool SelfHosted { get; set; } } diff --git a/src/Admin/Models/StripeSubscriptionsModel.cs b/src/Admin/Models/StripeSubscriptionsModel.cs index 3e30d63d5..99e9c5b77 100644 --- a/src/Admin/Models/StripeSubscriptionsModel.cs +++ b/src/Admin/Models/StripeSubscriptionsModel.cs @@ -1,43 +1,42 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Models.BitStripe; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class StripeSubscriptionRowModel { - public class StripeSubscriptionRowModel - { - public Stripe.Subscription Subscription { get; set; } - public bool Selected { get; set; } + public Stripe.Subscription Subscription { get; set; } + public bool Selected { get; set; } - public StripeSubscriptionRowModel() { } - public StripeSubscriptionRowModel(Stripe.Subscription subscription) - { - Subscription = subscription; - } + public StripeSubscriptionRowModel() { } + public StripeSubscriptionRowModel(Stripe.Subscription subscription) + { + Subscription = subscription; } +} - public enum StripeSubscriptionsAction - { - Search, - PreviousPage, - NextPage, - Export, - BulkCancel - } +public enum StripeSubscriptionsAction +{ + Search, + PreviousPage, + NextPage, + Export, + BulkCancel +} - public class StripeSubscriptionsModel : IValidatableObject +public class StripeSubscriptionsModel : IValidatableObject +{ + public List Items { get; set; } + public StripeSubscriptionsAction Action { get; set; } = StripeSubscriptionsAction.Search; + public string Message { get; set; } + public List Prices { get; set; } + public List TestClocks { get; set; } + public StripeSubscriptionListOptions Filter { get; set; } = new StripeSubscriptionListOptions(); + public IEnumerable Validate(ValidationContext validationContext) { - public List Items { get; set; } - public StripeSubscriptionsAction Action { get; set; } = StripeSubscriptionsAction.Search; - public string Message { get; set; } - public List Prices { get; set; } - public List TestClocks { get; set; } - public StripeSubscriptionListOptions Filter { get; set; } = new StripeSubscriptionListOptions(); - public IEnumerable Validate(ValidationContext validationContext) + if (Action == StripeSubscriptionsAction.BulkCancel && Filter.Status != "unpaid") { - if (Action == StripeSubscriptionsAction.BulkCancel && Filter.Status != "unpaid") - { - yield return new ValidationResult("Bulk cancel is currently only supported for unpaid subscriptions"); - } + yield return new ValidationResult("Bulk cancel is currently only supported for unpaid subscriptions"); } } } diff --git a/src/Admin/Models/TaxRateAddEditModel.cs b/src/Admin/Models/TaxRateAddEditModel.cs index e55ec87c6..bfa87d7cc 100644 --- a/src/Admin/Models/TaxRateAddEditModel.cs +++ b/src/Admin/Models/TaxRateAddEditModel.cs @@ -1,11 +1,10 @@ -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class TaxRateAddEditModel { - public class TaxRateAddEditModel - { - public string StripeTaxRateId { get; set; } - public string Country { get; set; } - public string State { get; set; } - public string PostalCode { get; set; } - public decimal Rate { get; set; } - } + public string StripeTaxRateId { get; set; } + public string Country { get; set; } + public string State { get; set; } + public string PostalCode { get; set; } + public decimal Rate { get; set; } } diff --git a/src/Admin/Models/TaxRatesModel.cs b/src/Admin/Models/TaxRatesModel.cs index 92564d82f..0af073f38 100644 --- a/src/Admin/Models/TaxRatesModel.cs +++ b/src/Admin/Models/TaxRatesModel.cs @@ -1,9 +1,8 @@ using Bit.Core.Entities; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class TaxRatesModel : PagedModel { - public class TaxRatesModel : PagedModel - { - public string Message { get; set; } - } + public string Message { get; set; } } diff --git a/src/Admin/Models/UserEditModel.cs b/src/Admin/Models/UserEditModel.cs index 5b789c73d..d7ef56f08 100644 --- a/src/Admin/Models/UserEditModel.cs +++ b/src/Admin/Models/UserEditModel.cs @@ -4,71 +4,70 @@ using Bit.Core.Models.Business; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class UserEditModel : UserViewModel { - public class UserEditModel : UserViewModel + public UserEditModel() { } + + public UserEditModel(User user, IEnumerable ciphers, BillingInfo billingInfo, + GlobalSettings globalSettings) + : base(user, ciphers) { - public UserEditModel() { } + BillingInfo = billingInfo; + BraintreeMerchantId = globalSettings.Braintree.MerchantId; - public UserEditModel(User user, IEnumerable ciphers, BillingInfo billingInfo, - GlobalSettings globalSettings) - : base(user, ciphers) - { - BillingInfo = billingInfo; - BraintreeMerchantId = globalSettings.Braintree.MerchantId; + Name = user.Name; + Email = user.Email; + EmailVerified = user.EmailVerified; + Premium = user.Premium; + MaxStorageGb = user.MaxStorageGb; + Gateway = user.Gateway; + GatewayCustomerId = user.GatewayCustomerId; + GatewaySubscriptionId = user.GatewaySubscriptionId; + LicenseKey = user.LicenseKey; + PremiumExpirationDate = user.PremiumExpirationDate; + } - Name = user.Name; - Email = user.Email; - EmailVerified = user.EmailVerified; - Premium = user.Premium; - MaxStorageGb = user.MaxStorageGb; - Gateway = user.Gateway; - GatewayCustomerId = user.GatewayCustomerId; - GatewaySubscriptionId = user.GatewaySubscriptionId; - LicenseKey = user.LicenseKey; - PremiumExpirationDate = user.PremiumExpirationDate; - } + public BillingInfo BillingInfo { get; set; } + public string RandomLicenseKey => CoreHelpers.SecureRandomString(20); + public string OneYearExpirationDate => DateTime.Now.AddYears(1).ToString("yyyy-MM-ddTHH:mm"); + public string BraintreeMerchantId { get; set; } - public BillingInfo BillingInfo { get; set; } - public string RandomLicenseKey => CoreHelpers.SecureRandomString(20); - public string OneYearExpirationDate => DateTime.Now.AddYears(1).ToString("yyyy-MM-ddTHH:mm"); - public string BraintreeMerchantId { get; set; } + [Display(Name = "Name")] + public string Name { get; set; } + [Required] + [Display(Name = "Email")] + public string Email { get; set; } + [Display(Name = "Email Verified")] + public bool EmailVerified { get; set; } + [Display(Name = "Premium")] + public bool Premium { get; set; } + [Display(Name = "Max. Storage GB")] + public short? MaxStorageGb { get; set; } + [Display(Name = "Gateway")] + public Core.Enums.GatewayType? Gateway { get; set; } + [Display(Name = "Gateway Customer Id")] + public string GatewayCustomerId { get; set; } + [Display(Name = "Gateway Subscription Id")] + public string GatewaySubscriptionId { get; set; } + [Display(Name = "License Key")] + public string LicenseKey { get; set; } + [Display(Name = "Premium Expiration Date")] + public DateTime? PremiumExpirationDate { get; set; } - [Display(Name = "Name")] - public string Name { get; set; } - [Required] - [Display(Name = "Email")] - public string Email { get; set; } - [Display(Name = "Email Verified")] - public bool EmailVerified { get; set; } - [Display(Name = "Premium")] - public bool Premium { get; set; } - [Display(Name = "Max. Storage GB")] - public short? MaxStorageGb { get; set; } - [Display(Name = "Gateway")] - public Core.Enums.GatewayType? Gateway { get; set; } - [Display(Name = "Gateway Customer Id")] - public string GatewayCustomerId { get; set; } - [Display(Name = "Gateway Subscription Id")] - public string GatewaySubscriptionId { get; set; } - [Display(Name = "License Key")] - public string LicenseKey { get; set; } - [Display(Name = "Premium Expiration Date")] - public DateTime? PremiumExpirationDate { get; set; } - - public User ToUser(User existingUser) - { - existingUser.Name = Name; - existingUser.Email = Email; - existingUser.EmailVerified = EmailVerified; - existingUser.Premium = Premium; - existingUser.MaxStorageGb = MaxStorageGb; - existingUser.Gateway = Gateway; - existingUser.GatewayCustomerId = GatewayCustomerId; - existingUser.GatewaySubscriptionId = GatewaySubscriptionId; - existingUser.LicenseKey = LicenseKey; - existingUser.PremiumExpirationDate = PremiumExpirationDate; - return existingUser; - } + public User ToUser(User existingUser) + { + existingUser.Name = Name; + existingUser.Email = Email; + existingUser.EmailVerified = EmailVerified; + existingUser.Premium = Premium; + existingUser.MaxStorageGb = MaxStorageGb; + existingUser.Gateway = Gateway; + existingUser.GatewayCustomerId = GatewayCustomerId; + existingUser.GatewaySubscriptionId = GatewaySubscriptionId; + existingUser.LicenseKey = LicenseKey; + existingUser.PremiumExpirationDate = PremiumExpirationDate; + return existingUser; } } diff --git a/src/Admin/Models/UserViewModel.cs b/src/Admin/Models/UserViewModel.cs index adc8fb268..f493f68f2 100644 --- a/src/Admin/Models/UserViewModel.cs +++ b/src/Admin/Models/UserViewModel.cs @@ -1,18 +1,17 @@ using Bit.Core.Entities; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class UserViewModel { - public class UserViewModel + public UserViewModel() { } + + public UserViewModel(User user, IEnumerable ciphers) { - public UserViewModel() { } - - public UserViewModel(User user, IEnumerable ciphers) - { - User = user; - CipherCount = ciphers.Count(); - } - - public User User { get; set; } - public int CipherCount { get; set; } + User = user; + CipherCount = ciphers.Count(); } + + public User User { get; set; } + public int CipherCount { get; set; } } diff --git a/src/Admin/Models/UsersModel.cs b/src/Admin/Models/UsersModel.cs index 1215a9555..0a54e318d 100644 --- a/src/Admin/Models/UsersModel.cs +++ b/src/Admin/Models/UsersModel.cs @@ -1,10 +1,9 @@ using Bit.Core.Entities; -namespace Bit.Admin.Models +namespace Bit.Admin.Models; + +public class UsersModel : PagedModel { - public class UsersModel : PagedModel - { - public string Email { get; set; } - public string Action { get; set; } - } + public string Email { get; set; } + public string Action { get; set; } } diff --git a/src/Admin/Program.cs b/src/Admin/Program.cs index d8a55e7b6..f5bc877ab 100644 --- a/src/Admin/Program.cs +++ b/src/Admin/Program.cs @@ -1,37 +1,36 @@ using Bit.Core.Utilities; using Serilog.Events; -namespace Bit.Admin +namespace Bit.Admin; + +public class Program { - public class Program + public static void Main(string[] args) { - public static void Main(string[] args) - { - Host - .CreateDefaultBuilder(args) - .ConfigureCustomAppConfiguration(args) - .ConfigureWebHostDefaults(webBuilder => + Host + .CreateDefaultBuilder(args) + .ConfigureCustomAppConfiguration(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.ConfigureKestrel(o => { - webBuilder.ConfigureKestrel(o => + o.Limits.MaxRequestLineSize = 20_000; + }); + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => + { + var context = e.Properties["SourceContext"].ToString(); + if (e.Properties.ContainsKey("RequestPath") && + !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && + (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) { - o.Limits.MaxRequestLineSize = 20_000; - }); - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => - { - var context = e.Properties["SourceContext"].ToString(); - if (e.Properties.ContainsKey("RequestPath") && - !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) - { - return false; - } - return e.Level >= LogEventLevel.Error; - })); - }) - .Build() - .Run(); - } + return false; + } + return e.Level >= LogEventLevel.Error; + })); + }) + .Build() + .Run(); } } diff --git a/src/Admin/Startup.cs b/src/Admin/Startup.cs index ea8485c79..37645873e 100644 --- a/src/Admin/Startup.cs +++ b/src/Admin/Startup.cs @@ -11,128 +11,127 @@ using Stripe; using Bit.Commercial.Core.Utilities; #endif -namespace Bit.Admin +namespace Bit.Admin; + +public class Startup { - public class Startup + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Configuration = configuration; + Environment = env; + } + + public IConfiguration Configuration { get; private set; } + public IWebHostEnvironment Environment { get; set; } + + public void ConfigureServices(IServiceCollection services) + { + // Options + services.AddOptions(); + + // Settings + var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); + services.Configure(Configuration.GetSection("AdminSettings")); + + // Data Protection + services.AddCustomDataProtectionServices(Environment, globalSettings); + + // Stripe Billing + StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; + StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; + + // Repositories + services.AddSqlServerRepositories(globalSettings); + + // Context + services.AddScoped(); + + // Identity + services.AddPasswordlessIdentityServices(globalSettings); + services.Configure(options => { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - Configuration = configuration; - Environment = env; + options.ValidationInterval = TimeSpan.FromMinutes(5); + }); + if (globalSettings.SelfHosted) + { + services.ConfigureApplicationCookie(options => + { + options.Cookie.Path = "/admin"; + }); } - public IConfiguration Configuration { get; private set; } - public IWebHostEnvironment Environment { get; set; } - - public void ConfigureServices(IServiceCollection services) - { - // Options - services.AddOptions(); - - // Settings - var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); - services.Configure(Configuration.GetSection("AdminSettings")); - - // Data Protection - services.AddCustomDataProtectionServices(Environment, globalSettings); - - // Stripe Billing - StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; - StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; - - // Repositories - services.AddSqlServerRepositories(globalSettings); - - // Context - services.AddScoped(); - - // Identity - services.AddPasswordlessIdentityServices(globalSettings); - services.Configure(options => - { - options.ValidationInterval = TimeSpan.FromMinutes(5); - }); - if (globalSettings.SelfHosted) - { - services.ConfigureApplicationCookie(options => - { - options.Cookie.Path = "/admin"; - }); - } - - // Services - services.AddBaseServices(globalSettings); - services.AddDefaultServices(globalSettings); + // Services + services.AddBaseServices(globalSettings); + services.AddDefaultServices(globalSettings); #if OSS - services.AddOosServices(); + services.AddOosServices(); #else - services.AddCommCoreServices(); + services.AddCommCoreServices(); #endif - // Mvc - services.AddMvc(config => - { - config.Filters.Add(new LoggingExceptionHandlerFilterAttribute()); - }); - services.Configure(options => options.LowercaseUrls = true); - - // Jobs service - Jobs.JobsHostedService.AddJobsServices(services, globalSettings.SelfHosted); - services.AddHostedService(); - if (globalSettings.SelfHosted) - { - services.AddHostedService(); - } - else - { - if (CoreHelpers.SettingHasValue(globalSettings.Storage.ConnectionString)) - { - services.AddHostedService(); - } - else if (CoreHelpers.SettingHasValue(globalSettings.Amazon?.AccessKeySecret)) - { - services.AddHostedService(); - } - if (CoreHelpers.SettingHasValue(globalSettings.Mail.ConnectionString)) - { - services.AddHostedService(); - } - } - } - - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings) + // Mvc + services.AddMvc(config => { - app.UseSerilog(env, appLifetime, globalSettings); + config.Filters.Add(new LoggingExceptionHandlerFilterAttribute()); + }); + services.Configure(options => options.LowercaseUrls = true); - // Add general security headers - app.UseMiddleware(); - - if (globalSettings.SelfHosted) + // Jobs service + Jobs.JobsHostedService.AddJobsServices(services, globalSettings.SelfHosted); + services.AddHostedService(); + if (globalSettings.SelfHosted) + { + services.AddHostedService(); + } + else + { + if (CoreHelpers.SettingHasValue(globalSettings.Storage.ConnectionString)) { - app.UsePathBase("/admin"); - app.UseForwardedHeaders(globalSettings); + services.AddHostedService(); } - - if (env.IsDevelopment()) + else if (CoreHelpers.SettingHasValue(globalSettings.Amazon?.AccessKeySecret)) { - app.UseDeveloperExceptionPage(); + services.AddHostedService(); } - else + if (CoreHelpers.SettingHasValue(globalSettings.Mail.ConnectionString)) { - app.UseExceptionHandler("/error"); + services.AddHostedService(); } - - app.UseStaticFiles(); - app.UseRouting(); - app.UseAuthentication(); - app.UseAuthorization(); - app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); } } + + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings) + { + app.UseSerilog(env, appLifetime, globalSettings); + + // Add general security headers + app.UseMiddleware(); + + if (globalSettings.SelfHosted) + { + app.UsePathBase("/admin"); + app.UseForwardedHeaders(globalSettings); + } + + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); + } + else + { + app.UseExceptionHandler("/error"); + } + + app.UseStaticFiles(); + app.UseRouting(); + app.UseAuthentication(); + app.UseAuthorization(); + app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); + } } diff --git a/src/Admin/TagHelpers/ActivePageTagHelper.cs b/src/Admin/TagHelpers/ActivePageTagHelper.cs index 6e400383d..a148e3cdf 100644 --- a/src/Admin/TagHelpers/ActivePageTagHelper.cs +++ b/src/Admin/TagHelpers/ActivePageTagHelper.cs @@ -3,72 +3,71 @@ using Microsoft.AspNetCore.Mvc.Rendering; using Microsoft.AspNetCore.Mvc.ViewFeatures; using Microsoft.AspNetCore.Razor.TagHelpers; -namespace Bit.Admin.TagHelpers +namespace Bit.Admin.TagHelpers; + +[HtmlTargetElement("li", Attributes = ActiveControllerName)] +[HtmlTargetElement("li", Attributes = ActiveActionName)] +public class ActivePageTagHelper : TagHelper { - [HtmlTargetElement("li", Attributes = ActiveControllerName)] - [HtmlTargetElement("li", Attributes = ActiveActionName)] - public class ActivePageTagHelper : TagHelper + private const string ActiveControllerName = "active-controller"; + private const string ActiveActionName = "active-action"; + + private readonly IHtmlGenerator _generator; + + public ActivePageTagHelper(IHtmlGenerator generator) { - private const string ActiveControllerName = "active-controller"; - private const string ActiveActionName = "active-action"; + _generator = generator; + } - private readonly IHtmlGenerator _generator; + [HtmlAttributeNotBound] + [ViewContext] + public ViewContext ViewContext { get; set; } + [HtmlAttributeName(ActiveControllerName)] + public string ActiveController { get; set; } + [HtmlAttributeName(ActiveActionName)] + public string ActiveAction { get; set; } - public ActivePageTagHelper(IHtmlGenerator generator) + public override void Process(TagHelperContext context, TagHelperOutput output) + { + if (context == null) { - _generator = generator; + throw new ArgumentNullException(nameof(context)); } - [HtmlAttributeNotBound] - [ViewContext] - public ViewContext ViewContext { get; set; } - [HtmlAttributeName(ActiveControllerName)] - public string ActiveController { get; set; } - [HtmlAttributeName(ActiveActionName)] - public string ActiveAction { get; set; } - - public override void Process(TagHelperContext context, TagHelperOutput output) + if (output == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - if (output == null) - { - throw new ArgumentNullException(nameof(output)); - } - - if (ActiveAction == null && ActiveController == null) - { - return; - } - - var descriptor = ViewContext.ActionDescriptor as ControllerActionDescriptor; - if (descriptor == null) - { - return; - } - - var controllerMatch = ActiveMatch(ActiveController, descriptor.ControllerName); - var actionMatch = ActiveMatch(ActiveAction, descriptor.ActionName); - if (controllerMatch && actionMatch) - { - var classValue = "active"; - if (output.Attributes["class"] != null) - { - classValue += " " + output.Attributes["class"].Value; - output.Attributes.Remove(output.Attributes["class"]); - } - - output.Attributes.Add("class", classValue); - } + throw new ArgumentNullException(nameof(output)); } - private bool ActiveMatch(string route, string descriptor) + if (ActiveAction == null && ActiveController == null) { - return route == null || route == "*" || - route.Split(',').Any(c => c.Trim().ToLower() == descriptor.ToLower()); + return; + } + + var descriptor = ViewContext.ActionDescriptor as ControllerActionDescriptor; + if (descriptor == null) + { + return; + } + + var controllerMatch = ActiveMatch(ActiveController, descriptor.ControllerName); + var actionMatch = ActiveMatch(ActiveAction, descriptor.ActionName); + if (controllerMatch && actionMatch) + { + var classValue = "active"; + if (output.Attributes["class"] != null) + { + classValue += " " + output.Attributes["class"].Value; + output.Attributes.Remove(output.Attributes["class"]); + } + + output.Attributes.Add("class", classValue); } } + + private bool ActiveMatch(string route, string descriptor) + { + return route == null || route == "*" || + route.Split(',').Any(c => c.Trim().ToLower() == descriptor.ToLower()); + } } diff --git a/src/Admin/TagHelpers/OptionSelectedTagHelper.cs b/src/Admin/TagHelpers/OptionSelectedTagHelper.cs index 190d3d1cc..3dc9562a0 100644 --- a/src/Admin/TagHelpers/OptionSelectedTagHelper.cs +++ b/src/Admin/TagHelpers/OptionSelectedTagHelper.cs @@ -1,43 +1,42 @@ using Microsoft.AspNetCore.Mvc.ViewFeatures; using Microsoft.AspNetCore.Razor.TagHelpers; -namespace Bit.Admin.TagHelpers +namespace Bit.Admin.TagHelpers; + +[HtmlTargetElement("option", Attributes = SelectedName)] +public class OptionSelectedTagHelper : TagHelper { - [HtmlTargetElement("option", Attributes = SelectedName)] - public class OptionSelectedTagHelper : TagHelper + private const string SelectedName = "asp-selected"; + + private readonly IHtmlGenerator _generator; + + public OptionSelectedTagHelper(IHtmlGenerator generator) { - private const string SelectedName = "asp-selected"; + _generator = generator; + } - private readonly IHtmlGenerator _generator; + [HtmlAttributeName(SelectedName)] + public bool Selected { get; set; } - public OptionSelectedTagHelper(IHtmlGenerator generator) + public override void Process(TagHelperContext context, TagHelperOutput output) + { + if (context == null) { - _generator = generator; + throw new ArgumentNullException(nameof(context)); } - [HtmlAttributeName(SelectedName)] - public bool Selected { get; set; } - - public override void Process(TagHelperContext context, TagHelperOutput output) + if (output == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } + throw new ArgumentNullException(nameof(output)); + } - if (output == null) - { - throw new ArgumentNullException(nameof(output)); - } - - if (Selected) - { - output.Attributes.Add("selected", "selected"); - } - else - { - output.Attributes.RemoveAll("selected"); - } + if (Selected) + { + output.Attributes.Add("selected", "selected"); + } + else + { + output.Attributes.RemoveAll("selected"); } } } diff --git a/src/Api/Controllers/AccountsBillingController.cs b/src/Api/Controllers/AccountsBillingController.cs index bc012e7b3..9e480301f 100644 --- a/src/Api/Controllers/AccountsBillingController.cs +++ b/src/Api/Controllers/AccountsBillingController.cs @@ -4,49 +4,48 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("accounts/billing")] +[Authorize("Application")] +public class AccountsBillingController : Controller { - [Route("accounts/billing")] - [Authorize("Application")] - public class AccountsBillingController : Controller + private readonly IPaymentService _paymentService; + private readonly IUserService _userService; + + public AccountsBillingController( + IPaymentService paymentService, + IUserService userService) { - private readonly IPaymentService _paymentService; - private readonly IUserService _userService; + _paymentService = paymentService; + _userService = userService; + } - public AccountsBillingController( - IPaymentService paymentService, - IUserService userService) + [HttpGet("history")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task GetBillingHistory() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) { - _paymentService = paymentService; - _userService = userService; + throw new UnauthorizedAccessException(); } - [HttpGet("history")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetBillingHistory() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } + var billingInfo = await _paymentService.GetBillingHistoryAsync(user); + return new BillingHistoryResponseModel(billingInfo); + } - var billingInfo = await _paymentService.GetBillingHistoryAsync(user); - return new BillingHistoryResponseModel(billingInfo); + [HttpGet("payment-method")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task GetPaymentMethod() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); } - [HttpGet("payment-method")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetPaymentMethod() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var billingInfo = await _paymentService.GetBillingBalanceAndSourceAsync(user); - return new BillingPaymentResponseModel(billingInfo); - } + var billingInfo = await _paymentService.GetBillingBalanceAndSourceAsync(user); + return new BillingPaymentResponseModel(billingInfo); } } diff --git a/src/Api/Controllers/AccountsController.cs b/src/Api/Controllers/AccountsController.cs index 41708d3d2..74aa469c9 100644 --- a/src/Api/Controllers/AccountsController.cs +++ b/src/Api/Controllers/AccountsController.cs @@ -18,833 +18,511 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("accounts")] +[Authorize("Application")] +public class AccountsController : Controller { - [Route("accounts")] - [Authorize("Application")] - public class AccountsController : Controller + private readonly GlobalSettings _globalSettings; + private readonly ICipherRepository _cipherRepository; + private readonly IFolderRepository _folderRepository; + private readonly IOrganizationService _organizationService; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IProviderUserRepository _providerUserRepository; + private readonly IPaymentService _paymentService; + private readonly IUserRepository _userRepository; + private readonly IUserService _userService; + private readonly ISendRepository _sendRepository; + private readonly ISendService _sendService; + + public AccountsController( + GlobalSettings globalSettings, + ICipherRepository cipherRepository, + IFolderRepository folderRepository, + IOrganizationService organizationService, + IOrganizationUserRepository organizationUserRepository, + IProviderUserRepository providerUserRepository, + IPaymentService paymentService, + IUserRepository userRepository, + IUserService userService, + ISendRepository sendRepository, + ISendService sendService) { - private readonly GlobalSettings _globalSettings; - private readonly ICipherRepository _cipherRepository; - private readonly IFolderRepository _folderRepository; - private readonly IOrganizationService _organizationService; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IProviderUserRepository _providerUserRepository; - private readonly IPaymentService _paymentService; - private readonly IUserRepository _userRepository; - private readonly IUserService _userService; - private readonly ISendRepository _sendRepository; - private readonly ISendService _sendService; + _cipherRepository = cipherRepository; + _folderRepository = folderRepository; + _globalSettings = globalSettings; + _organizationService = organizationService; + _organizationUserRepository = organizationUserRepository; + _providerUserRepository = providerUserRepository; + _paymentService = paymentService; + _userRepository = userRepository; + _userService = userService; + _sendRepository = sendRepository; + _sendService = sendService; + } - public AccountsController( - GlobalSettings globalSettings, - ICipherRepository cipherRepository, - IFolderRepository folderRepository, - IOrganizationService organizationService, - IOrganizationUserRepository organizationUserRepository, - IProviderUserRepository providerUserRepository, - IPaymentService paymentService, - IUserRepository userRepository, - IUserService userService, - ISendRepository sendRepository, - ISendService sendService) + #region DEPRECATED (Moved to Identity Service) + + [Obsolete("2022-01-12 Moved to Identity, left for backwards compatability with older clients")] + [HttpPost("prelogin")] + [AllowAnonymous] + public async Task PostPrelogin([FromBody] PreloginRequestModel model) + { + var kdfInformation = await _userRepository.GetKdfInformationByEmailAsync(model.Email); + if (kdfInformation == null) { - _cipherRepository = cipherRepository; - _folderRepository = folderRepository; - _globalSettings = globalSettings; - _organizationService = organizationService; - _organizationUserRepository = organizationUserRepository; - _providerUserRepository = providerUserRepository; - _paymentService = paymentService; - _userRepository = userRepository; - _userService = userService; - _sendRepository = sendRepository; - _sendService = sendService; + kdfInformation = new UserKdfInformation + { + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 100000, + }; + } + return new PreloginResponseModel(kdfInformation); + } + + [Obsolete("2022-01-12 Moved to Identity, left for backwards compatability with older clients")] + [HttpPost("register")] + [AllowAnonymous] + [CaptchaProtected] + public async Task PostRegister([FromBody] RegisterRequestModel model) + { + var result = await _userService.RegisterUserAsync(model.ToUser(), model.MasterPasswordHash, + model.Token, model.OrganizationUserId); + if (result.Succeeded) + { + return; } - #region DEPRECATED (Moved to Identity Service) - - [Obsolete("2022-01-12 Moved to Identity, left for backwards compatability with older clients")] - [HttpPost("prelogin")] - [AllowAnonymous] - public async Task PostPrelogin([FromBody] PreloginRequestModel model) + foreach (var error in result.Errors.Where(e => e.Code != "DuplicateUserName")) { - var kdfInformation = await _userRepository.GetKdfInformationByEmailAsync(model.Email); - if (kdfInformation == null) - { - kdfInformation = new UserKdfInformation - { - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 100000, - }; - } - return new PreloginResponseModel(kdfInformation); + ModelState.AddModelError(string.Empty, error.Description); } - [Obsolete("2022-01-12 Moved to Identity, left for backwards compatability with older clients")] - [HttpPost("register")] - [AllowAnonymous] - [CaptchaProtected] - public async Task PostRegister([FromBody] RegisterRequestModel model) + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + #endregion + + [HttpPost("password-hint")] + [AllowAnonymous] + public async Task PostPasswordHint([FromBody] PasswordHintRequestModel model) + { + await _userService.SendMasterPasswordHintAsync(model.Email); + } + + [HttpPost("email-token")] + public async Task PostEmailToken([FromBody] EmailTokenRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) { - var result = await _userService.RegisterUserAsync(model.ToUser(), model.MasterPasswordHash, - model.Token, model.OrganizationUserId); - if (result.Succeeded) - { - return; - } + throw new UnauthorizedAccessException(); + } - foreach (var error in result.Errors.Where(e => e.Code != "DuplicateUserName")) - { - ModelState.AddModelError(string.Empty, error.Description); - } + if (user.UsesKeyConnector) + { + throw new BadRequestException("You cannot change your email when using Key Connector."); + } + if (!await _userService.CheckPasswordAsync(user, model.MasterPasswordHash)) + { await Task.Delay(2000); - throw new BadRequestException(ModelState); + throw new BadRequestException("MasterPasswordHash", "Invalid password."); } - #endregion + await _userService.InitiateEmailChangeAsync(user, model.NewEmail); + } - [HttpPost("password-hint")] - [AllowAnonymous] - public async Task PostPasswordHint([FromBody] PasswordHintRequestModel model) + [HttpPost("email")] + public async Task PostEmail([FromBody] EmailRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) { - await _userService.SendMasterPasswordHintAsync(model.Email); + throw new UnauthorizedAccessException(); } - [HttpPost("email-token")] - public async Task PostEmailToken([FromBody] EmailTokenRequestModel model) + if (user.UsesKeyConnector) { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (user.UsesKeyConnector) - { - throw new BadRequestException("You cannot change your email when using Key Connector."); - } - - if (!await _userService.CheckPasswordAsync(user, model.MasterPasswordHash)) - { - await Task.Delay(2000); - throw new BadRequestException("MasterPasswordHash", "Invalid password."); - } - - await _userService.InitiateEmailChangeAsync(user, model.NewEmail); + throw new BadRequestException("You cannot change your email when using Key Connector."); } - [HttpPost("email")] - public async Task PostEmail([FromBody] EmailRequestModel model) + var result = await _userService.ChangeEmailAsync(user, model.MasterPasswordHash, model.NewEmail, + model.NewMasterPasswordHash, model.Token, model.Key); + if (result.Succeeded) { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (user.UsesKeyConnector) - { - throw new BadRequestException("You cannot change your email when using Key Connector."); - } - - var result = await _userService.ChangeEmailAsync(user, model.MasterPasswordHash, model.NewEmail, - model.NewMasterPasswordHash, model.Token, model.Key); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - await Task.Delay(2000); - throw new BadRequestException(ModelState); + return; } - [HttpPost("verify-email")] - public async Task PostVerifyEmail() + foreach (var error in result.Errors) { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - await _userService.SendEmailVerificationAsync(user); + ModelState.AddModelError(string.Empty, error.Description); } - [HttpPost("verify-email-token")] - [AllowAnonymous] - public async Task PostVerifyEmailToken([FromBody] VerifyEmailRequestModel model) + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + [HttpPost("verify-email")] + public async Task PostVerifyEmail() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) { - var user = await _userService.GetUserByIdAsync(new Guid(model.UserId)); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - var result = await _userService.ConfirmEmailAsync(user, model.Token); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - await Task.Delay(2000); - throw new BadRequestException(ModelState); + throw new UnauthorizedAccessException(); } - [HttpPost("password")] - public async Task PostPassword([FromBody] PasswordRequestModel model) + await _userService.SendEmailVerificationAsync(user); + } + + [HttpPost("verify-email-token")] + [AllowAnonymous] + public async Task PostVerifyEmailToken([FromBody] VerifyEmailRequestModel model) + { + var user = await _userService.GetUserByIdAsync(new Guid(model.UserId)); + if (user == null) { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var result = await _userService.ChangePasswordAsync(user, model.MasterPasswordHash, - model.NewMasterPasswordHash, model.MasterPasswordHint, model.Key); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - await Task.Delay(2000); - throw new BadRequestException(ModelState); + throw new UnauthorizedAccessException(); + } + var result = await _userService.ConfirmEmailAsync(user, model.Token); + if (result.Succeeded) + { + return; } - [HttpPost("set-password")] - public async Task PostSetPasswordAsync([FromBody] SetPasswordRequestModel model) + foreach (var error in result.Errors) { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var result = await _userService.SetPasswordAsync(model.ToUser(user), model.MasterPasswordHash, model.Key, - model.OrgIdentifier); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - throw new BadRequestException(ModelState); + ModelState.AddModelError(string.Empty, error.Description); } - [HttpPost("verify-password")] - public async Task PostVerifyPassword([FromBody] SecretVerificationRequestModel model) + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + [HttpPost("password")] + public async Task PostPassword([FromBody] PasswordRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (await _userService.CheckPasswordAsync(user, model.MasterPasswordHash)) - { - return; - } - - ModelState.AddModelError(nameof(model.MasterPasswordHash), "Invalid password."); - await Task.Delay(2000); - throw new BadRequestException(ModelState); + throw new UnauthorizedAccessException(); } - [HttpPost("set-key-connector-key")] - public async Task PostSetKeyConnectorKeyAsync([FromBody] SetKeyConnectorKeyRequestModel model) + var result = await _userService.ChangePasswordAsync(user, model.MasterPasswordHash, + model.NewMasterPasswordHash, model.MasterPasswordHint, model.Key); + if (result.Succeeded) { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var result = await _userService.SetKeyConnectorKeyAsync(model.ToUser(user), model.Key, model.OrgIdentifier); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - throw new BadRequestException(ModelState); + return; } - [HttpPost("convert-to-key-connector")] - public async Task PostConvertToKeyConnector() + foreach (var error in result.Errors) { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var result = await _userService.ConvertToKeyConnectorAsync(user); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - throw new BadRequestException(ModelState); + ModelState.AddModelError(string.Empty, error.Description); } - [HttpPost("kdf")] - public async Task PostKdf([FromBody] KdfRequestModel model) + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + [HttpPost("set-password")] + public async Task PostSetPasswordAsync([FromBody] SetPasswordRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var result = await _userService.ChangeKdfAsync(user, model.MasterPasswordHash, - model.NewMasterPasswordHash, model.Key, model.Kdf.Value, model.KdfIterations.Value); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - await Task.Delay(2000); - throw new BadRequestException(ModelState); + throw new UnauthorizedAccessException(); } - [HttpPost("key")] - public async Task PostKey([FromBody] UpdateKeyRequestModel model) + var result = await _userService.SetPasswordAsync(model.ToUser(user), model.MasterPasswordHash, model.Key, + model.OrgIdentifier); + if (result.Succeeded) { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var ciphers = new List(); - if (model.Ciphers.Any()) - { - var existingCiphers = await _cipherRepository.GetManyByUserIdAsync(user.Id); - ciphers.AddRange(existingCiphers - .Join(model.Ciphers, c => c.Id, c => c.Id, (existing, c) => c.ToCipher(existing))); - } - - var folders = new List(); - if (model.Folders.Any()) - { - var existingFolders = await _folderRepository.GetManyByUserIdAsync(user.Id); - folders.AddRange(existingFolders - .Join(model.Folders, f => f.Id, f => f.Id, (existing, f) => f.ToFolder(existing))); - } - - var sends = new List(); - if (model.Sends?.Any() == true) - { - var existingSends = await _sendRepository.GetManyByUserIdAsync(user.Id); - sends.AddRange(existingSends - .Join(model.Sends, s => s.Id, s => s.Id, (existing, s) => s.ToSend(existing, _sendService))); - } - - var result = await _userService.UpdateKeyAsync( - user, - model.MasterPasswordHash, - model.Key, - model.PrivateKey, - ciphers, - folders, - sends); - - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - await Task.Delay(2000); - throw new BadRequestException(ModelState); + return; } - [HttpPost("security-stamp")] - public async Task PostSecurityStamp([FromBody] SecretVerificationRequestModel model) + foreach (var error in result.Errors) { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var result = await _userService.RefreshSecurityStampAsync(user, model.Secret); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - await Task.Delay(2000); - throw new BadRequestException(ModelState); + ModelState.AddModelError(string.Empty, error.Description); } - [HttpGet("profile")] - public async Task GetProfile() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } + throw new BadRequestException(ModelState); + } - var organizationUserDetails = await _organizationUserRepository.GetManyDetailsByUserAsync(user.Id, - OrganizationUserStatusType.Confirmed); - var providerUserDetails = await _providerUserRepository.GetManyDetailsByUserAsync(user.Id, + [HttpPost("verify-password")] + public async Task PostVerifyPassword([FromBody] SecretVerificationRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (await _userService.CheckPasswordAsync(user, model.MasterPasswordHash)) + { + return; + } + + ModelState.AddModelError(nameof(model.MasterPasswordHash), "Invalid password."); + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + [HttpPost("set-key-connector-key")] + public async Task PostSetKeyConnectorKeyAsync([FromBody] SetKeyConnectorKeyRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var result = await _userService.SetKeyConnectorKeyAsync(model.ToUser(user), model.Key, model.OrgIdentifier); + if (result.Succeeded) + { + return; + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + throw new BadRequestException(ModelState); + } + + [HttpPost("convert-to-key-connector")] + public async Task PostConvertToKeyConnector() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var result = await _userService.ConvertToKeyConnectorAsync(user); + if (result.Succeeded) + { + return; + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + throw new BadRequestException(ModelState); + } + + [HttpPost("kdf")] + public async Task PostKdf([FromBody] KdfRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var result = await _userService.ChangeKdfAsync(user, model.MasterPasswordHash, + model.NewMasterPasswordHash, model.Key, model.Kdf.Value, model.KdfIterations.Value); + if (result.Succeeded) + { + return; + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + [HttpPost("key")] + public async Task PostKey([FromBody] UpdateKeyRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var ciphers = new List(); + if (model.Ciphers.Any()) + { + var existingCiphers = await _cipherRepository.GetManyByUserIdAsync(user.Id); + ciphers.AddRange(existingCiphers + .Join(model.Ciphers, c => c.Id, c => c.Id, (existing, c) => c.ToCipher(existing))); + } + + var folders = new List(); + if (model.Folders.Any()) + { + var existingFolders = await _folderRepository.GetManyByUserIdAsync(user.Id); + folders.AddRange(existingFolders + .Join(model.Folders, f => f.Id, f => f.Id, (existing, f) => f.ToFolder(existing))); + } + + var sends = new List(); + if (model.Sends?.Any() == true) + { + var existingSends = await _sendRepository.GetManyByUserIdAsync(user.Id); + sends.AddRange(existingSends + .Join(model.Sends, s => s.Id, s => s.Id, (existing, s) => s.ToSend(existing, _sendService))); + } + + var result = await _userService.UpdateKeyAsync( + user, + model.MasterPasswordHash, + model.Key, + model.PrivateKey, + ciphers, + folders, + sends); + + if (result.Succeeded) + { + return; + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + [HttpPost("security-stamp")] + public async Task PostSecurityStamp([FromBody] SecretVerificationRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var result = await _userService.RefreshSecurityStampAsync(user, model.Secret); + if (result.Succeeded) + { + return; + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + [HttpGet("profile")] + public async Task GetProfile() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var organizationUserDetails = await _organizationUserRepository.GetManyDetailsByUserAsync(user.Id, + OrganizationUserStatusType.Confirmed); + var providerUserDetails = await _providerUserRepository.GetManyDetailsByUserAsync(user.Id, + ProviderUserStatusType.Confirmed); + var providerUserOrganizationDetails = + await _providerUserRepository.GetManyOrganizationDetailsByUserAsync(user.Id, ProviderUserStatusType.Confirmed); - var providerUserOrganizationDetails = - await _providerUserRepository.GetManyOrganizationDetailsByUserAsync(user.Id, - ProviderUserStatusType.Confirmed); - var response = new ProfileResponseModel(user, organizationUserDetails, providerUserDetails, - providerUserOrganizationDetails, await _userService.TwoFactorIsEnabledAsync(user), await _userService.HasPremiumFromOrganization(user)); - return response; + var response = new ProfileResponseModel(user, organizationUserDetails, providerUserDetails, + providerUserOrganizationDetails, await _userService.TwoFactorIsEnabledAsync(user), await _userService.HasPremiumFromOrganization(user)); + return response; + } + + [HttpGet("organizations")] + public async Task> GetOrganizations() + { + var userId = _userService.GetProperUserId(User); + var organizationUserDetails = await _organizationUserRepository.GetManyDetailsByUserAsync(userId.Value, + OrganizationUserStatusType.Confirmed); + var responseData = organizationUserDetails.Select(o => new ProfileOrganizationResponseModel(o)); + return new ListResponseModel(responseData); + } + + [HttpPut("profile")] + [HttpPost("profile")] + public async Task PutProfile([FromBody] UpdateProfileRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); } - [HttpGet("organizations")] - public async Task> GetOrganizations() + await _userService.SaveUserAsync(model.ToUser(user)); + var response = new ProfileResponseModel(user, null, null, null, await _userService.TwoFactorIsEnabledAsync(user), await _userService.HasPremiumFromOrganization(user)); + return response; + } + + [HttpGet("revision-date")] + public async Task GetAccountRevisionDate() + { + var userId = _userService.GetProperUserId(User); + long? revisionDate = null; + if (userId.HasValue) { - var userId = _userService.GetProperUserId(User); - var organizationUserDetails = await _organizationUserRepository.GetManyDetailsByUserAsync(userId.Value, - OrganizationUserStatusType.Confirmed); - var responseData = organizationUserDetails.Select(o => new ProfileOrganizationResponseModel(o)); - return new ListResponseModel(responseData); + var date = await _userService.GetAccountRevisionDateByIdAsync(userId.Value); + revisionDate = CoreHelpers.ToEpocMilliseconds(date); } - [HttpPut("profile")] - [HttpPost("profile")] - public async Task PutProfile([FromBody] UpdateProfileRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } + return revisionDate; + } - await _userService.SaveUserAsync(model.ToUser(user)); - var response = new ProfileResponseModel(user, null, null, null, await _userService.TwoFactorIsEnabledAsync(user), await _userService.HasPremiumFromOrganization(user)); - return response; + [HttpPost("keys")] + public async Task PostKeys([FromBody] KeysRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); } - [HttpGet("revision-date")] - public async Task GetAccountRevisionDate() - { - var userId = _userService.GetProperUserId(User); - long? revisionDate = null; - if (userId.HasValue) - { - var date = await _userService.GetAccountRevisionDateByIdAsync(userId.Value); - revisionDate = CoreHelpers.ToEpocMilliseconds(date); - } + await _userService.SaveUserAsync(model.ToUser(user)); + return new KeysResponseModel(user); + } - return revisionDate; + [HttpGet("keys")] + public async Task GetKeys() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); } - [HttpPost("keys")] - public async Task PostKeys([FromBody] KeysRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } + return new KeysResponseModel(user); + } - await _userService.SaveUserAsync(model.ToUser(user)); - return new KeysResponseModel(user); + [HttpDelete] + [HttpPost("delete")] + public async Task Delete([FromBody] SecretVerificationRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); } - [HttpGet("keys")] - public async Task GetKeys() + if (!await _userService.VerifySecretAsync(user, model.Secret)) { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - return new KeysResponseModel(user); - } - - [HttpDelete] - [HttpPost("delete")] - public async Task Delete([FromBody] SecretVerificationRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (!await _userService.VerifySecretAsync(user, model.Secret)) - { - ModelState.AddModelError(string.Empty, "User verification failed."); - await Task.Delay(2000); - } - else - { - var result = await _userService.DeleteAsync(user); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - } - - throw new BadRequestException(ModelState); - } - - [AllowAnonymous] - [HttpPost("delete-recover")] - public async Task PostDeleteRecover([FromBody] DeleteRecoverRequestModel model) - { - await _userService.SendDeleteConfirmationAsync(model.Email); - } - - [HttpPost("delete-recover-token")] - [AllowAnonymous] - public async Task PostDeleteRecoverToken([FromBody] VerifyDeleteRecoverRequestModel model) - { - var user = await _userService.GetUserByIdAsync(new Guid(model.UserId)); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var result = await _userService.DeleteAsync(user, model.Token); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - + ModelState.AddModelError(string.Empty, "User verification failed."); await Task.Delay(2000); - throw new BadRequestException(ModelState); } - - [HttpPost("iap-check")] - public async Task PostIapCheck([FromBody] IapCheckRequestModel model) + else { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - await _userService.IapCheckAsync(user, model.PaymentMethodType.Value); - } - - [HttpPost("premium")] - public async Task PostPremium(PremiumRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var valid = model.Validate(_globalSettings); - UserLicense license = null; - if (valid && _globalSettings.SelfHosted) - { - license = await ApiHelpers.ReadJsonFileFromBody(HttpContext, model.License); - } - - if (!valid && !_globalSettings.SelfHosted && string.IsNullOrWhiteSpace(model.Country)) - { - throw new BadRequestException("Country is required."); - } - - if (!valid || (_globalSettings.SelfHosted && license == null)) - { - throw new BadRequestException("Invalid license."); - } - - var result = await _userService.SignUpPremiumAsync(user, model.PaymentToken, - model.PaymentMethodType.Value, model.AdditionalStorageGb.GetValueOrDefault(0), license, - new TaxInfo - { - BillingAddressCountry = model.Country, - BillingAddressPostalCode = model.PostalCode, - }); - var profile = new ProfileResponseModel(user, null, null, null, await _userService.TwoFactorIsEnabledAsync(user), await _userService.HasPremiumFromOrganization(user)); - return new PaymentResponseModel - { - UserProfile = profile, - PaymentIntentClientSecret = result.Item2, - Success = result.Item1 - }; - } - - [Obsolete("2022-04-01 Use separate Billing History/Payment APIs, left for backwards compatability with older clients")] - [HttpGet("billing")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetBilling() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var billingInfo = await _paymentService.GetBillingAsync(user); - return new BillingResponseModel(billingInfo); - } - - [HttpGet("subscription")] - public async Task GetSubscription() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (!_globalSettings.SelfHosted && user.Gateway != null) - { - var subscriptionInfo = await _paymentService.GetSubscriptionAsync(user); - var license = await _userService.GenerateLicenseAsync(user, subscriptionInfo); - return new SubscriptionResponseModel(user, subscriptionInfo, license); - } - else if (!_globalSettings.SelfHosted) - { - var license = await _userService.GenerateLicenseAsync(user); - return new SubscriptionResponseModel(user, license); - } - else - { - return new SubscriptionResponseModel(user); - } - } - - [HttpPost("payment")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostPayment([FromBody] PaymentRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - await _userService.ReplacePaymentMethodAsync(user, model.PaymentToken, model.PaymentMethodType.Value, - new TaxInfo - { - BillingAddressCountry = model.Country, - BillingAddressPostalCode = model.PostalCode, - }); - } - - [HttpPost("storage")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostStorage([FromBody] StorageRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var result = await _userService.AdjustStorageAsync(user, model.StorageGbAdjustment.Value); - return new PaymentResponseModel - { - Success = true, - PaymentIntentClientSecret = result - }; - } - - [HttpPost("license")] - [SelfHosted(SelfHostedOnly = true)] - public async Task PostLicense(LicenseRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var license = await ApiHelpers.ReadJsonFileFromBody(HttpContext, model.License); - if (license == null) - { - throw new BadRequestException("Invalid license"); - } - - await _userService.UpdateLicenseAsync(user, license); - } - - [HttpPost("cancel-premium")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostCancel() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - await _userService.CancelPremiumAsync(user); - } - - [HttpPost("reinstate-premium")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostReinstate() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - await _userService.ReinstatePremiumAsync(user); - } - - [HttpGet("tax")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetTaxInfo() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var taxInfo = await _paymentService.GetTaxInfoAsync(user); - return new TaxInfoResponseModel(taxInfo); - } - - [HttpPut("tax")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PutTaxInfo([FromBody] TaxInfoUpdateRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var taxInfo = new TaxInfo - { - BillingAddressPostalCode = model.PostalCode, - BillingAddressCountry = model.Country, - }; - await _paymentService.SaveTaxInfoAsync(user, taxInfo); - } - - [HttpDelete("sso/{organizationId}")] - public async Task DeleteSsoUser(string organizationId) - { - var userId = _userService.GetProperUserId(User); - if (!userId.HasValue) - { - throw new NotFoundException(); - } - - await _organizationService.DeleteSsoUserAsync(userId.Value, new Guid(organizationId)); - } - - [HttpGet("sso/user-identifier")] - public async Task GetSsoUserIdentifier() - { - var user = await _userService.GetUserByPrincipalAsync(User); - var token = await _userService.GenerateSignInTokenAsync(user, TokenPurposes.LinkSso); - var userIdentifier = $"{user.Id},{token}"; - return userIdentifier; - } - - [HttpPost("api-key")] - public async Task ApiKey([FromBody] SecretVerificationRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (!await _userService.VerifySecretAsync(user, model.Secret)) - { - await Task.Delay(2000); - throw new BadRequestException(string.Empty, "User verification failed."); - } - - return new ApiKeyResponseModel(user); - } - - [HttpPost("rotate-api-key")] - public async Task RotateApiKey([FromBody] SecretVerificationRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (!await _userService.VerifySecretAsync(user, model.Secret)) - { - await Task.Delay(2000); - throw new BadRequestException(string.Empty, "User verification failed."); - } - - await _userService.RotateApiKeyAsync(user); - var response = new ApiKeyResponseModel(user); - return response; - } - - [HttpPut("update-temp-password")] - public async Task PutUpdateTempPasswordAsync([FromBody] UpdateTempPasswordRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var result = await _userService.UpdateTempPasswordAsync(user, model.NewMasterPasswordHash, model.Key, model.MasterPasswordHint); + var result = await _userService.DeleteAsync(user); if (result.Succeeded) { return; @@ -854,36 +532,357 @@ namespace Bit.Api.Controllers { ModelState.AddModelError(string.Empty, error.Description); } - - throw new BadRequestException(ModelState); } - [HttpPost("request-otp")] - public async Task PostRequestOTP() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user is not { UsesKeyConnector: true }) - { - throw new UnauthorizedAccessException(); - } + throw new BadRequestException(ModelState); + } - await _userService.SendOTPAsync(user); + [AllowAnonymous] + [HttpPost("delete-recover")] + public async Task PostDeleteRecover([FromBody] DeleteRecoverRequestModel model) + { + await _userService.SendDeleteConfirmationAsync(model.Email); + } + + [HttpPost("delete-recover-token")] + [AllowAnonymous] + public async Task PostDeleteRecoverToken([FromBody] VerifyDeleteRecoverRequestModel model) + { + var user = await _userService.GetUserByIdAsync(new Guid(model.UserId)); + if (user == null) + { + throw new UnauthorizedAccessException(); } - [HttpPost("verify-otp")] - public async Task VerifyOTP([FromBody] VerifyOTPRequestModel model) + var result = await _userService.DeleteAsync(user, model.Token); + if (result.Succeeded) { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user is not { UsesKeyConnector: true }) - { - throw new UnauthorizedAccessException(); - } + return; + } - if (!await _userService.VerifyOTPAsync(user, model.OTP)) + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + [HttpPost("iap-check")] + public async Task PostIapCheck([FromBody] IapCheckRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + await _userService.IapCheckAsync(user, model.PaymentMethodType.Value); + } + + [HttpPost("premium")] + public async Task PostPremium(PremiumRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var valid = model.Validate(_globalSettings); + UserLicense license = null; + if (valid && _globalSettings.SelfHosted) + { + license = await ApiHelpers.ReadJsonFileFromBody(HttpContext, model.License); + } + + if (!valid && !_globalSettings.SelfHosted && string.IsNullOrWhiteSpace(model.Country)) + { + throw new BadRequestException("Country is required."); + } + + if (!valid || (_globalSettings.SelfHosted && license == null)) + { + throw new BadRequestException("Invalid license."); + } + + var result = await _userService.SignUpPremiumAsync(user, model.PaymentToken, + model.PaymentMethodType.Value, model.AdditionalStorageGb.GetValueOrDefault(0), license, + new TaxInfo { - await Task.Delay(2000); - throw new BadRequestException("Token", "Invalid token"); - } + BillingAddressCountry = model.Country, + BillingAddressPostalCode = model.PostalCode, + }); + var profile = new ProfileResponseModel(user, null, null, null, await _userService.TwoFactorIsEnabledAsync(user), await _userService.HasPremiumFromOrganization(user)); + return new PaymentResponseModel + { + UserProfile = profile, + PaymentIntentClientSecret = result.Item2, + Success = result.Item1 + }; + } + + [Obsolete("2022-04-01 Use separate Billing History/Payment APIs, left for backwards compatability with older clients")] + [HttpGet("billing")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task GetBilling() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var billingInfo = await _paymentService.GetBillingAsync(user); + return new BillingResponseModel(billingInfo); + } + + [HttpGet("subscription")] + public async Task GetSubscription() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (!_globalSettings.SelfHosted && user.Gateway != null) + { + var subscriptionInfo = await _paymentService.GetSubscriptionAsync(user); + var license = await _userService.GenerateLicenseAsync(user, subscriptionInfo); + return new SubscriptionResponseModel(user, subscriptionInfo, license); + } + else if (!_globalSettings.SelfHosted) + { + var license = await _userService.GenerateLicenseAsync(user); + return new SubscriptionResponseModel(user, license); + } + else + { + return new SubscriptionResponseModel(user); + } + } + + [HttpPost("payment")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostPayment([FromBody] PaymentRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + await _userService.ReplacePaymentMethodAsync(user, model.PaymentToken, model.PaymentMethodType.Value, + new TaxInfo + { + BillingAddressCountry = model.Country, + BillingAddressPostalCode = model.PostalCode, + }); + } + + [HttpPost("storage")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostStorage([FromBody] StorageRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var result = await _userService.AdjustStorageAsync(user, model.StorageGbAdjustment.Value); + return new PaymentResponseModel + { + Success = true, + PaymentIntentClientSecret = result + }; + } + + [HttpPost("license")] + [SelfHosted(SelfHostedOnly = true)] + public async Task PostLicense(LicenseRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var license = await ApiHelpers.ReadJsonFileFromBody(HttpContext, model.License); + if (license == null) + { + throw new BadRequestException("Invalid license"); + } + + await _userService.UpdateLicenseAsync(user, license); + } + + [HttpPost("cancel-premium")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostCancel() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + await _userService.CancelPremiumAsync(user); + } + + [HttpPost("reinstate-premium")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostReinstate() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + await _userService.ReinstatePremiumAsync(user); + } + + [HttpGet("tax")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task GetTaxInfo() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var taxInfo = await _paymentService.GetTaxInfoAsync(user); + return new TaxInfoResponseModel(taxInfo); + } + + [HttpPut("tax")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PutTaxInfo([FromBody] TaxInfoUpdateRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var taxInfo = new TaxInfo + { + BillingAddressPostalCode = model.PostalCode, + BillingAddressCountry = model.Country, + }; + await _paymentService.SaveTaxInfoAsync(user, taxInfo); + } + + [HttpDelete("sso/{organizationId}")] + public async Task DeleteSsoUser(string organizationId) + { + var userId = _userService.GetProperUserId(User); + if (!userId.HasValue) + { + throw new NotFoundException(); + } + + await _organizationService.DeleteSsoUserAsync(userId.Value, new Guid(organizationId)); + } + + [HttpGet("sso/user-identifier")] + public async Task GetSsoUserIdentifier() + { + var user = await _userService.GetUserByPrincipalAsync(User); + var token = await _userService.GenerateSignInTokenAsync(user, TokenPurposes.LinkSso); + var userIdentifier = $"{user.Id},{token}"; + return userIdentifier; + } + + [HttpPost("api-key")] + public async Task ApiKey([FromBody] SecretVerificationRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (!await _userService.VerifySecretAsync(user, model.Secret)) + { + await Task.Delay(2000); + throw new BadRequestException(string.Empty, "User verification failed."); + } + + return new ApiKeyResponseModel(user); + } + + [HttpPost("rotate-api-key")] + public async Task RotateApiKey([FromBody] SecretVerificationRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (!await _userService.VerifySecretAsync(user, model.Secret)) + { + await Task.Delay(2000); + throw new BadRequestException(string.Empty, "User verification failed."); + } + + await _userService.RotateApiKeyAsync(user); + var response = new ApiKeyResponseModel(user); + return response; + } + + [HttpPut("update-temp-password")] + public async Task PutUpdateTempPasswordAsync([FromBody] UpdateTempPasswordRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var result = await _userService.UpdateTempPasswordAsync(user, model.NewMasterPasswordHash, model.Key, model.MasterPasswordHint); + if (result.Succeeded) + { + return; + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + throw new BadRequestException(ModelState); + } + + [HttpPost("request-otp")] + public async Task PostRequestOTP() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user is not { UsesKeyConnector: true }) + { + throw new UnauthorizedAccessException(); + } + + await _userService.SendOTPAsync(user); + } + + [HttpPost("verify-otp")] + public async Task VerifyOTP([FromBody] VerifyOTPRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user is not { UsesKeyConnector: true }) + { + throw new UnauthorizedAccessException(); + } + + if (!await _userService.VerifyOTPAsync(user, model.OTP)) + { + await Task.Delay(2000); + throw new BadRequestException("Token", "Invalid token"); } } } diff --git a/src/Api/Controllers/CiphersController.cs b/src/Api/Controllers/CiphersController.cs index f5831acaa..5b059a332 100644 --- a/src/Api/Controllers/CiphersController.cs +++ b/src/Api/Controllers/CiphersController.cs @@ -18,789 +18,788 @@ using Core.Models.Data; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("ciphers")] +[Authorize("Application")] +public class CiphersController : Controller { - [Route("ciphers")] - [Authorize("Application")] - public class CiphersController : Controller + private readonly ICipherRepository _cipherRepository; + private readonly ICollectionCipherRepository _collectionCipherRepository; + private readonly ICipherService _cipherService; + private readonly IUserService _userService; + private readonly IAttachmentStorageService _attachmentStorageService; + private readonly IProviderService _providerService; + private readonly ICurrentContext _currentContext; + private readonly ILogger _logger; + private readonly GlobalSettings _globalSettings; + + public CiphersController( + ICipherRepository cipherRepository, + ICollectionCipherRepository collectionCipherRepository, + ICipherService cipherService, + IUserService userService, + IAttachmentStorageService attachmentStorageService, + IProviderService providerService, + ICurrentContext currentContext, + ILogger logger, + GlobalSettings globalSettings) { - private readonly ICipherRepository _cipherRepository; - private readonly ICollectionCipherRepository _collectionCipherRepository; - private readonly ICipherService _cipherService; - private readonly IUserService _userService; - private readonly IAttachmentStorageService _attachmentStorageService; - private readonly IProviderService _providerService; - private readonly ICurrentContext _currentContext; - private readonly ILogger _logger; - private readonly GlobalSettings _globalSettings; + _cipherRepository = cipherRepository; + _collectionCipherRepository = collectionCipherRepository; + _cipherService = cipherService; + _userService = userService; + _attachmentStorageService = attachmentStorageService; + _providerService = providerService; + _currentContext = currentContext; + _logger = logger; + _globalSettings = globalSettings; + } - public CiphersController( - ICipherRepository cipherRepository, - ICollectionCipherRepository collectionCipherRepository, - ICipherService cipherService, - IUserService userService, - IAttachmentStorageService attachmentStorageService, - IProviderService providerService, - ICurrentContext currentContext, - ILogger logger, - GlobalSettings globalSettings) + [HttpGet("{id}")] + public async Task Get(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); + if (cipher == null) { - _cipherRepository = cipherRepository; - _collectionCipherRepository = collectionCipherRepository; - _cipherService = cipherService; - _userService = userService; - _attachmentStorageService = attachmentStorageService; - _providerService = providerService; - _currentContext = currentContext; - _logger = logger; - _globalSettings = globalSettings; + throw new NotFoundException(); } - [HttpGet("{id}")] - public async Task Get(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); - if (cipher == null) - { - throw new NotFoundException(); - } + return new CipherResponseModel(cipher, _globalSettings); + } - return new CipherResponseModel(cipher, _globalSettings); + [HttpGet("{id}/admin")] + public async Task GetAdmin(string id) + { + var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(new Guid(id)); + if (cipher == null || !cipher.OrganizationId.HasValue || + !await _currentContext.ViewAllCollections(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); } - [HttpGet("{id}/admin")] - public async Task GetAdmin(string id) - { - var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(new Guid(id)); - if (cipher == null || !cipher.OrganizationId.HasValue || - !await _currentContext.ViewAllCollections(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } + return new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp); + } - return new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp); + [HttpGet("{id}/full-details")] + [HttpGet("{id}/details")] + public async Task GetDetails(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var cipherId = new Guid(id); + var cipher = await _cipherRepository.GetByIdAsync(cipherId, userId); + if (cipher == null) + { + throw new NotFoundException(); } - [HttpGet("{id}/full-details")] - [HttpGet("{id}/details")] - public async Task GetDetails(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var cipherId = new Guid(id); - var cipher = await _cipherRepository.GetByIdAsync(cipherId, userId); - if (cipher == null) - { - throw new NotFoundException(); - } + var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(userId, cipherId); + return new CipherDetailsResponseModel(cipher, _globalSettings, collectionCiphers); + } - var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(userId, cipherId); - return new CipherDetailsResponseModel(cipher, _globalSettings, collectionCiphers); + [HttpGet("")] + public async Task> Get() + { + var userId = _userService.GetProperUserId(User).Value; + var hasOrgs = _currentContext.Organizations?.Any() ?? false; + // TODO: Use hasOrgs proper for cipher listing here? + var ciphers = await _cipherRepository.GetManyByUserIdAsync(userId, true || hasOrgs); + Dictionary> collectionCiphersGroupDict = null; + if (hasOrgs) + { + var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdAsync(userId); + collectionCiphersGroupDict = collectionCiphers.GroupBy(c => c.CipherId).ToDictionary(s => s.Key); } - [HttpGet("")] - public async Task> Get() - { - var userId = _userService.GetProperUserId(User).Value; - var hasOrgs = _currentContext.Organizations?.Any() ?? false; - // TODO: Use hasOrgs proper for cipher listing here? - var ciphers = await _cipherRepository.GetManyByUserIdAsync(userId, true || hasOrgs); - Dictionary> collectionCiphersGroupDict = null; - if (hasOrgs) - { - var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdAsync(userId); - collectionCiphersGroupDict = collectionCiphers.GroupBy(c => c.CipherId).ToDictionary(s => s.Key); - } + var responses = ciphers.Select(c => new CipherDetailsResponseModel(c, _globalSettings, + collectionCiphersGroupDict)).ToList(); + return new ListResponseModel(responses); + } - var responses = ciphers.Select(c => new CipherDetailsResponseModel(c, _globalSettings, - collectionCiphersGroupDict)).ToList(); - return new ListResponseModel(responses); + [HttpPost("")] + public async Task Post([FromBody] CipherRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = model.ToCipherDetails(userId); + if (cipher.OrganizationId.HasValue && !await _currentContext.OrganizationUser(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); } - [HttpPost("")] - public async Task Post([FromBody] CipherRequestModel model) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = model.ToCipherDetails(userId); - if (cipher.OrganizationId.HasValue && !await _currentContext.OrganizationUser(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } + await _cipherService.SaveDetailsAsync(cipher, userId, model.LastKnownRevisionDate, null, cipher.OrganizationId.HasValue); + var response = new CipherResponseModel(cipher, _globalSettings); + return response; + } - await _cipherService.SaveDetailsAsync(cipher, userId, model.LastKnownRevisionDate, null, cipher.OrganizationId.HasValue); - var response = new CipherResponseModel(cipher, _globalSettings); - return response; + [HttpPost("create")] + public async Task PostCreate([FromBody] CipherCreateRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = model.Cipher.ToCipherDetails(userId); + if (cipher.OrganizationId.HasValue && !await _currentContext.OrganizationUser(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); } - [HttpPost("create")] - public async Task PostCreate([FromBody] CipherCreateRequestModel model) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = model.Cipher.ToCipherDetails(userId); - if (cipher.OrganizationId.HasValue && !await _currentContext.OrganizationUser(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } + await _cipherService.SaveDetailsAsync(cipher, userId, model.Cipher.LastKnownRevisionDate, model.CollectionIds, cipher.OrganizationId.HasValue); + var response = new CipherResponseModel(cipher, _globalSettings); + return response; + } - await _cipherService.SaveDetailsAsync(cipher, userId, model.Cipher.LastKnownRevisionDate, model.CollectionIds, cipher.OrganizationId.HasValue); - var response = new CipherResponseModel(cipher, _globalSettings); - return response; + [HttpPost("admin")] + public async Task PostAdmin([FromBody] CipherCreateRequestModel model) + { + var cipher = model.Cipher.ToOrganizationCipher(); + if (!await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); } - [HttpPost("admin")] - public async Task PostAdmin([FromBody] CipherCreateRequestModel model) + var userId = _userService.GetProperUserId(User).Value; + await _cipherService.SaveAsync(cipher, userId, model.Cipher.LastKnownRevisionDate, model.CollectionIds, true, false); + + var response = new CipherMiniResponseModel(cipher, _globalSettings, false); + return response; + } + + [HttpPut("{id}")] + [HttpPost("{id}")] + public async Task Put(Guid id, [FromBody] CipherRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(id, userId); + if (cipher == null) { - var cipher = model.Cipher.ToOrganizationCipher(); - if (!await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User).Value; - await _cipherService.SaveAsync(cipher, userId, model.Cipher.LastKnownRevisionDate, model.CollectionIds, true, false); - - var response = new CipherMiniResponseModel(cipher, _globalSettings, false); - return response; + throw new NotFoundException(); } - [HttpPut("{id}")] - [HttpPost("{id}")] - public async Task Put(Guid id, [FromBody] CipherRequestModel model) + var collectionIds = (await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(userId, id)).Select(c => c.CollectionId).ToList(); + var modelOrgId = string.IsNullOrWhiteSpace(model.OrganizationId) ? + (Guid?)null : new Guid(model.OrganizationId); + if (cipher.OrganizationId != modelOrgId) { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(id, userId); - if (cipher == null) - { - throw new NotFoundException(); - } - - var collectionIds = (await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(userId, id)).Select(c => c.CollectionId).ToList(); - var modelOrgId = string.IsNullOrWhiteSpace(model.OrganizationId) ? - (Guid?)null : new Guid(model.OrganizationId); - if (cipher.OrganizationId != modelOrgId) - { - throw new BadRequestException("Organization mismatch. Re-sync if you recently moved this item, " + - "then try again."); - } - - await _cipherService.SaveDetailsAsync(model.ToCipherDetails(cipher), userId, model.LastKnownRevisionDate, collectionIds); - - var response = new CipherResponseModel(cipher, _globalSettings); - return response; + throw new BadRequestException("Organization mismatch. Re-sync if you recently moved this item, " + + "then try again."); } - [HttpPut("{id}/admin")] - [HttpPost("{id}/admin")] - public async Task PutAdmin(Guid id, [FromBody] CipherRequestModel model) + await _cipherService.SaveDetailsAsync(model.ToCipherDetails(cipher), userId, model.LastKnownRevisionDate, collectionIds); + + var response = new CipherResponseModel(cipher, _globalSettings); + return response; + } + + [HttpPut("{id}/admin")] + [HttpPost("{id}/admin")] + public async Task PutAdmin(Guid id, [FromBody] CipherRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(id); + if (cipher == null || !cipher.OrganizationId.HasValue || + !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(id); - if (cipher == null || !cipher.OrganizationId.HasValue || - !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } - - var collectionIds = (await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(userId, id)).Select(c => c.CollectionId).ToList(); - // object cannot be a descendant of CipherDetails, so let's clone it. - var cipherClone = model.ToCipher(cipher).Clone(); - await _cipherService.SaveAsync(cipherClone, userId, model.LastKnownRevisionDate, collectionIds, true, false); - - var response = new CipherMiniResponseModel(cipherClone, _globalSettings, cipher.OrganizationUseTotp); - return response; + throw new NotFoundException(); } - [HttpGet("organization-details")] - public async Task> GetOrganizationCollections( - string organizationId) + var collectionIds = (await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(userId, id)).Select(c => c.CollectionId).ToList(); + // object cannot be a descendant of CipherDetails, so let's clone it. + var cipherClone = model.ToCipher(cipher).Clone(); + await _cipherService.SaveAsync(cipherClone, userId, model.LastKnownRevisionDate, collectionIds, true, false); + + var response = new CipherMiniResponseModel(cipherClone, _globalSettings, cipher.OrganizationUseTotp); + return response; + } + + [HttpGet("organization-details")] + public async Task> GetOrganizationCollections( + string organizationId) + { + var userId = _userService.GetProperUserId(User).Value; + var orgIdGuid = new Guid(organizationId); + + (IEnumerable orgCiphers, Dictionary> collectionCiphersGroupDict) = await _cipherService.GetOrganizationCiphers(userId, orgIdGuid); + + var responses = orgCiphers.Select(c => new CipherMiniDetailsResponseModel(c, _globalSettings, + collectionCiphersGroupDict, c.OrganizationUseTotp)); + + return new ListResponseModel(responses); + } + + [HttpPost("import")] + public async Task PostImport([FromBody] ImportCiphersRequestModel model) + { + if (!_globalSettings.SelfHosted && + (model.Ciphers.Count() > 6000 || model.FolderRelationships.Count() > 6000 || + model.Folders.Count() > 1000)) { - var userId = _userService.GetProperUserId(User).Value; - var orgIdGuid = new Guid(organizationId); - - (IEnumerable orgCiphers, Dictionary> collectionCiphersGroupDict) = await _cipherService.GetOrganizationCiphers(userId, orgIdGuid); - - var responses = orgCiphers.Select(c => new CipherMiniDetailsResponseModel(c, _globalSettings, - collectionCiphersGroupDict, c.OrganizationUseTotp)); - - return new ListResponseModel(responses); + throw new BadRequestException("You cannot import this much data at once."); } - [HttpPost("import")] - public async Task PostImport([FromBody] ImportCiphersRequestModel model) - { - if (!_globalSettings.SelfHosted && - (model.Ciphers.Count() > 6000 || model.FolderRelationships.Count() > 6000 || - model.Folders.Count() > 1000)) - { - throw new BadRequestException("You cannot import this much data at once."); - } + var userId = _userService.GetProperUserId(User).Value; + var folders = model.Folders.Select(f => f.ToFolder(userId)).ToList(); + var ciphers = model.Ciphers.Select(c => c.ToCipherDetails(userId, false)).ToList(); + await _cipherService.ImportCiphersAsync(folders, ciphers, model.FolderRelationships); + } - var userId = _userService.GetProperUserId(User).Value; - var folders = model.Folders.Select(f => f.ToFolder(userId)).ToList(); - var ciphers = model.Ciphers.Select(c => c.ToCipherDetails(userId, false)).ToList(); - await _cipherService.ImportCiphersAsync(folders, ciphers, model.FolderRelationships); + [HttpPost("import-organization")] + public async Task PostImport([FromQuery] string organizationId, + [FromBody] ImportOrganizationCiphersRequestModel model) + { + if (!_globalSettings.SelfHosted && + (model.Ciphers.Count() > 6000 || model.CollectionRelationships.Count() > 12000 || + model.Collections.Count() > 1000)) + { + throw new BadRequestException("You cannot import this much data at once."); } - [HttpPost("import-organization")] - public async Task PostImport([FromQuery] string organizationId, - [FromBody] ImportOrganizationCiphersRequestModel model) + var orgId = new Guid(organizationId); + if (!await _currentContext.AccessImportExport(orgId)) { - if (!_globalSettings.SelfHosted && - (model.Ciphers.Count() > 6000 || model.CollectionRelationships.Count() > 12000 || - model.Collections.Count() > 1000)) + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User).Value; + var collections = model.Collections.Select(c => c.ToCollection(orgId)).ToList(); + var ciphers = model.Ciphers.Select(l => l.ToOrganizationCipherDetails(orgId)).ToList(); + await _cipherService.ImportCiphersAsync(collections, ciphers, model.CollectionRelationships, userId); + } + + [HttpPut("{id}/partial")] + [HttpPost("{id}/partial")] + public async Task PutPartial(string id, [FromBody] CipherPartialRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var folderId = string.IsNullOrWhiteSpace(model.FolderId) ? null : (Guid?)new Guid(model.FolderId); + await _cipherRepository.UpdatePartialAsync(new Guid(id), userId, folderId, model.Favorite); + } + + [HttpPut("{id}/share")] + [HttpPost("{id}/share")] + public async Task PutShare(string id, [FromBody] CipherShareRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var cipherId = new Guid(id); + var cipher = await _cipherRepository.GetByIdAsync(cipherId); + if (cipher == null || cipher.UserId != userId || + !await _currentContext.OrganizationUser(new Guid(model.Cipher.OrganizationId))) + { + throw new NotFoundException(); + } + + var original = cipher.Clone(); + await _cipherService.ShareAsync(original, model.Cipher.ToCipher(cipher), new Guid(model.Cipher.OrganizationId), + model.CollectionIds.Select(c => new Guid(c)), userId, model.Cipher.LastKnownRevisionDate); + + var sharedCipher = await _cipherRepository.GetByIdAsync(cipherId, userId); + var response = new CipherResponseModel(sharedCipher, _globalSettings); + return response; + } + + [HttpPut("{id}/collections")] + [HttpPost("{id}/collections")] + public async Task PutCollections(string id, [FromBody] CipherCollectionsRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); + if (cipher == null || !cipher.OrganizationId.HasValue || + !await _currentContext.OrganizationUser(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); + } + + await _cipherService.SaveCollectionsAsync(cipher, + model.CollectionIds.Select(c => new Guid(c)), userId, false); + } + + [HttpPut("{id}/collections-admin")] + [HttpPost("{id}/collections-admin")] + public async Task PutCollectionsAdmin(string id, [FromBody] CipherCollectionsRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); + if (cipher == null || !cipher.OrganizationId.HasValue || + !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); + } + + await _cipherService.SaveCollectionsAsync(cipher, + model.CollectionIds.Select(c => new Guid(c)), userId, true); + } + + [HttpDelete("{id}")] + [HttpPost("{id}/delete")] + public async Task Delete(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); + if (cipher == null) + { + throw new NotFoundException(); + } + + await _cipherService.DeleteAsync(cipher, userId); + } + + [HttpDelete("{id}/admin")] + [HttpPost("{id}/delete-admin")] + public async Task DeleteAdmin(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); + if (cipher == null || !cipher.OrganizationId.HasValue || + !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); + } + + await _cipherService.DeleteAsync(cipher, userId, true); + } + + [HttpDelete("")] + [HttpPost("delete")] + public async Task DeleteMany([FromBody] CipherBulkDeleteRequestModel model) + { + if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) + { + throw new BadRequestException("You can only delete up to 500 items at a time. " + + "Consider using the \"Purge Vault\" option instead."); + } + + var userId = _userService.GetProperUserId(User).Value; + await _cipherService.DeleteManyAsync(model.Ids.Select(i => new Guid(i)), userId); + } + + [HttpDelete("admin")] + [HttpPost("delete-admin")] + public async Task DeleteManyAdmin([FromBody] CipherBulkDeleteRequestModel model) + { + if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) + { + throw new BadRequestException("You can only delete up to 500 items at a time. " + + "Consider using the \"Purge Vault\" option instead."); + } + + if (model == null || string.IsNullOrWhiteSpace(model.OrganizationId) || + !await _currentContext.EditAnyCollection(new Guid(model.OrganizationId))) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User).Value; + await _cipherService.DeleteManyAsync(model.Ids.Select(i => new Guid(i)), userId, new Guid(model.OrganizationId), true); + } + + [HttpPut("{id}/delete")] + public async Task PutDelete(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); + if (cipher == null) + { + throw new NotFoundException(); + } + await _cipherService.SoftDeleteAsync(cipher, userId); + } + + [HttpPut("{id}/delete-admin")] + public async Task PutDeleteAdmin(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); + if (cipher == null || !cipher.OrganizationId.HasValue || + !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); + } + + await _cipherService.SoftDeleteAsync(cipher, userId, true); + } + + [HttpPut("delete")] + public async Task PutDeleteMany([FromBody] CipherBulkDeleteRequestModel model) + { + if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) + { + throw new BadRequestException("You can only delete up to 500 items at a time."); + } + + var userId = _userService.GetProperUserId(User).Value; + await _cipherService.SoftDeleteManyAsync(model.Ids.Select(i => new Guid(i)), userId); + } + + [HttpPut("delete-admin")] + public async Task PutDeleteManyAdmin([FromBody] CipherBulkDeleteRequestModel model) + { + if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) + { + throw new BadRequestException("You can only delete up to 500 items at a time."); + } + + if (model == null || string.IsNullOrWhiteSpace(model.OrganizationId) || + !await _currentContext.EditAnyCollection(new Guid(model.OrganizationId))) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User).Value; + await _cipherService.SoftDeleteManyAsync(model.Ids.Select(i => new Guid(i)), userId, new Guid(model.OrganizationId), true); + } + + [HttpPut("{id}/restore")] + public async Task PutRestore(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); + if (cipher == null) + { + throw new NotFoundException(); + } + + await _cipherService.RestoreAsync(cipher, userId); + return new CipherResponseModel(cipher, _globalSettings); + } + + [HttpPut("{id}/restore-admin")] + public async Task PutRestoreAdmin(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(new Guid(id)); + if (cipher == null || !cipher.OrganizationId.HasValue || + !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); + } + + await _cipherService.RestoreAsync(cipher, userId, true); + return new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp); + } + + [HttpPut("restore")] + public async Task> PutRestoreMany([FromBody] CipherBulkRestoreRequestModel model) + { + if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) + { + throw new BadRequestException("You can only restore up to 500 items at a time."); + } + + var userId = _userService.GetProperUserId(User).Value; + var cipherIdsToRestore = new HashSet(model.Ids.Select(i => new Guid(i))); + + var ciphers = await _cipherRepository.GetManyByUserIdAsync(userId); + var restoringCiphers = ciphers.Where(c => cipherIdsToRestore.Contains(c.Id) && c.Edit); + + await _cipherService.RestoreManyAsync(restoringCiphers, userId); + var responses = restoringCiphers.Select(c => new CipherResponseModel(c, _globalSettings)); + return new ListResponseModel(responses); + } + + [HttpPut("move")] + [HttpPost("move")] + public async Task MoveMany([FromBody] CipherBulkMoveRequestModel model) + { + if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) + { + throw new BadRequestException("You can only move up to 500 items at a time."); + } + + var userId = _userService.GetProperUserId(User).Value; + await _cipherService.MoveManyAsync(model.Ids.Select(i => new Guid(i)), + string.IsNullOrWhiteSpace(model.FolderId) ? (Guid?)null : new Guid(model.FolderId), userId); + } + + [HttpPut("share")] + [HttpPost("share")] + public async Task PutShareMany([FromBody] CipherBulkShareRequestModel model) + { + var organizationId = new Guid(model.Ciphers.First().OrganizationId); + if (!await _currentContext.OrganizationUser(organizationId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User).Value; + var ciphers = await _cipherRepository.GetManyByUserIdAsync(userId, false); + var ciphersDict = ciphers.ToDictionary(c => c.Id); + + var shareCiphers = new List<(Cipher, DateTime?)>(); + foreach (var cipher in model.Ciphers) + { + if (!ciphersDict.ContainsKey(cipher.Id.Value)) { - throw new BadRequestException("You cannot import this much data at once."); + throw new BadRequestException("Trying to move ciphers that you do not own."); } + shareCiphers.Add((cipher.ToCipher(ciphersDict[cipher.Id.Value]), cipher.LastKnownRevisionDate)); + } + + await _cipherService.ShareManyAsync(shareCiphers, organizationId, + model.CollectionIds.Select(c => new Guid(c)), userId); + } + + [HttpPost("purge")] + public async Task PostPurge([FromBody] SecretVerificationRequestModel model, string organizationId = null) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (!await _userService.VerifySecretAsync(user, model.Secret)) + { + ModelState.AddModelError(string.Empty, "User verification failed."); + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + if (string.IsNullOrWhiteSpace(organizationId)) + { + await _cipherRepository.DeleteByUserIdAsync(user.Id); + } + else + { var orgId = new Guid(organizationId); - if (!await _currentContext.AccessImportExport(orgId)) + if (!await _currentContext.EditAnyCollection(orgId)) { throw new NotFoundException(); } + await _cipherService.PurgeAsync(orgId); + } + } - var userId = _userService.GetProperUserId(User).Value; - var collections = model.Collections.Select(c => c.ToCollection(orgId)).ToList(); - var ciphers = model.Ciphers.Select(l => l.ToOrganizationCipherDetails(orgId)).ToList(); - await _cipherService.ImportCiphersAsync(collections, ciphers, model.CollectionRelationships, userId); + [HttpPost("{id}/attachment/v2")] + public async Task PostAttachment(string id, [FromBody] AttachmentRequestModel request) + { + var idGuid = new Guid(id); + var userId = _userService.GetProperUserId(User).Value; + var cipher = request.AdminRequest ? + await _cipherRepository.GetOrganizationDetailsByIdAsync(idGuid) : + await _cipherRepository.GetByIdAsync(idGuid, userId); + + if (cipher == null || (request.AdminRequest && (!cipher.OrganizationId.HasValue || + !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)))) + { + throw new NotFoundException(); } - [HttpPut("{id}/partial")] - [HttpPost("{id}/partial")] - public async Task PutPartial(string id, [FromBody] CipherPartialRequestModel model) + if (request.FileSize > CipherService.MAX_FILE_SIZE) { - var userId = _userService.GetProperUserId(User).Value; - var folderId = string.IsNullOrWhiteSpace(model.FolderId) ? null : (Guid?)new Guid(model.FolderId); - await _cipherRepository.UpdatePartialAsync(new Guid(id), userId, folderId, model.Favorite); + throw new BadRequestException($"Max file size is {CipherService.MAX_FILE_SIZE_READABLE}."); } - [HttpPut("{id}/share")] - [HttpPost("{id}/share")] - public async Task PutShare(string id, [FromBody] CipherShareRequestModel model) + var (attachmentId, uploadUrl) = await _cipherService.CreateAttachmentForDelayedUploadAsync(cipher, + request.Key, request.FileName, request.FileSize, request.AdminRequest, userId); + return new AttachmentUploadDataResponseModel { - var userId = _userService.GetProperUserId(User).Value; - var cipherId = new Guid(id); - var cipher = await _cipherRepository.GetByIdAsync(cipherId); - if (cipher == null || cipher.UserId != userId || - !await _currentContext.OrganizationUser(new Guid(model.Cipher.OrganizationId))) - { - throw new NotFoundException(); - } + AttachmentId = attachmentId, + Url = uploadUrl, + FileUploadType = _attachmentStorageService.FileUploadType, + CipherResponse = request.AdminRequest ? null : new CipherResponseModel((CipherDetails)cipher, _globalSettings), + CipherMiniResponse = request.AdminRequest ? new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp) : null, + }; + } - var original = cipher.Clone(); - await _cipherService.ShareAsync(original, model.Cipher.ToCipher(cipher), new Guid(model.Cipher.OrganizationId), - model.CollectionIds.Select(c => new Guid(c)), userId, model.Cipher.LastKnownRevisionDate); + [HttpGet("{id}/attachment/{attachmentId}/renew")] + public async Task RenewFileUploadUrl(string id, string attachmentId) + { + var userId = _userService.GetProperUserId(User).Value; + var cipherId = new Guid(id); + var cipher = await _cipherRepository.GetByIdAsync(cipherId, userId); + var attachments = cipher?.GetAttachments(); - var sharedCipher = await _cipherRepository.GetByIdAsync(cipherId, userId); - var response = new CipherResponseModel(sharedCipher, _globalSettings); - return response; + if (attachments == null || !attachments.ContainsKey(attachmentId) || attachments[attachmentId].Validated) + { + throw new NotFoundException(); } - [HttpPut("{id}/collections")] - [HttpPost("{id}/collections")] - public async Task PutCollections(string id, [FromBody] CipherCollectionsRequestModel model) + return new AttachmentUploadDataResponseModel { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); - if (cipher == null || !cipher.OrganizationId.HasValue || - !await _currentContext.OrganizationUser(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } + Url = await _attachmentStorageService.GetAttachmentUploadUrlAsync(cipher, attachments[attachmentId]), + FileUploadType = _attachmentStorageService.FileUploadType, + }; + } - await _cipherService.SaveCollectionsAsync(cipher, - model.CollectionIds.Select(c => new Guid(c)), userId, false); + [HttpPost("{id}/attachment/{attachmentId}")] + [SelfHosted(SelfHostedOnly = true)] + [RequestSizeLimit(Constants.FileSize501mb)] + [DisableFormValueModelBinding] + public async Task PostFileForExistingAttachment(string id, string attachmentId) + { + if (!Request?.ContentType.Contains("multipart/") ?? true) + { + throw new BadRequestException("Invalid content."); } - [HttpPut("{id}/collections-admin")] - [HttpPost("{id}/collections-admin")] - public async Task PutCollectionsAdmin(string id, [FromBody] CipherCollectionsRequestModel model) + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); + var attachments = cipher?.GetAttachments(); + if (attachments == null || !attachments.ContainsKey(attachmentId)) { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); - if (cipher == null || !cipher.OrganizationId.HasValue || - !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } + throw new NotFoundException(); + } + var attachmentData = attachments[attachmentId]; - await _cipherService.SaveCollectionsAsync(cipher, - model.CollectionIds.Select(c => new Guid(c)), userId, true); + await Request.GetFileAsync(async (stream) => + { + await _cipherService.UploadFileForExistingAttachmentAsync(stream, cipher, attachmentData); + }); + } + + [HttpPost("{id}/attachment")] + [Obsolete("Deprecated Attachments API", false)] + [RequestSizeLimit(Constants.FileSize101mb)] + [DisableFormValueModelBinding] + public async Task PostAttachment(string id) + { + ValidateAttachment(); + + var idGuid = new Guid(id); + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(idGuid, userId); + if (cipher == null) + { + throw new NotFoundException(); } - [HttpDelete("{id}")] - [HttpPost("{id}/delete")] - public async Task Delete(string id) + await Request.GetFileAsync(async (stream, fileName, key) => { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); - if (cipher == null) - { - throw new NotFoundException(); - } + await _cipherService.CreateAttachmentAsync(cipher, stream, fileName, key, + Request.ContentLength.GetValueOrDefault(0), userId); + }); - await _cipherService.DeleteAsync(cipher, userId); + return new CipherResponseModel(cipher, _globalSettings); + } + + [HttpPost("{id}/attachment-admin")] + [RequestSizeLimit(Constants.FileSize101mb)] + [DisableFormValueModelBinding] + public async Task PostAttachmentAdmin(string id) + { + ValidateAttachment(); + + var idGuid = new Guid(id); + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(idGuid); + if (cipher == null || !cipher.OrganizationId.HasValue || + !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) + { + throw new NotFoundException(); } - [HttpDelete("{id}/admin")] - [HttpPost("{id}/delete-admin")] - public async Task DeleteAdmin(string id) + await Request.GetFileAsync(async (stream, fileName, key) => { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); - if (cipher == null || !cipher.OrganizationId.HasValue || - !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } + await _cipherService.CreateAttachmentAsync(cipher, stream, fileName, key, + Request.ContentLength.GetValueOrDefault(0), userId, true); + }); - await _cipherService.DeleteAsync(cipher, userId, true); + return new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp); + } + + [HttpGet("{id}/attachment/{attachmentId}")] + public async Task GetAttachmentData(string id, string attachmentId) + { + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); + var result = await _cipherService.GetAttachmentDownloadDataAsync(cipher, attachmentId); + return new AttachmentResponseModel(result); + } + + [HttpPost("{id}/attachment/{attachmentId}/share")] + [RequestSizeLimit(Constants.FileSize101mb)] + [DisableFormValueModelBinding] + public async Task PostAttachmentShare(string id, string attachmentId, Guid organizationId) + { + ValidateAttachment(); + + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); + if (cipher == null || cipher.UserId != userId || !await _currentContext.OrganizationUser(organizationId)) + { + throw new NotFoundException(); } - [HttpDelete("")] - [HttpPost("delete")] - public async Task DeleteMany([FromBody] CipherBulkDeleteRequestModel model) + await Request.GetFileAsync(async (stream, fileName, key) => { - if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) - { - throw new BadRequestException("You can only delete up to 500 items at a time. " + - "Consider using the \"Purge Vault\" option instead."); - } + await _cipherService.CreateAttachmentShareAsync(cipher, stream, + Request.ContentLength.GetValueOrDefault(0), attachmentId, organizationId); + }); + } - var userId = _userService.GetProperUserId(User).Value; - await _cipherService.DeleteManyAsync(model.Ids.Select(i => new Guid(i)), userId); + [HttpDelete("{id}/attachment/{attachmentId}")] + [HttpPost("{id}/attachment/{attachmentId}/delete")] + public async Task DeleteAttachment(string id, string attachmentId) + { + var idGuid = new Guid(id); + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(idGuid, userId); + if (cipher == null) + { + throw new NotFoundException(); } - [HttpDelete("admin")] - [HttpPost("delete-admin")] - public async Task DeleteManyAdmin([FromBody] CipherBulkDeleteRequestModel model) + await _cipherService.DeleteAttachmentAsync(cipher, attachmentId, userId, false); + } + + [HttpDelete("{id}/attachment/{attachmentId}/admin")] + [HttpPost("{id}/attachment/{attachmentId}/delete-admin")] + public async Task DeleteAttachmentAdmin(string id, string attachmentId) + { + var idGuid = new Guid(id); + var userId = _userService.GetProperUserId(User).Value; + var cipher = await _cipherRepository.GetByIdAsync(idGuid); + if (cipher == null || !cipher.OrganizationId.HasValue || + !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) { - if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) - { - throw new BadRequestException("You can only delete up to 500 items at a time. " + - "Consider using the \"Purge Vault\" option instead."); - } - - if (model == null || string.IsNullOrWhiteSpace(model.OrganizationId) || - !await _currentContext.EditAnyCollection(new Guid(model.OrganizationId))) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User).Value; - await _cipherService.DeleteManyAsync(model.Ids.Select(i => new Guid(i)), userId, new Guid(model.OrganizationId), true); + throw new NotFoundException(); } - [HttpPut("{id}/delete")] - public async Task PutDelete(string id) + await _cipherService.DeleteAttachmentAsync(cipher, attachmentId, userId, true); + } + + [AllowAnonymous] + [HttpPost("attachment/validate/azure")] + public async Task AzureValidateFile() + { + return await ApiHelpers.HandleAzureEvents(Request, new Dictionary> { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); - if (cipher == null) { - throw new NotFoundException(); - } - await _cipherService.SoftDeleteAsync(cipher, userId); - } - - [HttpPut("{id}/delete-admin")] - public async Task PutDeleteAdmin(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); - if (cipher == null || !cipher.OrganizationId.HasValue || - !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } - - await _cipherService.SoftDeleteAsync(cipher, userId, true); - } - - [HttpPut("delete")] - public async Task PutDeleteMany([FromBody] CipherBulkDeleteRequestModel model) - { - if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) - { - throw new BadRequestException("You can only delete up to 500 items at a time."); - } - - var userId = _userService.GetProperUserId(User).Value; - await _cipherService.SoftDeleteManyAsync(model.Ids.Select(i => new Guid(i)), userId); - } - - [HttpPut("delete-admin")] - public async Task PutDeleteManyAdmin([FromBody] CipherBulkDeleteRequestModel model) - { - if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) - { - throw new BadRequestException("You can only delete up to 500 items at a time."); - } - - if (model == null || string.IsNullOrWhiteSpace(model.OrganizationId) || - !await _currentContext.EditAnyCollection(new Guid(model.OrganizationId))) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User).Value; - await _cipherService.SoftDeleteManyAsync(model.Ids.Select(i => new Guid(i)), userId, new Guid(model.OrganizationId), true); - } - - [HttpPut("{id}/restore")] - public async Task PutRestore(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); - if (cipher == null) - { - throw new NotFoundException(); - } - - await _cipherService.RestoreAsync(cipher, userId); - return new CipherResponseModel(cipher, _globalSettings); - } - - [HttpPut("{id}/restore-admin")] - public async Task PutRestoreAdmin(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(new Guid(id)); - if (cipher == null || !cipher.OrganizationId.HasValue || - !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } - - await _cipherService.RestoreAsync(cipher, userId, true); - return new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp); - } - - [HttpPut("restore")] - public async Task> PutRestoreMany([FromBody] CipherBulkRestoreRequestModel model) - { - if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) - { - throw new BadRequestException("You can only restore up to 500 items at a time."); - } - - var userId = _userService.GetProperUserId(User).Value; - var cipherIdsToRestore = new HashSet(model.Ids.Select(i => new Guid(i))); - - var ciphers = await _cipherRepository.GetManyByUserIdAsync(userId); - var restoringCiphers = ciphers.Where(c => cipherIdsToRestore.Contains(c.Id) && c.Edit); - - await _cipherService.RestoreManyAsync(restoringCiphers, userId); - var responses = restoringCiphers.Select(c => new CipherResponseModel(c, _globalSettings)); - return new ListResponseModel(responses); - } - - [HttpPut("move")] - [HttpPost("move")] - public async Task MoveMany([FromBody] CipherBulkMoveRequestModel model) - { - if (!_globalSettings.SelfHosted && model.Ids.Count() > 500) - { - throw new BadRequestException("You can only move up to 500 items at a time."); - } - - var userId = _userService.GetProperUserId(User).Value; - await _cipherService.MoveManyAsync(model.Ids.Select(i => new Guid(i)), - string.IsNullOrWhiteSpace(model.FolderId) ? (Guid?)null : new Guid(model.FolderId), userId); - } - - [HttpPut("share")] - [HttpPost("share")] - public async Task PutShareMany([FromBody] CipherBulkShareRequestModel model) - { - var organizationId = new Guid(model.Ciphers.First().OrganizationId); - if (!await _currentContext.OrganizationUser(organizationId)) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User).Value; - var ciphers = await _cipherRepository.GetManyByUserIdAsync(userId, false); - var ciphersDict = ciphers.ToDictionary(c => c.Id); - - var shareCiphers = new List<(Cipher, DateTime?)>(); - foreach (var cipher in model.Ciphers) - { - if (!ciphersDict.ContainsKey(cipher.Id.Value)) + "Microsoft.Storage.BlobCreated", async (eventGridEvent) => { - throw new BadRequestException("Trying to move ciphers that you do not own."); - } - - shareCiphers.Add((cipher.ToCipher(ciphersDict[cipher.Id.Value]), cipher.LastKnownRevisionDate)); - } - - await _cipherService.ShareManyAsync(shareCiphers, organizationId, - model.CollectionIds.Select(c => new Guid(c)), userId); - } - - [HttpPost("purge")] - public async Task PostPurge([FromBody] SecretVerificationRequestModel model, string organizationId = null) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (!await _userService.VerifySecretAsync(user, model.Secret)) - { - ModelState.AddModelError(string.Empty, "User verification failed."); - await Task.Delay(2000); - throw new BadRequestException(ModelState); - } - - if (string.IsNullOrWhiteSpace(organizationId)) - { - await _cipherRepository.DeleteByUserIdAsync(user.Id); - } - else - { - var orgId = new Guid(organizationId); - if (!await _currentContext.EditAnyCollection(orgId)) - { - throw new NotFoundException(); - } - await _cipherService.PurgeAsync(orgId); - } - } - - [HttpPost("{id}/attachment/v2")] - public async Task PostAttachment(string id, [FromBody] AttachmentRequestModel request) - { - var idGuid = new Guid(id); - var userId = _userService.GetProperUserId(User).Value; - var cipher = request.AdminRequest ? - await _cipherRepository.GetOrganizationDetailsByIdAsync(idGuid) : - await _cipherRepository.GetByIdAsync(idGuid, userId); - - if (cipher == null || (request.AdminRequest && (!cipher.OrganizationId.HasValue || - !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)))) - { - throw new NotFoundException(); - } - - if (request.FileSize > CipherService.MAX_FILE_SIZE) - { - throw new BadRequestException($"Max file size is {CipherService.MAX_FILE_SIZE_READABLE}."); - } - - var (attachmentId, uploadUrl) = await _cipherService.CreateAttachmentForDelayedUploadAsync(cipher, - request.Key, request.FileName, request.FileSize, request.AdminRequest, userId); - return new AttachmentUploadDataResponseModel - { - AttachmentId = attachmentId, - Url = uploadUrl, - FileUploadType = _attachmentStorageService.FileUploadType, - CipherResponse = request.AdminRequest ? null : new CipherResponseModel((CipherDetails)cipher, _globalSettings), - CipherMiniResponse = request.AdminRequest ? new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp) : null, - }; - } - - [HttpGet("{id}/attachment/{attachmentId}/renew")] - public async Task RenewFileUploadUrl(string id, string attachmentId) - { - var userId = _userService.GetProperUserId(User).Value; - var cipherId = new Guid(id); - var cipher = await _cipherRepository.GetByIdAsync(cipherId, userId); - var attachments = cipher?.GetAttachments(); - - if (attachments == null || !attachments.ContainsKey(attachmentId) || attachments[attachmentId].Validated) - { - throw new NotFoundException(); - } - - return new AttachmentUploadDataResponseModel - { - Url = await _attachmentStorageService.GetAttachmentUploadUrlAsync(cipher, attachments[attachmentId]), - FileUploadType = _attachmentStorageService.FileUploadType, - }; - } - - [HttpPost("{id}/attachment/{attachmentId}")] - [SelfHosted(SelfHostedOnly = true)] - [RequestSizeLimit(Constants.FileSize501mb)] - [DisableFormValueModelBinding] - public async Task PostFileForExistingAttachment(string id, string attachmentId) - { - if (!Request?.ContentType.Contains("multipart/") ?? true) - { - throw new BadRequestException("Invalid content."); - } - - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); - var attachments = cipher?.GetAttachments(); - if (attachments == null || !attachments.ContainsKey(attachmentId)) - { - throw new NotFoundException(); - } - var attachmentData = attachments[attachmentId]; - - await Request.GetFileAsync(async (stream) => - { - await _cipherService.UploadFileForExistingAttachmentAsync(stream, cipher, attachmentData); - }); - } - - [HttpPost("{id}/attachment")] - [Obsolete("Deprecated Attachments API", false)] - [RequestSizeLimit(Constants.FileSize101mb)] - [DisableFormValueModelBinding] - public async Task PostAttachment(string id) - { - ValidateAttachment(); - - var idGuid = new Guid(id); - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(idGuid, userId); - if (cipher == null) - { - throw new NotFoundException(); - } - - await Request.GetFileAsync(async (stream, fileName, key) => - { - await _cipherService.CreateAttachmentAsync(cipher, stream, fileName, key, - Request.ContentLength.GetValueOrDefault(0), userId); - }); - - return new CipherResponseModel(cipher, _globalSettings); - } - - [HttpPost("{id}/attachment-admin")] - [RequestSizeLimit(Constants.FileSize101mb)] - [DisableFormValueModelBinding] - public async Task PostAttachmentAdmin(string id) - { - ValidateAttachment(); - - var idGuid = new Guid(id); - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetOrganizationDetailsByIdAsync(idGuid); - if (cipher == null || !cipher.OrganizationId.HasValue || - !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } - - await Request.GetFileAsync(async (stream, fileName, key) => - { - await _cipherService.CreateAttachmentAsync(cipher, stream, fileName, key, - Request.ContentLength.GetValueOrDefault(0), userId, true); - }); - - return new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp); - } - - [HttpGet("{id}/attachment/{attachmentId}")] - public async Task GetAttachmentData(string id, string attachmentId) - { - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id), userId); - var result = await _cipherService.GetAttachmentDownloadDataAsync(cipher, attachmentId); - return new AttachmentResponseModel(result); - } - - [HttpPost("{id}/attachment/{attachmentId}/share")] - [RequestSizeLimit(Constants.FileSize101mb)] - [DisableFormValueModelBinding] - public async Task PostAttachmentShare(string id, string attachmentId, Guid organizationId) - { - ValidateAttachment(); - - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); - if (cipher == null || cipher.UserId != userId || !await _currentContext.OrganizationUser(organizationId)) - { - throw new NotFoundException(); - } - - await Request.GetFileAsync(async (stream, fileName, key) => - { - await _cipherService.CreateAttachmentShareAsync(cipher, stream, - Request.ContentLength.GetValueOrDefault(0), attachmentId, organizationId); - }); - } - - [HttpDelete("{id}/attachment/{attachmentId}")] - [HttpPost("{id}/attachment/{attachmentId}/delete")] - public async Task DeleteAttachment(string id, string attachmentId) - { - var idGuid = new Guid(id); - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(idGuid, userId); - if (cipher == null) - { - throw new NotFoundException(); - } - - await _cipherService.DeleteAttachmentAsync(cipher, attachmentId, userId, false); - } - - [HttpDelete("{id}/attachment/{attachmentId}/admin")] - [HttpPost("{id}/attachment/{attachmentId}/delete-admin")] - public async Task DeleteAttachmentAdmin(string id, string attachmentId) - { - var idGuid = new Guid(id); - var userId = _userService.GetProperUserId(User).Value; - var cipher = await _cipherRepository.GetByIdAsync(idGuid); - if (cipher == null || !cipher.OrganizationId.HasValue || - !await _currentContext.EditAnyCollection(cipher.OrganizationId.Value)) - { - throw new NotFoundException(); - } - - await _cipherService.DeleteAttachmentAsync(cipher, attachmentId, userId, true); - } - - [AllowAnonymous] - [HttpPost("attachment/validate/azure")] - public async Task AzureValidateFile() - { - return await ApiHelpers.HandleAzureEvents(Request, new Dictionary> - { - { - "Microsoft.Storage.BlobCreated", async (eventGridEvent) => + try { - try + var blobName = eventGridEvent.Subject.Split($"{AzureAttachmentStorageService.EventGridEnabledContainerName}/blobs/")[1]; + var (cipherId, organizationId, attachmentId) = AzureAttachmentStorageService.IdentifiersFromBlobName(blobName); + var cipher = await _cipherRepository.GetByIdAsync(new Guid(cipherId)); + var attachments = cipher?.GetAttachments() ?? new Dictionary(); + + if (cipher == null || !attachments.ContainsKey(attachmentId) || attachments[attachmentId].Validated) { - var blobName = eventGridEvent.Subject.Split($"{AzureAttachmentStorageService.EventGridEnabledContainerName}/blobs/")[1]; - var (cipherId, organizationId, attachmentId) = AzureAttachmentStorageService.IdentifiersFromBlobName(blobName); - var cipher = await _cipherRepository.GetByIdAsync(new Guid(cipherId)); - var attachments = cipher?.GetAttachments() ?? new Dictionary(); - - if (cipher == null || !attachments.ContainsKey(attachmentId) || attachments[attachmentId].Validated) + if (_attachmentStorageService is AzureSendFileStorageService azureFileStorageService) { - if (_attachmentStorageService is AzureSendFileStorageService azureFileStorageService) - { - await azureFileStorageService.DeleteBlobAsync(blobName); - } - - return; + await azureFileStorageService.DeleteBlobAsync(blobName); } - await _cipherService.ValidateCipherAttachmentFile(cipher, attachments[attachmentId]); - } - catch (Exception e) - { - _logger.LogError(e, $"Uncaught exception occurred while handling event grid event: {JsonSerializer.Serialize(eventGridEvent)}"); return; } + + await _cipherService.ValidateCipherAttachmentFile(cipher, attachments[attachmentId]); + } + catch (Exception e) + { + _logger.LogError(e, $"Uncaught exception occurred while handling event grid event: {JsonSerializer.Serialize(eventGridEvent)}"); + return; } } - }); - } - - private void ValidateAttachment() - { - if (!Request?.ContentType.Contains("multipart/") ?? true) - { - throw new BadRequestException("Invalid content."); } + }); + } + + private void ValidateAttachment() + { + if (!Request?.ContentType.Contains("multipart/") ?? true) + { + throw new BadRequestException("Invalid content."); } } } diff --git a/src/Api/Controllers/CollectionsController.cs b/src/Api/Controllers/CollectionsController.cs index 548ff80d4..14d4e95b2 100644 --- a/src/Api/Controllers/CollectionsController.cs +++ b/src/Api/Controllers/CollectionsController.cs @@ -8,261 +8,260 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("organizations/{orgId}/collections")] +[Authorize("Application")] +public class CollectionsController : Controller { - [Route("organizations/{orgId}/collections")] - [Authorize("Application")] - public class CollectionsController : Controller + private readonly ICollectionRepository _collectionRepository; + private readonly ICollectionService _collectionService; + private readonly IUserService _userService; + private readonly ICurrentContext _currentContext; + + public CollectionsController( + ICollectionRepository collectionRepository, + ICollectionService collectionService, + IUserService userService, + ICurrentContext currentContext) { - private readonly ICollectionRepository _collectionRepository; - private readonly ICollectionService _collectionService; - private readonly IUserService _userService; - private readonly ICurrentContext _currentContext; + _collectionRepository = collectionRepository; + _collectionService = collectionService; + _userService = userService; + _currentContext = currentContext; + } - public CollectionsController( - ICollectionRepository collectionRepository, - ICollectionService collectionService, - IUserService userService, - ICurrentContext currentContext) + [HttpGet("{id}")] + public async Task Get(Guid orgId, Guid id) + { + if (!await CanViewCollectionAsync(orgId, id)) { - _collectionRepository = collectionRepository; - _collectionService = collectionService; - _userService = userService; - _currentContext = currentContext; + throw new NotFoundException(); } - [HttpGet("{id}")] - public async Task Get(Guid orgId, Guid id) + var collection = await GetCollectionAsync(id, orgId); + return new CollectionResponseModel(collection); + } + + [HttpGet("{id}/details")] + public async Task GetDetails(Guid orgId, Guid id) + { + if (!await ViewAtLeastOneCollectionAsync(orgId) && !await _currentContext.ManageUsers(orgId)) { - if (!await CanViewCollectionAsync(orgId, id)) + throw new NotFoundException(); + } + + if (await _currentContext.ViewAllCollections(orgId)) + { + var collectionDetails = await _collectionRepository.GetByIdWithGroupsAsync(id); + if (collectionDetails?.Item1 == null || collectionDetails.Item1.OrganizationId != orgId) { throw new NotFoundException(); } - - var collection = await GetCollectionAsync(id, orgId); - return new CollectionResponseModel(collection); + return new CollectionGroupDetailsResponseModel(collectionDetails.Item1, collectionDetails.Item2); } - - [HttpGet("{id}/details")] - public async Task GetDetails(Guid orgId, Guid id) + else { - if (!await ViewAtLeastOneCollectionAsync(orgId) && !await _currentContext.ManageUsers(orgId)) + var collectionDetails = await _collectionRepository.GetByIdWithGroupsAsync(id, + _currentContext.UserId.Value); + if (collectionDetails?.Item1 == null || collectionDetails.Item1.OrganizationId != orgId) { throw new NotFoundException(); } - - if (await _currentContext.ViewAllCollections(orgId)) - { - var collectionDetails = await _collectionRepository.GetByIdWithGroupsAsync(id); - if (collectionDetails?.Item1 == null || collectionDetails.Item1.OrganizationId != orgId) - { - throw new NotFoundException(); - } - return new CollectionGroupDetailsResponseModel(collectionDetails.Item1, collectionDetails.Item2); - } - else - { - var collectionDetails = await _collectionRepository.GetByIdWithGroupsAsync(id, - _currentContext.UserId.Value); - if (collectionDetails?.Item1 == null || collectionDetails.Item1.OrganizationId != orgId) - { - throw new NotFoundException(); - } - return new CollectionGroupDetailsResponseModel(collectionDetails.Item1, collectionDetails.Item2); - } - } - - [HttpGet("")] - public async Task> Get(Guid orgId) - { - IEnumerable orgCollections = await _collectionService.GetOrganizationCollections(orgId); - - var responses = orgCollections.Select(c => new CollectionResponseModel(c)); - return new ListResponseModel(responses); - } - - [HttpGet("~/collections")] - public async Task> GetUser() - { - var collections = await _collectionRepository.GetManyByUserIdAsync( - _userService.GetProperUserId(User).Value); - var responses = collections.Select(c => new CollectionDetailsResponseModel(c)); - return new ListResponseModel(responses); - } - - [HttpGet("{id}/users")] - public async Task> GetUsers(Guid orgId, Guid id) - { - var collection = await GetCollectionAsync(id, orgId); - var collectionUsers = await _collectionRepository.GetManyUsersByIdAsync(collection.Id); - var responses = collectionUsers.Select(cu => new SelectionReadOnlyResponseModel(cu)); - return responses; - } - - [HttpPost("")] - public async Task Post(Guid orgId, [FromBody] CollectionRequestModel model) - { - var collection = model.ToCollection(orgId); - - if (!await CanCreateCollection(orgId, collection.Id) && - !await CanEditCollectionAsync(orgId, collection.Id)) - { - throw new NotFoundException(); - } - - var assignUserToCollection = !(await _currentContext.EditAnyCollection(orgId)) && - await _currentContext.EditAssignedCollections(orgId); - - await _collectionService.SaveAsync(collection, model.Groups?.Select(g => g.ToSelectionReadOnly()), - assignUserToCollection ? _currentContext.UserId : null); - return new CollectionResponseModel(collection); - } - - [HttpPut("{id}")] - [HttpPost("{id}")] - public async Task Put(Guid orgId, Guid id, [FromBody] CollectionRequestModel model) - { - if (!await CanEditCollectionAsync(orgId, id)) - { - throw new NotFoundException(); - } - - var collection = await GetCollectionAsync(id, orgId); - await _collectionService.SaveAsync(model.ToCollection(collection), - model.Groups?.Select(g => g.ToSelectionReadOnly())); - return new CollectionResponseModel(collection); - } - - [HttpPut("{id}/users")] - public async Task PutUsers(Guid orgId, Guid id, [FromBody] IEnumerable model) - { - if (!await CanEditCollectionAsync(orgId, id)) - { - throw new NotFoundException(); - } - - var collection = await GetCollectionAsync(id, orgId); - await _collectionRepository.UpdateUsersAsync(collection.Id, model?.Select(g => g.ToSelectionReadOnly())); - } - - [HttpDelete("{id}")] - [HttpPost("{id}/delete")] - public async Task Delete(Guid orgId, Guid id) - { - if (!await CanDeleteCollectionAsync(orgId, id)) - { - throw new NotFoundException(); - } - - var collection = await GetCollectionAsync(id, orgId); - await _collectionService.DeleteAsync(collection); - } - - [HttpDelete("{id}/user/{orgUserId}")] - [HttpPost("{id}/delete-user/{orgUserId}")] - public async Task Delete(string orgId, string id, string orgUserId) - { - var collection = await GetCollectionAsync(new Guid(id), new Guid(orgId)); - await _collectionService.DeleteUserAsync(collection, new Guid(orgUserId)); - } - - private async Task GetCollectionAsync(Guid id, Guid orgId) - { - Collection collection = default; - if (await _currentContext.ViewAllCollections(orgId)) - { - collection = await _collectionRepository.GetByIdAsync(id); - } - else if (await _currentContext.ViewAssignedCollections(orgId)) - { - collection = await _collectionRepository.GetByIdAsync(id, _currentContext.UserId.Value); - } - - if (collection == null || collection.OrganizationId != orgId) - { - throw new NotFoundException(); - } - - return collection; - } - - - private async Task CanCreateCollection(Guid orgId, Guid collectionId) - { - if (collectionId != default) - { - return false; - } - - return await _currentContext.CreateNewCollections(orgId); - } - - private async Task CanEditCollectionAsync(Guid orgId, Guid collectionId) - { - if (collectionId == default) - { - return false; - } - - if (await _currentContext.EditAnyCollection(orgId)) - { - return true; - } - - if (await _currentContext.EditAssignedCollections(orgId)) - { - var collectionDetails = await _collectionRepository.GetByIdAsync(collectionId, _currentContext.UserId.Value); - return collectionDetails != null; - } - - return false; - } - - private async Task CanDeleteCollectionAsync(Guid orgId, Guid collectionId) - { - if (collectionId == default) - { - return false; - } - - if (await _currentContext.DeleteAnyCollection(orgId)) - { - return true; - } - - if (await _currentContext.DeleteAssignedCollections(orgId)) - { - var collectionDetails = await _collectionRepository.GetByIdAsync(collectionId, _currentContext.UserId.Value); - return collectionDetails != null; - } - - return false; - } - - private async Task CanViewCollectionAsync(Guid orgId, Guid collectionId) - { - if (collectionId == default) - { - return false; - } - - if (await _currentContext.ViewAllCollections(orgId)) - { - return true; - } - - if (await _currentContext.ViewAssignedCollections(orgId)) - { - var collectionDetails = await _collectionRepository.GetByIdAsync(collectionId, _currentContext.UserId.Value); - return collectionDetails != null; - } - - return false; - } - - private async Task ViewAtLeastOneCollectionAsync(Guid orgId) - { - return await _currentContext.ViewAllCollections(orgId) || await _currentContext.ViewAssignedCollections(orgId); + return new CollectionGroupDetailsResponseModel(collectionDetails.Item1, collectionDetails.Item2); } } + + [HttpGet("")] + public async Task> Get(Guid orgId) + { + IEnumerable orgCollections = await _collectionService.GetOrganizationCollections(orgId); + + var responses = orgCollections.Select(c => new CollectionResponseModel(c)); + return new ListResponseModel(responses); + } + + [HttpGet("~/collections")] + public async Task> GetUser() + { + var collections = await _collectionRepository.GetManyByUserIdAsync( + _userService.GetProperUserId(User).Value); + var responses = collections.Select(c => new CollectionDetailsResponseModel(c)); + return new ListResponseModel(responses); + } + + [HttpGet("{id}/users")] + public async Task> GetUsers(Guid orgId, Guid id) + { + var collection = await GetCollectionAsync(id, orgId); + var collectionUsers = await _collectionRepository.GetManyUsersByIdAsync(collection.Id); + var responses = collectionUsers.Select(cu => new SelectionReadOnlyResponseModel(cu)); + return responses; + } + + [HttpPost("")] + public async Task Post(Guid orgId, [FromBody] CollectionRequestModel model) + { + var collection = model.ToCollection(orgId); + + if (!await CanCreateCollection(orgId, collection.Id) && + !await CanEditCollectionAsync(orgId, collection.Id)) + { + throw new NotFoundException(); + } + + var assignUserToCollection = !(await _currentContext.EditAnyCollection(orgId)) && + await _currentContext.EditAssignedCollections(orgId); + + await _collectionService.SaveAsync(collection, model.Groups?.Select(g => g.ToSelectionReadOnly()), + assignUserToCollection ? _currentContext.UserId : null); + return new CollectionResponseModel(collection); + } + + [HttpPut("{id}")] + [HttpPost("{id}")] + public async Task Put(Guid orgId, Guid id, [FromBody] CollectionRequestModel model) + { + if (!await CanEditCollectionAsync(orgId, id)) + { + throw new NotFoundException(); + } + + var collection = await GetCollectionAsync(id, orgId); + await _collectionService.SaveAsync(model.ToCollection(collection), + model.Groups?.Select(g => g.ToSelectionReadOnly())); + return new CollectionResponseModel(collection); + } + + [HttpPut("{id}/users")] + public async Task PutUsers(Guid orgId, Guid id, [FromBody] IEnumerable model) + { + if (!await CanEditCollectionAsync(orgId, id)) + { + throw new NotFoundException(); + } + + var collection = await GetCollectionAsync(id, orgId); + await _collectionRepository.UpdateUsersAsync(collection.Id, model?.Select(g => g.ToSelectionReadOnly())); + } + + [HttpDelete("{id}")] + [HttpPost("{id}/delete")] + public async Task Delete(Guid orgId, Guid id) + { + if (!await CanDeleteCollectionAsync(orgId, id)) + { + throw new NotFoundException(); + } + + var collection = await GetCollectionAsync(id, orgId); + await _collectionService.DeleteAsync(collection); + } + + [HttpDelete("{id}/user/{orgUserId}")] + [HttpPost("{id}/delete-user/{orgUserId}")] + public async Task Delete(string orgId, string id, string orgUserId) + { + var collection = await GetCollectionAsync(new Guid(id), new Guid(orgId)); + await _collectionService.DeleteUserAsync(collection, new Guid(orgUserId)); + } + + private async Task GetCollectionAsync(Guid id, Guid orgId) + { + Collection collection = default; + if (await _currentContext.ViewAllCollections(orgId)) + { + collection = await _collectionRepository.GetByIdAsync(id); + } + else if (await _currentContext.ViewAssignedCollections(orgId)) + { + collection = await _collectionRepository.GetByIdAsync(id, _currentContext.UserId.Value); + } + + if (collection == null || collection.OrganizationId != orgId) + { + throw new NotFoundException(); + } + + return collection; + } + + + private async Task CanCreateCollection(Guid orgId, Guid collectionId) + { + if (collectionId != default) + { + return false; + } + + return await _currentContext.CreateNewCollections(orgId); + } + + private async Task CanEditCollectionAsync(Guid orgId, Guid collectionId) + { + if (collectionId == default) + { + return false; + } + + if (await _currentContext.EditAnyCollection(orgId)) + { + return true; + } + + if (await _currentContext.EditAssignedCollections(orgId)) + { + var collectionDetails = await _collectionRepository.GetByIdAsync(collectionId, _currentContext.UserId.Value); + return collectionDetails != null; + } + + return false; + } + + private async Task CanDeleteCollectionAsync(Guid orgId, Guid collectionId) + { + if (collectionId == default) + { + return false; + } + + if (await _currentContext.DeleteAnyCollection(orgId)) + { + return true; + } + + if (await _currentContext.DeleteAssignedCollections(orgId)) + { + var collectionDetails = await _collectionRepository.GetByIdAsync(collectionId, _currentContext.UserId.Value); + return collectionDetails != null; + } + + return false; + } + + private async Task CanViewCollectionAsync(Guid orgId, Guid collectionId) + { + if (collectionId == default) + { + return false; + } + + if (await _currentContext.ViewAllCollections(orgId)) + { + return true; + } + + if (await _currentContext.ViewAssignedCollections(orgId)) + { + var collectionDetails = await _collectionRepository.GetByIdAsync(collectionId, _currentContext.UserId.Value); + return collectionDetails != null; + } + + return false; + } + + private async Task ViewAtLeastOneCollectionAsync(Guid orgId) + { + return await _currentContext.ViewAllCollections(orgId) || await _currentContext.ViewAssignedCollections(orgId); + } } diff --git a/src/Api/Controllers/DevicesController.cs b/src/Api/Controllers/DevicesController.cs index 8bfa5d7b0..77fb34c64 100644 --- a/src/Api/Controllers/DevicesController.cs +++ b/src/Api/Controllers/DevicesController.cs @@ -7,124 +7,123 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("devices")] +[Authorize("Application")] +public class DevicesController : Controller { - [Route("devices")] - [Authorize("Application")] - public class DevicesController : Controller + private readonly IDeviceRepository _deviceRepository; + private readonly IDeviceService _deviceService; + private readonly IUserService _userService; + + public DevicesController( + IDeviceRepository deviceRepository, + IDeviceService deviceService, + IUserService userService) { - private readonly IDeviceRepository _deviceRepository; - private readonly IDeviceService _deviceService; - private readonly IUserService _userService; + _deviceRepository = deviceRepository; + _deviceService = deviceService; + _userService = userService; + } - public DevicesController( - IDeviceRepository deviceRepository, - IDeviceService deviceService, - IUserService userService) + [HttpGet("{id}")] + public async Task Get(string id) + { + var device = await _deviceRepository.GetByIdAsync(new Guid(id), _userService.GetProperUserId(User).Value); + if (device == null) { - _deviceRepository = deviceRepository; - _deviceService = deviceService; - _userService = userService; + throw new NotFoundException(); } - [HttpGet("{id}")] - public async Task Get(string id) - { - var device = await _deviceRepository.GetByIdAsync(new Guid(id), _userService.GetProperUserId(User).Value); - if (device == null) - { - throw new NotFoundException(); - } + var response = new DeviceResponseModel(device); + return response; + } - var response = new DeviceResponseModel(device); - return response; + [HttpGet("identifier/{identifier}")] + public async Task GetByIdentifier(string identifier) + { + var device = await _deviceRepository.GetByIdentifierAsync(identifier, _userService.GetProperUserId(User).Value); + if (device == null) + { + throw new NotFoundException(); } - [HttpGet("identifier/{identifier}")] - public async Task GetByIdentifier(string identifier) - { - var device = await _deviceRepository.GetByIdentifierAsync(identifier, _userService.GetProperUserId(User).Value); - if (device == null) - { - throw new NotFoundException(); - } + var response = new DeviceResponseModel(device); + return response; + } - var response = new DeviceResponseModel(device); - return response; + [HttpGet("")] + public async Task> Get() + { + ICollection devices = await _deviceRepository.GetManyByUserIdAsync(_userService.GetProperUserId(User).Value); + var responses = devices.Select(d => new DeviceResponseModel(d)); + return new ListResponseModel(responses); + } + + [HttpPost("")] + public async Task Post([FromBody] DeviceRequestModel model) + { + var device = model.ToDevice(_userService.GetProperUserId(User)); + await _deviceService.SaveAsync(device); + + var response = new DeviceResponseModel(device); + return response; + } + + [HttpPut("{id}")] + [HttpPost("{id}")] + public async Task Put(string id, [FromBody] DeviceRequestModel model) + { + var device = await _deviceRepository.GetByIdAsync(new Guid(id), _userService.GetProperUserId(User).Value); + if (device == null) + { + throw new NotFoundException(); } - [HttpGet("")] - public async Task> Get() + await _deviceService.SaveAsync(model.ToDevice(device)); + + var response = new DeviceResponseModel(device); + return response; + } + + [HttpPut("identifier/{identifier}/token")] + [HttpPost("identifier/{identifier}/token")] + public async Task PutToken(string identifier, [FromBody] DeviceTokenRequestModel model) + { + var device = await _deviceRepository.GetByIdentifierAsync(identifier, _userService.GetProperUserId(User).Value); + if (device == null) { - ICollection devices = await _deviceRepository.GetManyByUserIdAsync(_userService.GetProperUserId(User).Value); - var responses = devices.Select(d => new DeviceResponseModel(d)); - return new ListResponseModel(responses); + throw new NotFoundException(); } - [HttpPost("")] - public async Task Post([FromBody] DeviceRequestModel model) - { - var device = model.ToDevice(_userService.GetProperUserId(User)); - await _deviceService.SaveAsync(device); + await _deviceService.SaveAsync(model.ToDevice(device)); + } - var response = new DeviceResponseModel(device); - return response; + [AllowAnonymous] + [HttpPut("identifier/{identifier}/clear-token")] + [HttpPost("identifier/{identifier}/clear-token")] + public async Task PutClearToken(string identifier) + { + var device = await _deviceRepository.GetByIdentifierAsync(identifier); + if (device == null) + { + throw new NotFoundException(); } - [HttpPut("{id}")] - [HttpPost("{id}")] - public async Task Put(string id, [FromBody] DeviceRequestModel model) + await _deviceService.ClearTokenAsync(device); + } + + [HttpDelete("{id}")] + [HttpPost("{id}/delete")] + public async Task Delete(string id) + { + var device = await _deviceRepository.GetByIdAsync(new Guid(id), _userService.GetProperUserId(User).Value); + if (device == null) { - var device = await _deviceRepository.GetByIdAsync(new Guid(id), _userService.GetProperUserId(User).Value); - if (device == null) - { - throw new NotFoundException(); - } - - await _deviceService.SaveAsync(model.ToDevice(device)); - - var response = new DeviceResponseModel(device); - return response; + throw new NotFoundException(); } - [HttpPut("identifier/{identifier}/token")] - [HttpPost("identifier/{identifier}/token")] - public async Task PutToken(string identifier, [FromBody] DeviceTokenRequestModel model) - { - var device = await _deviceRepository.GetByIdentifierAsync(identifier, _userService.GetProperUserId(User).Value); - if (device == null) - { - throw new NotFoundException(); - } - - await _deviceService.SaveAsync(model.ToDevice(device)); - } - - [AllowAnonymous] - [HttpPut("identifier/{identifier}/clear-token")] - [HttpPost("identifier/{identifier}/clear-token")] - public async Task PutClearToken(string identifier) - { - var device = await _deviceRepository.GetByIdentifierAsync(identifier); - if (device == null) - { - throw new NotFoundException(); - } - - await _deviceService.ClearTokenAsync(device); - } - - [HttpDelete("{id}")] - [HttpPost("{id}/delete")] - public async Task Delete(string id) - { - var device = await _deviceRepository.GetByIdAsync(new Guid(id), _userService.GetProperUserId(User).Value); - if (device == null) - { - throw new NotFoundException(); - } - - await _deviceService.DeleteAsync(device); - } + await _deviceService.DeleteAsync(device); } } diff --git a/src/Api/Controllers/EmergencyAccessController.cs b/src/Api/Controllers/EmergencyAccessController.cs index 4e8ac834d..b2eb997b4 100644 --- a/src/Api/Controllers/EmergencyAccessController.cs +++ b/src/Api/Controllers/EmergencyAccessController.cs @@ -9,170 +9,169 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("emergency-access")] +[Authorize("Application")] +public class EmergencyAccessController : Controller { - [Route("emergency-access")] - [Authorize("Application")] - public class EmergencyAccessController : Controller + private readonly IUserService _userService; + private readonly IEmergencyAccessRepository _emergencyAccessRepository; + private readonly IEmergencyAccessService _emergencyAccessService; + private readonly IGlobalSettings _globalSettings; + + public EmergencyAccessController( + IUserService userService, + IEmergencyAccessRepository emergencyAccessRepository, + IEmergencyAccessService emergencyAccessService, + IGlobalSettings globalSettings) { - private readonly IUserService _userService; - private readonly IEmergencyAccessRepository _emergencyAccessRepository; - private readonly IEmergencyAccessService _emergencyAccessService; - private readonly IGlobalSettings _globalSettings; + _userService = userService; + _emergencyAccessRepository = emergencyAccessRepository; + _emergencyAccessService = emergencyAccessService; + _globalSettings = globalSettings; + } - public EmergencyAccessController( - IUserService userService, - IEmergencyAccessRepository emergencyAccessRepository, - IEmergencyAccessService emergencyAccessService, - IGlobalSettings globalSettings) + [HttpGet("trusted")] + public async Task> GetContacts() + { + var userId = _userService.GetProperUserId(User); + var granteeDetails = await _emergencyAccessRepository.GetManyDetailsByGrantorIdAsync(userId.Value); + + var responses = granteeDetails.Select(d => + new EmergencyAccessGranteeDetailsResponseModel(d)); + + return new ListResponseModel(responses); + } + + [HttpGet("granted")] + public async Task> GetGrantees() + { + var userId = _userService.GetProperUserId(User); + var granteeDetails = await _emergencyAccessRepository.GetManyDetailsByGranteeIdAsync(userId.Value); + + var responses = granteeDetails.Select(d => new EmergencyAccessGrantorDetailsResponseModel(d)); + + return new ListResponseModel(responses); + } + + [HttpGet("{id}")] + public async Task Get(Guid id) + { + var userId = _userService.GetProperUserId(User); + var result = await _emergencyAccessService.GetAsync(id, userId.Value); + return new EmergencyAccessGranteeDetailsResponseModel(result); + } + + [HttpGet("{id}/policies")] + public async Task> Policies(Guid id) + { + var user = await _userService.GetUserByPrincipalAsync(User); + var policies = await _emergencyAccessService.GetPoliciesAsync(id, user); + var responses = policies.Select(policy => new PolicyResponseModel(policy)); + return new ListResponseModel(responses); + } + + [HttpPut("{id}")] + [HttpPost("{id}")] + public async Task Put(Guid id, [FromBody] EmergencyAccessUpdateRequestModel model) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + if (emergencyAccess == null) { - _userService = userService; - _emergencyAccessRepository = emergencyAccessRepository; - _emergencyAccessService = emergencyAccessService; - _globalSettings = globalSettings; + throw new NotFoundException(); } - [HttpGet("trusted")] - public async Task> GetContacts() - { - var userId = _userService.GetProperUserId(User); - var granteeDetails = await _emergencyAccessRepository.GetManyDetailsByGrantorIdAsync(userId.Value); + var user = await _userService.GetUserByPrincipalAsync(User); + await _emergencyAccessService.SaveAsync(model.ToEmergencyAccess(emergencyAccess), user); + } - var responses = granteeDetails.Select(d => - new EmergencyAccessGranteeDetailsResponseModel(d)); + [HttpDelete("{id}")] + [HttpPost("{id}/delete")] + public async Task Delete(Guid id) + { + var userId = _userService.GetProperUserId(User); + await _emergencyAccessService.DeleteAsync(id, userId.Value); + } - return new ListResponseModel(responses); - } + [HttpPost("invite")] + public async Task Invite([FromBody] EmergencyAccessInviteRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + await _emergencyAccessService.InviteAsync(user, model.Email, model.Type.Value, model.WaitTimeDays); + } - [HttpGet("granted")] - public async Task> GetGrantees() - { - var userId = _userService.GetProperUserId(User); - var granteeDetails = await _emergencyAccessRepository.GetManyDetailsByGranteeIdAsync(userId.Value); + [HttpPost("{id}/reinvite")] + public async Task Reinvite(Guid id) + { + var user = await _userService.GetUserByPrincipalAsync(User); + await _emergencyAccessService.ResendInviteAsync(user, id); + } - var responses = granteeDetails.Select(d => new EmergencyAccessGrantorDetailsResponseModel(d)); + [HttpPost("{id}/accept")] + public async Task Accept(Guid id, [FromBody] OrganizationUserAcceptRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + await _emergencyAccessService.AcceptUserAsync(id, user, model.Token, _userService); + } - return new ListResponseModel(responses); - } + [HttpPost("{id}/confirm")] + public async Task Confirm(Guid id, [FromBody] OrganizationUserConfirmRequestModel model) + { + var userId = _userService.GetProperUserId(User); + await _emergencyAccessService.ConfirmUserAsync(id, model.Key, userId.Value); + } - [HttpGet("{id}")] - public async Task Get(Guid id) - { - var userId = _userService.GetProperUserId(User); - var result = await _emergencyAccessService.GetAsync(id, userId.Value); - return new EmergencyAccessGranteeDetailsResponseModel(result); - } + [HttpPost("{id}/initiate")] + public async Task Initiate(Guid id) + { + var user = await _userService.GetUserByPrincipalAsync(User); + await _emergencyAccessService.InitiateAsync(id, user); + } - [HttpGet("{id}/policies")] - public async Task> Policies(Guid id) - { - var user = await _userService.GetUserByPrincipalAsync(User); - var policies = await _emergencyAccessService.GetPoliciesAsync(id, user); - var responses = policies.Select(policy => new PolicyResponseModel(policy)); - return new ListResponseModel(responses); - } + [HttpPost("{id}/approve")] + public async Task Accept(Guid id) + { + var user = await _userService.GetUserByPrincipalAsync(User); + await _emergencyAccessService.ApproveAsync(id, user); + } - [HttpPut("{id}")] - [HttpPost("{id}")] - public async Task Put(Guid id, [FromBody] EmergencyAccessUpdateRequestModel model) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); - if (emergencyAccess == null) - { - throw new NotFoundException(); - } + [HttpPost("{id}/reject")] + public async Task Reject(Guid id) + { + var user = await _userService.GetUserByPrincipalAsync(User); + await _emergencyAccessService.RejectAsync(id, user); + } - var user = await _userService.GetUserByPrincipalAsync(User); - await _emergencyAccessService.SaveAsync(model.ToEmergencyAccess(emergencyAccess), user); - } + [HttpPost("{id}/takeover")] + public async Task Takeover(Guid id) + { + var user = await _userService.GetUserByPrincipalAsync(User); + var (result, grantor) = await _emergencyAccessService.TakeoverAsync(id, user); + return new EmergencyAccessTakeoverResponseModel(result, grantor); + } - [HttpDelete("{id}")] - [HttpPost("{id}/delete")] - public async Task Delete(Guid id) - { - var userId = _userService.GetProperUserId(User); - await _emergencyAccessService.DeleteAsync(id, userId.Value); - } + [HttpPost("{id}/password")] + public async Task Password(Guid id, [FromBody] EmergencyAccessPasswordRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + await _emergencyAccessService.PasswordAsync(id, user, model.NewMasterPasswordHash, model.Key); + } - [HttpPost("invite")] - public async Task Invite([FromBody] EmergencyAccessInviteRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - await _emergencyAccessService.InviteAsync(user, model.Email, model.Type.Value, model.WaitTimeDays); - } + [HttpPost("{id}/view")] + public async Task ViewCiphers(Guid id) + { + var user = await _userService.GetUserByPrincipalAsync(User); + var viewResult = await _emergencyAccessService.ViewAsync(id, user); + return new EmergencyAccessViewResponseModel(_globalSettings, viewResult.EmergencyAccess, viewResult.Ciphers); + } - [HttpPost("{id}/reinvite")] - public async Task Reinvite(Guid id) - { - var user = await _userService.GetUserByPrincipalAsync(User); - await _emergencyAccessService.ResendInviteAsync(user, id); - } - - [HttpPost("{id}/accept")] - public async Task Accept(Guid id, [FromBody] OrganizationUserAcceptRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - await _emergencyAccessService.AcceptUserAsync(id, user, model.Token, _userService); - } - - [HttpPost("{id}/confirm")] - public async Task Confirm(Guid id, [FromBody] OrganizationUserConfirmRequestModel model) - { - var userId = _userService.GetProperUserId(User); - await _emergencyAccessService.ConfirmUserAsync(id, model.Key, userId.Value); - } - - [HttpPost("{id}/initiate")] - public async Task Initiate(Guid id) - { - var user = await _userService.GetUserByPrincipalAsync(User); - await _emergencyAccessService.InitiateAsync(id, user); - } - - [HttpPost("{id}/approve")] - public async Task Accept(Guid id) - { - var user = await _userService.GetUserByPrincipalAsync(User); - await _emergencyAccessService.ApproveAsync(id, user); - } - - [HttpPost("{id}/reject")] - public async Task Reject(Guid id) - { - var user = await _userService.GetUserByPrincipalAsync(User); - await _emergencyAccessService.RejectAsync(id, user); - } - - [HttpPost("{id}/takeover")] - public async Task Takeover(Guid id) - { - var user = await _userService.GetUserByPrincipalAsync(User); - var (result, grantor) = await _emergencyAccessService.TakeoverAsync(id, user); - return new EmergencyAccessTakeoverResponseModel(result, grantor); - } - - [HttpPost("{id}/password")] - public async Task Password(Guid id, [FromBody] EmergencyAccessPasswordRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - await _emergencyAccessService.PasswordAsync(id, user, model.NewMasterPasswordHash, model.Key); - } - - [HttpPost("{id}/view")] - public async Task ViewCiphers(Guid id) - { - var user = await _userService.GetUserByPrincipalAsync(User); - var viewResult = await _emergencyAccessService.ViewAsync(id, user); - return new EmergencyAccessViewResponseModel(_globalSettings, viewResult.EmergencyAccess, viewResult.Ciphers); - } - - [HttpGet("{id}/{cipherId}/attachment/{attachmentId}")] - public async Task GetAttachmentData(Guid id, Guid cipherId, string attachmentId) - { - var user = await _userService.GetUserByPrincipalAsync(User); - var result = - await _emergencyAccessService.GetAttachmentDownloadAsync(id, cipherId, attachmentId, user); - return new AttachmentResponseModel(result); - } + [HttpGet("{id}/{cipherId}/attachment/{attachmentId}")] + public async Task GetAttachmentData(Guid id, Guid cipherId, string attachmentId) + { + var user = await _userService.GetUserByPrincipalAsync(User); + var result = + await _emergencyAccessService.GetAttachmentDownloadAsync(id, cipherId, attachmentId, user); + return new AttachmentResponseModel(result); } } diff --git a/src/Api/Controllers/EventsController.cs b/src/Api/Controllers/EventsController.cs index ad657692e..4fd1496b0 100644 --- a/src/Api/Controllers/EventsController.cs +++ b/src/Api/Controllers/EventsController.cs @@ -7,171 +7,170 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("events")] +[Authorize("Application")] +public class EventsController : Controller { - [Route("events")] - [Authorize("Application")] - public class EventsController : Controller + private readonly IUserService _userService; + private readonly ICipherRepository _cipherRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IProviderUserRepository _providerUserRepository; + private readonly IEventRepository _eventRepository; + private readonly ICurrentContext _currentContext; + + public EventsController( + IUserService userService, + ICipherRepository cipherRepository, + IOrganizationUserRepository organizationUserRepository, + IProviderUserRepository providerUserRepository, + IEventRepository eventRepository, + ICurrentContext currentContext) { - private readonly IUserService _userService; - private readonly ICipherRepository _cipherRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IProviderUserRepository _providerUserRepository; - private readonly IEventRepository _eventRepository; - private readonly ICurrentContext _currentContext; + _userService = userService; + _cipherRepository = cipherRepository; + _organizationUserRepository = organizationUserRepository; + _providerUserRepository = providerUserRepository; + _eventRepository = eventRepository; + _currentContext = currentContext; + } - public EventsController( - IUserService userService, - ICipherRepository cipherRepository, - IOrganizationUserRepository organizationUserRepository, - IProviderUserRepository providerUserRepository, - IEventRepository eventRepository, - ICurrentContext currentContext) + [HttpGet("")] + public async Task> GetUser( + [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) + { + var dateRange = GetDateRange(start, end); + var userId = _userService.GetProperUserId(User).Value; + var result = await _eventRepository.GetManyByUserAsync(userId, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = continuationToken }); + var responses = result.Data.Select(e => new EventResponseModel(e)); + return new ListResponseModel(responses, result.ContinuationToken); + } + + [HttpGet("~/ciphers/{id}/events")] + public async Task> GetCipher(string id, + [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) + { + var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); + if (cipher == null) { - _userService = userService; - _cipherRepository = cipherRepository; - _organizationUserRepository = organizationUserRepository; - _providerUserRepository = providerUserRepository; - _eventRepository = eventRepository; - _currentContext = currentContext; + throw new NotFoundException(); } - [HttpGet("")] - public async Task> GetUser( - [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) + var canView = false; + if (cipher.OrganizationId.HasValue) + { + canView = await _currentContext.AccessEventLogs(cipher.OrganizationId.Value); + } + else if (cipher.UserId.HasValue) { - var dateRange = GetDateRange(start, end); var userId = _userService.GetProperUserId(User).Value; - var result = await _eventRepository.GetManyByUserAsync(userId, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = continuationToken }); - var responses = result.Data.Select(e => new EventResponseModel(e)); - return new ListResponseModel(responses, result.ContinuationToken); + canView = userId == cipher.UserId.Value; } - [HttpGet("~/ciphers/{id}/events")] - public async Task> GetCipher(string id, - [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) + if (!canView) { - var cipher = await _cipherRepository.GetByIdAsync(new Guid(id)); - if (cipher == null) - { - throw new NotFoundException(); - } - - var canView = false; - if (cipher.OrganizationId.HasValue) - { - canView = await _currentContext.AccessEventLogs(cipher.OrganizationId.Value); - } - else if (cipher.UserId.HasValue) - { - var userId = _userService.GetProperUserId(User).Value; - canView = userId == cipher.UserId.Value; - } - - if (!canView) - { - throw new NotFoundException(); - } - - var dateRange = GetDateRange(start, end); - var result = await _eventRepository.GetManyByCipherAsync(cipher, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = continuationToken }); - var responses = result.Data.Select(e => new EventResponseModel(e)); - return new ListResponseModel(responses, result.ContinuationToken); + throw new NotFoundException(); } - [HttpGet("~/organizations/{id}/events")] - public async Task> GetOrganization(string id, - [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) + var dateRange = GetDateRange(start, end); + var result = await _eventRepository.GetManyByCipherAsync(cipher, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = continuationToken }); + var responses = result.Data.Select(e => new EventResponseModel(e)); + return new ListResponseModel(responses, result.ContinuationToken); + } + + [HttpGet("~/organizations/{id}/events")] + public async Task> GetOrganization(string id, + [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) + { + var orgId = new Guid(id); + if (!await _currentContext.AccessEventLogs(orgId)) { - var orgId = new Guid(id); - if (!await _currentContext.AccessEventLogs(orgId)) - { - throw new NotFoundException(); - } - - var dateRange = GetDateRange(start, end); - var result = await _eventRepository.GetManyByOrganizationAsync(orgId, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = continuationToken }); - var responses = result.Data.Select(e => new EventResponseModel(e)); - return new ListResponseModel(responses, result.ContinuationToken); + throw new NotFoundException(); } - [HttpGet("~/organizations/{orgId}/users/{id}/events")] - public async Task> GetOrganizationUser(string orgId, string id, - [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) + var dateRange = GetDateRange(start, end); + var result = await _eventRepository.GetManyByOrganizationAsync(orgId, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = continuationToken }); + var responses = result.Data.Select(e => new EventResponseModel(e)); + return new ListResponseModel(responses, result.ContinuationToken); + } + + [HttpGet("~/organizations/{orgId}/users/{id}/events")] + public async Task> GetOrganizationUser(string orgId, string id, + [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) + { + var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); + if (organizationUser == null || !organizationUser.UserId.HasValue || + !await _currentContext.AccessEventLogs(organizationUser.OrganizationId)) { - var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); - if (organizationUser == null || !organizationUser.UserId.HasValue || - !await _currentContext.AccessEventLogs(organizationUser.OrganizationId)) - { - throw new NotFoundException(); - } - - var dateRange = GetDateRange(start, end); - var result = await _eventRepository.GetManyByOrganizationActingUserAsync(organizationUser.OrganizationId, - organizationUser.UserId.Value, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = continuationToken }); - var responses = result.Data.Select(e => new EventResponseModel(e)); - return new ListResponseModel(responses, result.ContinuationToken); + throw new NotFoundException(); } - [HttpGet("~/providers/{providerId:guid}/events")] - public async Task> GetProvider(Guid providerId, - [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) + var dateRange = GetDateRange(start, end); + var result = await _eventRepository.GetManyByOrganizationActingUserAsync(organizationUser.OrganizationId, + organizationUser.UserId.Value, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = continuationToken }); + var responses = result.Data.Select(e => new EventResponseModel(e)); + return new ListResponseModel(responses, result.ContinuationToken); + } + + [HttpGet("~/providers/{providerId:guid}/events")] + public async Task> GetProvider(Guid providerId, + [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) + { + if (!_currentContext.ProviderAccessEventLogs(providerId)) { - if (!_currentContext.ProviderAccessEventLogs(providerId)) - { - throw new NotFoundException(); - } - - var dateRange = GetDateRange(start, end); - var result = await _eventRepository.GetManyByProviderAsync(providerId, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = continuationToken }); - var responses = result.Data.Select(e => new EventResponseModel(e)); - return new ListResponseModel(responses, result.ContinuationToken); + throw new NotFoundException(); } - [HttpGet("~/providers/{providerId:guid}/users/{id:guid}/events")] - public async Task> GetProviderUser(Guid providerId, Guid id, - [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) + var dateRange = GetDateRange(start, end); + var result = await _eventRepository.GetManyByProviderAsync(providerId, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = continuationToken }); + var responses = result.Data.Select(e => new EventResponseModel(e)); + return new ListResponseModel(responses, result.ContinuationToken); + } + + [HttpGet("~/providers/{providerId:guid}/users/{id:guid}/events")] + public async Task> GetProviderUser(Guid providerId, Guid id, + [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) + { + var providerUser = await _providerUserRepository.GetByIdAsync(id); + if (providerUser == null || !providerUser.UserId.HasValue || + !_currentContext.ProviderAccessEventLogs(providerUser.ProviderId)) { - var providerUser = await _providerUserRepository.GetByIdAsync(id); - if (providerUser == null || !providerUser.UserId.HasValue || - !_currentContext.ProviderAccessEventLogs(providerUser.ProviderId)) - { - throw new NotFoundException(); - } - - var dateRange = GetDateRange(start, end); - var result = await _eventRepository.GetManyByProviderActingUserAsync(providerUser.ProviderId, - providerUser.UserId.Value, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = continuationToken }); - var responses = result.Data.Select(e => new EventResponseModel(e)); - return new ListResponseModel(responses, result.ContinuationToken); + throw new NotFoundException(); } - private Tuple GetDateRange(DateTime? start, DateTime? end) + var dateRange = GetDateRange(start, end); + var result = await _eventRepository.GetManyByProviderActingUserAsync(providerUser.ProviderId, + providerUser.UserId.Value, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = continuationToken }); + var responses = result.Data.Select(e => new EventResponseModel(e)); + return new ListResponseModel(responses, result.ContinuationToken); + } + + private Tuple GetDateRange(DateTime? start, DateTime? end) + { + if (!end.HasValue || !start.HasValue) { - if (!end.HasValue || !start.HasValue) - { - end = DateTime.UtcNow.Date.AddDays(1).AddMilliseconds(-1); - start = DateTime.UtcNow.Date.AddDays(-30); - } - else if (start.Value > end.Value) - { - var newEnd = start; - start = end; - end = newEnd; - } - - if ((end.Value - start.Value) > TimeSpan.FromDays(367)) - { - throw new BadRequestException("Range too large."); - } - - return new Tuple(start.Value, end.Value); + end = DateTime.UtcNow.Date.AddDays(1).AddMilliseconds(-1); + start = DateTime.UtcNow.Date.AddDays(-30); } + else if (start.Value > end.Value) + { + var newEnd = start; + start = end; + end = newEnd; + } + + if ((end.Value - start.Value) > TimeSpan.FromDays(367)) + { + throw new BadRequestException("Range too large."); + } + + return new Tuple(start.Value, end.Value); } } diff --git a/src/Api/Controllers/FoldersController.cs b/src/Api/Controllers/FoldersController.cs index 752856302..b387809ec 100644 --- a/src/Api/Controllers/FoldersController.cs +++ b/src/Api/Controllers/FoldersController.cs @@ -6,84 +6,83 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("folders")] +[Authorize("Application")] +public class FoldersController : Controller { - [Route("folders")] - [Authorize("Application")] - public class FoldersController : Controller + private readonly IFolderRepository _folderRepository; + private readonly ICipherService _cipherService; + private readonly IUserService _userService; + + public FoldersController( + IFolderRepository folderRepository, + ICipherService cipherService, + IUserService userService) { - private readonly IFolderRepository _folderRepository; - private readonly ICipherService _cipherService; - private readonly IUserService _userService; + _folderRepository = folderRepository; + _cipherService = cipherService; + _userService = userService; + } - public FoldersController( - IFolderRepository folderRepository, - ICipherService cipherService, - IUserService userService) + [HttpGet("{id}")] + public async Task Get(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var folder = await _folderRepository.GetByIdAsync(new Guid(id), userId); + if (folder == null) { - _folderRepository = folderRepository; - _cipherService = cipherService; - _userService = userService; + throw new NotFoundException(); } - [HttpGet("{id}")] - public async Task Get(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var folder = await _folderRepository.GetByIdAsync(new Guid(id), userId); - if (folder == null) - { - throw new NotFoundException(); - } + return new FolderResponseModel(folder); + } - return new FolderResponseModel(folder); + [HttpGet("")] + public async Task> Get() + { + var userId = _userService.GetProperUserId(User).Value; + var folders = await _folderRepository.GetManyByUserIdAsync(userId); + var responses = folders.Select(f => new FolderResponseModel(f)); + return new ListResponseModel(responses); + } + + [HttpPost("")] + public async Task Post([FromBody] FolderRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var folder = model.ToFolder(_userService.GetProperUserId(User).Value); + await _cipherService.SaveFolderAsync(folder); + return new FolderResponseModel(folder); + } + + [HttpPut("{id}")] + [HttpPost("{id}")] + public async Task Put(string id, [FromBody] FolderRequestModel model) + { + var userId = _userService.GetProperUserId(User).Value; + var folder = await _folderRepository.GetByIdAsync(new Guid(id), userId); + if (folder == null) + { + throw new NotFoundException(); } - [HttpGet("")] - public async Task> Get() + await _cipherService.SaveFolderAsync(model.ToFolder(folder)); + return new FolderResponseModel(folder); + } + + [HttpDelete("{id}")] + [HttpPost("{id}/delete")] + public async Task Delete(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var folder = await _folderRepository.GetByIdAsync(new Guid(id), userId); + if (folder == null) { - var userId = _userService.GetProperUserId(User).Value; - var folders = await _folderRepository.GetManyByUserIdAsync(userId); - var responses = folders.Select(f => new FolderResponseModel(f)); - return new ListResponseModel(responses); + throw new NotFoundException(); } - [HttpPost("")] - public async Task Post([FromBody] FolderRequestModel model) - { - var userId = _userService.GetProperUserId(User).Value; - var folder = model.ToFolder(_userService.GetProperUserId(User).Value); - await _cipherService.SaveFolderAsync(folder); - return new FolderResponseModel(folder); - } - - [HttpPut("{id}")] - [HttpPost("{id}")] - public async Task Put(string id, [FromBody] FolderRequestModel model) - { - var userId = _userService.GetProperUserId(User).Value; - var folder = await _folderRepository.GetByIdAsync(new Guid(id), userId); - if (folder == null) - { - throw new NotFoundException(); - } - - await _cipherService.SaveFolderAsync(model.ToFolder(folder)); - return new FolderResponseModel(folder); - } - - [HttpDelete("{id}")] - [HttpPost("{id}/delete")] - public async Task Delete(string id) - { - var userId = _userService.GetProperUserId(User).Value; - var folder = await _folderRepository.GetByIdAsync(new Guid(id), userId); - if (folder == null) - { - throw new NotFoundException(); - } - - await _cipherService.DeleteFolderAsync(folder); - } + await _cipherService.DeleteFolderAsync(folder); } } diff --git a/src/Api/Controllers/GroupsController.cs b/src/Api/Controllers/GroupsController.cs index fc226c4d1..d38ba03bc 100644 --- a/src/Api/Controllers/GroupsController.cs +++ b/src/Api/Controllers/GroupsController.cs @@ -7,146 +7,145 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("organizations/{orgId}/groups")] +[Authorize("Application")] +public class GroupsController : Controller { - [Route("organizations/{orgId}/groups")] - [Authorize("Application")] - public class GroupsController : Controller + private readonly IGroupRepository _groupRepository; + private readonly IGroupService _groupService; + private readonly ICurrentContext _currentContext; + + public GroupsController( + IGroupRepository groupRepository, + IGroupService groupService, + ICurrentContext currentContext) { - private readonly IGroupRepository _groupRepository; - private readonly IGroupService _groupService; - private readonly ICurrentContext _currentContext; + _groupRepository = groupRepository; + _groupService = groupService; + _currentContext = currentContext; + } - public GroupsController( - IGroupRepository groupRepository, - IGroupService groupService, - ICurrentContext currentContext) + [HttpGet("{id}")] + public async Task Get(string orgId, string id) + { + var group = await _groupRepository.GetByIdAsync(new Guid(id)); + if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) { - _groupRepository = groupRepository; - _groupService = groupService; - _currentContext = currentContext; + throw new NotFoundException(); } - [HttpGet("{id}")] - public async Task Get(string orgId, string id) - { - var group = await _groupRepository.GetByIdAsync(new Guid(id)); - if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) - { - throw new NotFoundException(); - } + return new GroupResponseModel(group); + } - return new GroupResponseModel(group); + [HttpGet("{id}/details")] + public async Task GetDetails(string orgId, string id) + { + var groupDetails = await _groupRepository.GetByIdWithCollectionsAsync(new Guid(id)); + if (groupDetails?.Item1 == null || !await _currentContext.ManageGroups(groupDetails.Item1.OrganizationId)) + { + throw new NotFoundException(); } - [HttpGet("{id}/details")] - public async Task GetDetails(string orgId, string id) - { - var groupDetails = await _groupRepository.GetByIdWithCollectionsAsync(new Guid(id)); - if (groupDetails?.Item1 == null || !await _currentContext.ManageGroups(groupDetails.Item1.OrganizationId)) - { - throw new NotFoundException(); - } + return new GroupDetailsResponseModel(groupDetails.Item1, groupDetails.Item2); + } - return new GroupDetailsResponseModel(groupDetails.Item1, groupDetails.Item2); + [HttpGet("")] + public async Task> Get(string orgId) + { + var orgIdGuid = new Guid(orgId); + var canAccess = await _currentContext.ManageGroups(orgIdGuid) || + await _currentContext.ViewAssignedCollections(orgIdGuid) || + await _currentContext.ViewAllCollections(orgIdGuid) || + await _currentContext.ManageUsers(orgIdGuid); + + if (!canAccess) + { + throw new NotFoundException(); } - [HttpGet("")] - public async Task> Get(string orgId) + var groups = await _groupRepository.GetManyByOrganizationIdAsync(orgIdGuid); + var responses = groups.Select(g => new GroupResponseModel(g)); + return new ListResponseModel(responses); + } + + [HttpGet("{id}/users")] + public async Task> GetUsers(string orgId, string id) + { + var idGuid = new Guid(id); + var group = await _groupRepository.GetByIdAsync(idGuid); + if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) { - var orgIdGuid = new Guid(orgId); - var canAccess = await _currentContext.ManageGroups(orgIdGuid) || - await _currentContext.ViewAssignedCollections(orgIdGuid) || - await _currentContext.ViewAllCollections(orgIdGuid) || - await _currentContext.ManageUsers(orgIdGuid); - - if (!canAccess) - { - throw new NotFoundException(); - } - - var groups = await _groupRepository.GetManyByOrganizationIdAsync(orgIdGuid); - var responses = groups.Select(g => new GroupResponseModel(g)); - return new ListResponseModel(responses); + throw new NotFoundException(); } - [HttpGet("{id}/users")] - public async Task> GetUsers(string orgId, string id) - { - var idGuid = new Guid(id); - var group = await _groupRepository.GetByIdAsync(idGuid); - if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) - { - throw new NotFoundException(); - } + var groupIds = await _groupRepository.GetManyUserIdsByIdAsync(idGuid); + return groupIds; + } - var groupIds = await _groupRepository.GetManyUserIdsByIdAsync(idGuid); - return groupIds; + [HttpPost("")] + public async Task Post(string orgId, [FromBody] GroupRequestModel model) + { + var orgIdGuid = new Guid(orgId); + if (!await _currentContext.ManageGroups(orgIdGuid)) + { + throw new NotFoundException(); } - [HttpPost("")] - public async Task Post(string orgId, [FromBody] GroupRequestModel model) - { - var orgIdGuid = new Guid(orgId); - if (!await _currentContext.ManageGroups(orgIdGuid)) - { - throw new NotFoundException(); - } + var group = model.ToGroup(orgIdGuid); + await _groupService.SaveAsync(group, model.Collections?.Select(c => c.ToSelectionReadOnly())); + return new GroupResponseModel(group); + } - var group = model.ToGroup(orgIdGuid); - await _groupService.SaveAsync(group, model.Collections?.Select(c => c.ToSelectionReadOnly())); - return new GroupResponseModel(group); + [HttpPut("{id}")] + [HttpPost("{id}")] + public async Task Put(string orgId, string id, [FromBody] GroupRequestModel model) + { + var group = await _groupRepository.GetByIdAsync(new Guid(id)); + if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) + { + throw new NotFoundException(); } - [HttpPut("{id}")] - [HttpPost("{id}")] - public async Task Put(string orgId, string id, [FromBody] GroupRequestModel model) - { - var group = await _groupRepository.GetByIdAsync(new Guid(id)); - if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) - { - throw new NotFoundException(); - } + await _groupService.SaveAsync(model.ToGroup(group), model.Collections?.Select(c => c.ToSelectionReadOnly())); + return new GroupResponseModel(group); + } - await _groupService.SaveAsync(model.ToGroup(group), model.Collections?.Select(c => c.ToSelectionReadOnly())); - return new GroupResponseModel(group); + [HttpPut("{id}/users")] + public async Task PutUsers(string orgId, string id, [FromBody] IEnumerable model) + { + var group = await _groupRepository.GetByIdAsync(new Guid(id)); + if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) + { + throw new NotFoundException(); + } + await _groupRepository.UpdateUsersAsync(group.Id, model); + } + + [HttpDelete("{id}")] + [HttpPost("{id}/delete")] + public async Task Delete(string orgId, string id) + { + var group = await _groupRepository.GetByIdAsync(new Guid(id)); + if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) + { + throw new NotFoundException(); } - [HttpPut("{id}/users")] - public async Task PutUsers(string orgId, string id, [FromBody] IEnumerable model) + await _groupService.DeleteAsync(group); + } + + [HttpDelete("{id}/user/{orgUserId}")] + [HttpPost("{id}/delete-user/{orgUserId}")] + public async Task Delete(string orgId, string id, string orgUserId) + { + var group = await _groupRepository.GetByIdAsync(new Guid(id)); + if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) { - var group = await _groupRepository.GetByIdAsync(new Guid(id)); - if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) - { - throw new NotFoundException(); - } - await _groupRepository.UpdateUsersAsync(group.Id, model); + throw new NotFoundException(); } - [HttpDelete("{id}")] - [HttpPost("{id}/delete")] - public async Task Delete(string orgId, string id) - { - var group = await _groupRepository.GetByIdAsync(new Guid(id)); - if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) - { - throw new NotFoundException(); - } - - await _groupService.DeleteAsync(group); - } - - [HttpDelete("{id}/user/{orgUserId}")] - [HttpPost("{id}/delete-user/{orgUserId}")] - public async Task Delete(string orgId, string id, string orgUserId) - { - var group = await _groupRepository.GetByIdAsync(new Guid(id)); - if (group == null || !await _currentContext.ManageGroups(group.OrganizationId)) - { - throw new NotFoundException(); - } - - await _groupService.DeleteUserAsync(group, new Guid(orgUserId)); - } + await _groupService.DeleteUserAsync(group, new Guid(orgUserId)); } } diff --git a/src/Api/Controllers/HibpController.cs b/src/Api/Controllers/HibpController.cs index 3b94901b0..517ffb5ef 100644 --- a/src/Api/Controllers/HibpController.cs +++ b/src/Api/Controllers/HibpController.cs @@ -8,91 +8,90 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("hibp")] +[Authorize("Application")] +public class HibpController : Controller { - [Route("hibp")] - [Authorize("Application")] - public class HibpController : Controller + private const string HibpBreachApi = "https://haveibeenpwned.com/api/v3/breachedaccount/{0}" + + "?truncateResponse=false&includeUnverified=false"; + private static HttpClient _httpClient; + + private readonly IUserService _userService; + private readonly ICurrentContext _currentContext; + private readonly GlobalSettings _globalSettings; + private readonly string _userAgent; + + static HibpController() { - private const string HibpBreachApi = "https://haveibeenpwned.com/api/v3/breachedaccount/{0}" + - "?truncateResponse=false&includeUnverified=false"; - private static HttpClient _httpClient; + _httpClient = new HttpClient(); + } - private readonly IUserService _userService; - private readonly ICurrentContext _currentContext; - private readonly GlobalSettings _globalSettings; - private readonly string _userAgent; + public HibpController( + IUserService userService, + ICurrentContext currentContext, + GlobalSettings globalSettings) + { + _userService = userService; + _currentContext = currentContext; + _globalSettings = globalSettings; + _userAgent = _globalSettings.SelfHosted ? "Bitwarden Self-Hosted" : "Bitwarden"; + } - static HibpController() + [HttpGet("breach")] + public async Task Get(string username) + { + return await SendAsync(WebUtility.UrlEncode(username), true); + } + + private async Task SendAsync(string username, bool retry) + { + if (!CoreHelpers.SettingHasValue(_globalSettings.HibpApiKey)) { - _httpClient = new HttpClient(); + throw new BadRequestException("HaveIBeenPwned API key not set."); } - - public HibpController( - IUserService userService, - ICurrentContext currentContext, - GlobalSettings globalSettings) + var request = new HttpRequestMessage(HttpMethod.Get, string.Format(HibpBreachApi, username)); + request.Headers.Add("hibp-api-key", _globalSettings.HibpApiKey); + request.Headers.Add("hibp-client-id", GetClientId()); + request.Headers.Add("User-Agent", _userAgent); + var response = await _httpClient.SendAsync(request); + if (response.IsSuccessStatusCode) { - _userService = userService; - _currentContext = currentContext; - _globalSettings = globalSettings; - _userAgent = _globalSettings.SelfHosted ? "Bitwarden Self-Hosted" : "Bitwarden"; + var data = await response.Content.ReadAsStringAsync(); + return Content(data, "application/json"); } - - [HttpGet("breach")] - public async Task Get(string username) + else if (response.StatusCode == HttpStatusCode.NotFound) { - return await SendAsync(WebUtility.UrlEncode(username), true); + return new NotFoundResult(); } - - private async Task SendAsync(string username, bool retry) + else if (response.StatusCode == HttpStatusCode.TooManyRequests && retry) { - if (!CoreHelpers.SettingHasValue(_globalSettings.HibpApiKey)) + var delay = 2000; + if (response.Headers.Contains("retry-after")) { - throw new BadRequestException("HaveIBeenPwned API key not set."); - } - var request = new HttpRequestMessage(HttpMethod.Get, string.Format(HibpBreachApi, username)); - request.Headers.Add("hibp-api-key", _globalSettings.HibpApiKey); - request.Headers.Add("hibp-client-id", GetClientId()); - request.Headers.Add("User-Agent", _userAgent); - var response = await _httpClient.SendAsync(request); - if (response.IsSuccessStatusCode) - { - var data = await response.Content.ReadAsStringAsync(); - return Content(data, "application/json"); - } - else if (response.StatusCode == HttpStatusCode.NotFound) - { - return new NotFoundResult(); - } - else if (response.StatusCode == HttpStatusCode.TooManyRequests && retry) - { - var delay = 2000; - if (response.Headers.Contains("retry-after")) + var vals = response.Headers.GetValues("retry-after"); + if (vals.Any() && int.TryParse(vals.FirstOrDefault(), out var secDelay)) { - var vals = response.Headers.GetValues("retry-after"); - if (vals.Any() && int.TryParse(vals.FirstOrDefault(), out var secDelay)) - { - delay = (secDelay * 1000) + 200; - } + delay = (secDelay * 1000) + 200; } - await Task.Delay(delay); - return await SendAsync(username, false); - } - else - { - throw new BadRequestException("Request failed. Status code: " + response.StatusCode); } + await Task.Delay(delay); + return await SendAsync(username, false); } - - private string GetClientId() + else { - var userId = _userService.GetProperUserId(User).Value; - using (var sha256 = SHA256.Create()) - { - var hash = sha256.ComputeHash(userId.ToByteArray()); - return Convert.ToBase64String(hash); - } + throw new BadRequestException("Request failed. Status code: " + response.StatusCode); + } + } + + private string GetClientId() + { + var userId = _userService.GetProperUserId(User).Value; + using (var sha256 = SHA256.Create()) + { + var hash = sha256.ComputeHash(userId.ToByteArray()); + return Convert.ToBase64String(hash); } } } diff --git a/src/Api/Controllers/InfoController.cs b/src/Api/Controllers/InfoController.cs index 206ba6810..739f9f425 100644 --- a/src/Api/Controllers/InfoController.cs +++ b/src/Api/Controllers/InfoController.cs @@ -1,35 +1,34 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +public class InfoController : Controller { - public class InfoController : Controller + [HttpGet("~/alive")] + [HttpGet("~/now")] + public DateTime GetAlive() { - [HttpGet("~/alive")] - [HttpGet("~/now")] - public DateTime GetAlive() - { - return DateTime.UtcNow; - } + return DateTime.UtcNow; + } - [HttpGet("~/version")] - public JsonResult GetVersion() - { - return Json(CoreHelpers.GetVersion()); - } + [HttpGet("~/version")] + public JsonResult GetVersion() + { + return Json(CoreHelpers.GetVersion()); + } - [HttpGet("~/ip")] - public JsonResult Ip() + [HttpGet("~/ip")] + public JsonResult Ip() + { + var headerSet = new HashSet { "x-forwarded-for", "cf-connecting-ip", "client-ip" }; + var headers = HttpContext.Request?.Headers + .Where(h => headerSet.Contains(h.Key.ToLower())) + .ToDictionary(h => h.Key); + return new JsonResult(new { - var headerSet = new HashSet { "x-forwarded-for", "cf-connecting-ip", "client-ip" }; - var headers = HttpContext.Request?.Headers - .Where(h => headerSet.Contains(h.Key.ToLower())) - .ToDictionary(h => h.Key); - return new JsonResult(new - { - Ip = HttpContext.Connection?.RemoteIpAddress?.ToString(), - Headers = headers, - }); - } + Ip = HttpContext.Connection?.RemoteIpAddress?.ToString(), + Headers = headers, + }); } } diff --git a/src/Api/Controllers/InstallationsController.cs b/src/Api/Controllers/InstallationsController.cs index c75468b47..a2eeebab3 100644 --- a/src/Api/Controllers/InstallationsController.cs +++ b/src/Api/Controllers/InstallationsController.cs @@ -6,40 +6,39 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("installations")] +[SelfHosted(NotSelfHostedOnly = true)] +public class InstallationsController : Controller { - [Route("installations")] - [SelfHosted(NotSelfHostedOnly = true)] - public class InstallationsController : Controller + private readonly IInstallationRepository _installationRepository; + + public InstallationsController( + IInstallationRepository installationRepository) { - private readonly IInstallationRepository _installationRepository; + _installationRepository = installationRepository; + } - public InstallationsController( - IInstallationRepository installationRepository) + [HttpGet("{id}")] + [AllowAnonymous] + public async Task Get(Guid id) + { + var installation = await _installationRepository.GetByIdAsync(id); + if (installation == null) { - _installationRepository = installationRepository; + throw new NotFoundException(); } - [HttpGet("{id}")] - [AllowAnonymous] - public async Task Get(Guid id) - { - var installation = await _installationRepository.GetByIdAsync(id); - if (installation == null) - { - throw new NotFoundException(); - } + return new InstallationResponseModel(installation, false); + } - return new InstallationResponseModel(installation, false); - } - - [HttpPost("")] - [AllowAnonymous] - public async Task Post([FromBody] InstallationRequestModel model) - { - var installation = model.ToInstallation(); - await _installationRepository.CreateAsync(installation); - return new InstallationResponseModel(installation, true); - } + [HttpPost("")] + [AllowAnonymous] + public async Task Post([FromBody] InstallationRequestModel model) + { + var installation = model.ToInstallation(); + await _installationRepository.CreateAsync(installation); + return new InstallationResponseModel(installation, true); } } diff --git a/src/Api/Controllers/LicensesController.cs b/src/Api/Controllers/LicensesController.cs index 4de079885..63ed82479 100644 --- a/src/Api/Controllers/LicensesController.cs +++ b/src/Api/Controllers/LicensesController.cs @@ -7,70 +7,69 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("licenses")] +[Authorize("Licensing")] +[SelfHosted(NotSelfHostedOnly = true)] +public class LicensesController : Controller { - [Route("licenses")] - [Authorize("Licensing")] - [SelfHosted(NotSelfHostedOnly = true)] - public class LicensesController : Controller + private readonly ILicensingService _licensingService; + private readonly IUserRepository _userRepository; + private readonly IUserService _userService; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationService _organizationService; + private readonly ICurrentContext _currentContext; + + public LicensesController( + ILicensingService licensingService, + IUserRepository userRepository, + IUserService userService, + IOrganizationRepository organizationRepository, + IOrganizationService organizationService, + ICurrentContext currentContext) { - private readonly ILicensingService _licensingService; - private readonly IUserRepository _userRepository; - private readonly IUserService _userService; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationService _organizationService; - private readonly ICurrentContext _currentContext; + _licensingService = licensingService; + _userRepository = userRepository; + _userService = userService; + _organizationRepository = organizationRepository; + _organizationService = organizationService; + _currentContext = currentContext; + } - public LicensesController( - ILicensingService licensingService, - IUserRepository userRepository, - IUserService userService, - IOrganizationRepository organizationRepository, - IOrganizationService organizationService, - ICurrentContext currentContext) + [HttpGet("user/{id}")] + public async Task GetUser(string id, [FromQuery] string key) + { + var user = await _userRepository.GetByIdAsync(new Guid(id)); + if (user == null) { - _licensingService = licensingService; - _userRepository = userRepository; - _userService = userService; - _organizationRepository = organizationRepository; - _organizationService = organizationService; - _currentContext = currentContext; + return null; + } + else if (!user.LicenseKey.Equals(key)) + { + await Task.Delay(2000); + throw new BadRequestException("Invalid license key."); } - [HttpGet("user/{id}")] - public async Task GetUser(string id, [FromQuery] string key) - { - var user = await _userRepository.GetByIdAsync(new Guid(id)); - if (user == null) - { - return null; - } - else if (!user.LicenseKey.Equals(key)) - { - await Task.Delay(2000); - throw new BadRequestException("Invalid license key."); - } + var license = await _userService.GenerateLicenseAsync(user, null); + return license; + } - var license = await _userService.GenerateLicenseAsync(user, null); - return license; + [HttpGet("organization/{id}")] + public async Task GetOrganization(string id, [FromQuery] string key) + { + var org = await _organizationRepository.GetByIdAsync(new Guid(id)); + if (org == null) + { + return null; + } + else if (!org.LicenseKey.Equals(key)) + { + await Task.Delay(2000); + throw new BadRequestException("Invalid license key."); } - [HttpGet("organization/{id}")] - public async Task GetOrganization(string id, [FromQuery] string key) - { - var org = await _organizationRepository.GetByIdAsync(new Guid(id)); - if (org == null) - { - return null; - } - else if (!org.LicenseKey.Equals(key)) - { - await Task.Delay(2000); - throw new BadRequestException("Invalid license key."); - } - - var license = await _organizationService.GenerateLicenseAsync(org, _currentContext.InstallationId.Value); - return license; - } + var license = await _organizationService.GenerateLicenseAsync(org, _currentContext.InstallationId.Value); + return license; } } diff --git a/src/Api/Controllers/MiscController.cs b/src/Api/Controllers/MiscController.cs index edd4dfde4..6f23a27fb 100644 --- a/src/Api/Controllers/MiscController.cs +++ b/src/Api/Controllers/MiscController.cs @@ -5,42 +5,41 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; using Stripe; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +public class MiscController : Controller { - public class MiscController : Controller + private readonly BitPayClient _bitPayClient; + private readonly GlobalSettings _globalSettings; + + public MiscController( + BitPayClient bitPayClient, + GlobalSettings globalSettings) { - private readonly BitPayClient _bitPayClient; - private readonly GlobalSettings _globalSettings; + _bitPayClient = bitPayClient; + _globalSettings = globalSettings; + } - public MiscController( - BitPayClient bitPayClient, - GlobalSettings globalSettings) - { - _bitPayClient = bitPayClient; - _globalSettings = globalSettings; - } + [Authorize("Application")] + [HttpPost("~/bitpay-invoice")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostBitPayInvoice([FromBody] BitPayInvoiceRequestModel model) + { + var invoice = await _bitPayClient.CreateInvoiceAsync(model.ToBitpayInvoice(_globalSettings)); + return invoice.Url; + } - [Authorize("Application")] - [HttpPost("~/bitpay-invoice")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostBitPayInvoice([FromBody] BitPayInvoiceRequestModel model) + [Authorize("Application")] + [HttpPost("~/setup-payment")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostSetupPayment() + { + var options = new SetupIntentCreateOptions { - var invoice = await _bitPayClient.CreateInvoiceAsync(model.ToBitpayInvoice(_globalSettings)); - return invoice.Url; - } - - [Authorize("Application")] - [HttpPost("~/setup-payment")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostSetupPayment() - { - var options = new SetupIntentCreateOptions - { - Usage = "off_session" - }; - var service = new SetupIntentService(); - var setupIntent = await service.CreateAsync(options); - return setupIntent.ClientSecret; - } + Usage = "off_session" + }; + var service = new SetupIntentService(); + var setupIntent = await service.CreateAsync(options); + return setupIntent.ClientSecret; } } diff --git a/src/Api/Controllers/OrganizationConnectionsController.cs b/src/Api/Controllers/OrganizationConnectionsController.cs index 83f7a6ed7..73754dba7 100644 --- a/src/Api/Controllers/OrganizationConnectionsController.cs +++ b/src/Api/Controllers/OrganizationConnectionsController.cs @@ -12,199 +12,198 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Authorize("Application")] +[Route("organizations/connections")] +public class OrganizationConnectionsController : Controller { - [Authorize("Application")] - [Route("organizations/connections")] - public class OrganizationConnectionsController : Controller + private readonly ICreateOrganizationConnectionCommand _createOrganizationConnectionCommand; + private readonly IUpdateOrganizationConnectionCommand _updateOrganizationConnectionCommand; + private readonly IDeleteOrganizationConnectionCommand _deleteOrganizationConnectionCommand; + private readonly IOrganizationConnectionRepository _organizationConnectionRepository; + private readonly ICurrentContext _currentContext; + private readonly IGlobalSettings _globalSettings; + private readonly ILicensingService _licensingService; + + public OrganizationConnectionsController( + ICreateOrganizationConnectionCommand createOrganizationConnectionCommand, + IUpdateOrganizationConnectionCommand updateOrganizationConnectionCommand, + IDeleteOrganizationConnectionCommand deleteOrganizationConnectionCommand, + IOrganizationConnectionRepository organizationConnectionRepository, + ICurrentContext currentContext, + IGlobalSettings globalSettings, + ILicensingService licensingService) { - private readonly ICreateOrganizationConnectionCommand _createOrganizationConnectionCommand; - private readonly IUpdateOrganizationConnectionCommand _updateOrganizationConnectionCommand; - private readonly IDeleteOrganizationConnectionCommand _deleteOrganizationConnectionCommand; - private readonly IOrganizationConnectionRepository _organizationConnectionRepository; - private readonly ICurrentContext _currentContext; - private readonly IGlobalSettings _globalSettings; - private readonly ILicensingService _licensingService; + _createOrganizationConnectionCommand = createOrganizationConnectionCommand; + _updateOrganizationConnectionCommand = updateOrganizationConnectionCommand; + _deleteOrganizationConnectionCommand = deleteOrganizationConnectionCommand; + _organizationConnectionRepository = organizationConnectionRepository; + _currentContext = currentContext; + _globalSettings = globalSettings; + _licensingService = licensingService; + } - public OrganizationConnectionsController( - ICreateOrganizationConnectionCommand createOrganizationConnectionCommand, - IUpdateOrganizationConnectionCommand updateOrganizationConnectionCommand, - IDeleteOrganizationConnectionCommand deleteOrganizationConnectionCommand, - IOrganizationConnectionRepository organizationConnectionRepository, - ICurrentContext currentContext, - IGlobalSettings globalSettings, - ILicensingService licensingService) + [HttpGet("enabled")] + public bool ConnectionsEnabled() + { + return _globalSettings.SelfHosted && _globalSettings.EnableCloudCommunication; + } + + [HttpPost] + public async Task CreateConnection([FromBody] OrganizationConnectionRequestModel model) + { + if (!await HasPermissionAsync(model?.OrganizationId)) { - _createOrganizationConnectionCommand = createOrganizationConnectionCommand; - _updateOrganizationConnectionCommand = updateOrganizationConnectionCommand; - _deleteOrganizationConnectionCommand = deleteOrganizationConnectionCommand; - _organizationConnectionRepository = organizationConnectionRepository; - _currentContext = currentContext; - _globalSettings = globalSettings; - _licensingService = licensingService; + throw new BadRequestException($"You do not have permission to create a connection of type {model.Type}."); } - [HttpGet("enabled")] - public bool ConnectionsEnabled() + if (await HasConnectionTypeAsync(model, null, model.Type)) { - return _globalSettings.SelfHosted && _globalSettings.EnableCloudCommunication; + throw new BadRequestException($"The requested organization already has a connection of type {model.Type}. Only one of each connection type may exist per organization."); } - [HttpPost] - public async Task CreateConnection([FromBody] OrganizationConnectionRequestModel model) + switch (model.Type) { - if (!await HasPermissionAsync(model?.OrganizationId)) - { - throw new BadRequestException($"You do not have permission to create a connection of type {model.Type}."); - } - - if (await HasConnectionTypeAsync(model, null, model.Type)) - { - throw new BadRequestException($"The requested organization already has a connection of type {model.Type}. Only one of each connection type may exist per organization."); - } - - switch (model.Type) - { - case OrganizationConnectionType.CloudBillingSync: - return await CreateOrUpdateOrganizationConnectionAsync(null, model, ValidateBillingSyncConfig); - case OrganizationConnectionType.Scim: - return await CreateOrUpdateOrganizationConnectionAsync(null, model); - default: - throw new BadRequestException($"Unknown Organization connection Type: {model.Type}"); - } - } - - [HttpPut("{organizationConnectionId}")] - public async Task UpdateConnection(Guid organizationConnectionId, [FromBody] OrganizationConnectionRequestModel model) - { - var existingOrganizationConnection = await _organizationConnectionRepository.GetByIdAsync(organizationConnectionId); - if (existingOrganizationConnection == null) - { - throw new NotFoundException(); - } - - if (!await HasPermissionAsync(model?.OrganizationId, model?.Type)) - { - throw new BadRequestException("You do not have permission to update this connection."); - } - - if (await HasConnectionTypeAsync(model, organizationConnectionId, model.Type)) - { - throw new BadRequestException($"The requested organization already has a connection of type {model.Type}. Only one of each connection type may exist per organization."); - } - - switch (model.Type) - { - case OrganizationConnectionType.CloudBillingSync: - return await CreateOrUpdateOrganizationConnectionAsync(organizationConnectionId, model); - case OrganizationConnectionType.Scim: - return await CreateOrUpdateOrganizationConnectionAsync(organizationConnectionId, model); - default: - throw new BadRequestException($"Unkown Organization connection Type: {model.Type}"); - } - } - - [HttpGet("{organizationId}/{type}")] - public async Task GetConnection(Guid organizationId, OrganizationConnectionType type) - { - if (!await HasPermissionAsync(organizationId, type)) - { - throw new BadRequestException($"You do not have permission to retrieve a connection of type {type}."); - } - - var connections = await GetConnectionsAsync(organizationId, type); - var connection = connections.FirstOrDefault(c => c.Type == type); - - switch (type) - { - case OrganizationConnectionType.CloudBillingSync: - if (!_globalSettings.SelfHosted) - { - throw new BadRequestException($"Cannot get a {type} connection outside of a self-hosted instance."); - } - return new OrganizationConnectionResponseModel(connection, typeof(BillingSyncConfig)); - case OrganizationConnectionType.Scim: - return new OrganizationConnectionResponseModel(connection, typeof(ScimConfig)); - default: - throw new BadRequestException($"Unkown Organization connection Type: {type}"); - } - } - - [HttpDelete("{organizationConnectionId}")] - [HttpPost("{organizationConnectionId}/delete")] - public async Task DeleteConnection(Guid organizationConnectionId) - { - var connection = await _organizationConnectionRepository.GetByIdAsync(organizationConnectionId); - - if (connection == null) - { - throw new NotFoundException(); - } - - if (!await HasPermissionAsync(connection.OrganizationId, connection.Type)) - { - throw new BadRequestException($"You do not have permission to remove this connection of type {connection.Type}."); - } - - await _deleteOrganizationConnectionCommand.DeleteAsync(connection); - } - - private async Task> GetConnectionsAsync(Guid organizationId, OrganizationConnectionType type) => - await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(organizationId, type); - - private async Task HasConnectionTypeAsync(OrganizationConnectionRequestModel model, Guid? connectionId, - OrganizationConnectionType type) - { - var existingConnections = await GetConnectionsAsync(model.OrganizationId, type); - - return existingConnections.Any(c => c.Type == model.Type && (!connectionId.HasValue || c.Id != connectionId.Value)); - } - - private async Task HasPermissionAsync(Guid? organizationId, OrganizationConnectionType? type = null) - { - if (!organizationId.HasValue) - { - return false; - } - return type switch - { - OrganizationConnectionType.Scim => await _currentContext.ManageScim(organizationId.Value), - _ => await _currentContext.OrganizationOwner(organizationId.Value), - }; - } - - private async Task ValidateBillingSyncConfig(OrganizationConnectionRequestModel typedModel) - { - if (!_globalSettings.SelfHosted) - { - throw new BadRequestException($"Cannot create a {typedModel.Type} connection outside of a self-hosted instance."); - } - var license = await _licensingService.ReadOrganizationLicenseAsync(typedModel.OrganizationId); - if (!_licensingService.VerifyLicense(license)) - { - throw new BadRequestException("Cannot verify license file."); - } - typedModel.ParsedConfig.CloudOrganizationId = license.Id; - } - - private async Task CreateOrUpdateOrganizationConnectionAsync( - Guid? organizationConnectionId, - OrganizationConnectionRequestModel model, - Func, Task> validateAction = null) - where T : new() - { - var typedModel = new OrganizationConnectionRequestModel(model); - if (validateAction != null) - { - await validateAction(typedModel); - } - - var data = typedModel.ToData(organizationConnectionId); - var connection = organizationConnectionId.HasValue - ? await _updateOrganizationConnectionCommand.UpdateAsync(data) - : await _createOrganizationConnectionCommand.CreateAsync(data); - - return new OrganizationConnectionResponseModel(connection, typeof(T)); + case OrganizationConnectionType.CloudBillingSync: + return await CreateOrUpdateOrganizationConnectionAsync(null, model, ValidateBillingSyncConfig); + case OrganizationConnectionType.Scim: + return await CreateOrUpdateOrganizationConnectionAsync(null, model); + default: + throw new BadRequestException($"Unknown Organization connection Type: {model.Type}"); } } + + [HttpPut("{organizationConnectionId}")] + public async Task UpdateConnection(Guid organizationConnectionId, [FromBody] OrganizationConnectionRequestModel model) + { + var existingOrganizationConnection = await _organizationConnectionRepository.GetByIdAsync(organizationConnectionId); + if (existingOrganizationConnection == null) + { + throw new NotFoundException(); + } + + if (!await HasPermissionAsync(model?.OrganizationId, model?.Type)) + { + throw new BadRequestException("You do not have permission to update this connection."); + } + + if (await HasConnectionTypeAsync(model, organizationConnectionId, model.Type)) + { + throw new BadRequestException($"The requested organization already has a connection of type {model.Type}. Only one of each connection type may exist per organization."); + } + + switch (model.Type) + { + case OrganizationConnectionType.CloudBillingSync: + return await CreateOrUpdateOrganizationConnectionAsync(organizationConnectionId, model); + case OrganizationConnectionType.Scim: + return await CreateOrUpdateOrganizationConnectionAsync(organizationConnectionId, model); + default: + throw new BadRequestException($"Unkown Organization connection Type: {model.Type}"); + } + } + + [HttpGet("{organizationId}/{type}")] + public async Task GetConnection(Guid organizationId, OrganizationConnectionType type) + { + if (!await HasPermissionAsync(organizationId, type)) + { + throw new BadRequestException($"You do not have permission to retrieve a connection of type {type}."); + } + + var connections = await GetConnectionsAsync(organizationId, type); + var connection = connections.FirstOrDefault(c => c.Type == type); + + switch (type) + { + case OrganizationConnectionType.CloudBillingSync: + if (!_globalSettings.SelfHosted) + { + throw new BadRequestException($"Cannot get a {type} connection outside of a self-hosted instance."); + } + return new OrganizationConnectionResponseModel(connection, typeof(BillingSyncConfig)); + case OrganizationConnectionType.Scim: + return new OrganizationConnectionResponseModel(connection, typeof(ScimConfig)); + default: + throw new BadRequestException($"Unkown Organization connection Type: {type}"); + } + } + + [HttpDelete("{organizationConnectionId}")] + [HttpPost("{organizationConnectionId}/delete")] + public async Task DeleteConnection(Guid organizationConnectionId) + { + var connection = await _organizationConnectionRepository.GetByIdAsync(organizationConnectionId); + + if (connection == null) + { + throw new NotFoundException(); + } + + if (!await HasPermissionAsync(connection.OrganizationId, connection.Type)) + { + throw new BadRequestException($"You do not have permission to remove this connection of type {connection.Type}."); + } + + await _deleteOrganizationConnectionCommand.DeleteAsync(connection); + } + + private async Task> GetConnectionsAsync(Guid organizationId, OrganizationConnectionType type) => + await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(organizationId, type); + + private async Task HasConnectionTypeAsync(OrganizationConnectionRequestModel model, Guid? connectionId, + OrganizationConnectionType type) + { + var existingConnections = await GetConnectionsAsync(model.OrganizationId, type); + + return existingConnections.Any(c => c.Type == model.Type && (!connectionId.HasValue || c.Id != connectionId.Value)); + } + + private async Task HasPermissionAsync(Guid? organizationId, OrganizationConnectionType? type = null) + { + if (!organizationId.HasValue) + { + return false; + } + return type switch + { + OrganizationConnectionType.Scim => await _currentContext.ManageScim(organizationId.Value), + _ => await _currentContext.OrganizationOwner(organizationId.Value), + }; + } + + private async Task ValidateBillingSyncConfig(OrganizationConnectionRequestModel typedModel) + { + if (!_globalSettings.SelfHosted) + { + throw new BadRequestException($"Cannot create a {typedModel.Type} connection outside of a self-hosted instance."); + } + var license = await _licensingService.ReadOrganizationLicenseAsync(typedModel.OrganizationId); + if (!_licensingService.VerifyLicense(license)) + { + throw new BadRequestException("Cannot verify license file."); + } + typedModel.ParsedConfig.CloudOrganizationId = license.Id; + } + + private async Task CreateOrUpdateOrganizationConnectionAsync( + Guid? organizationConnectionId, + OrganizationConnectionRequestModel model, + Func, Task> validateAction = null) + where T : new() + { + var typedModel = new OrganizationConnectionRequestModel(model); + if (validateAction != null) + { + await validateAction(typedModel); + } + + var data = typedModel.ToData(organizationConnectionId); + var connection = organizationConnectionId.HasValue + ? await _updateOrganizationConnectionCommand.UpdateAsync(data) + : await _createOrganizationConnectionCommand.CreateAsync(data); + + return new OrganizationConnectionResponseModel(connection, typeof(T)); + } } diff --git a/src/Api/Controllers/OrganizationExportController.cs b/src/Api/Controllers/OrganizationExportController.cs index dd04fe009..f2a2265f9 100644 --- a/src/Api/Controllers/OrganizationExportController.cs +++ b/src/Api/Controllers/OrganizationExportController.cs @@ -6,59 +6,58 @@ using Core.Models.Data; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("organizations/{organizationId}")] +[Authorize("Application")] +public class OrganizationExportController : Controller { - [Route("organizations/{organizationId}")] - [Authorize("Application")] - public class OrganizationExportController : Controller + private readonly IUserService _userService; + private readonly ICollectionService _collectionService; + private readonly ICipherService _cipherService; + private readonly GlobalSettings _globalSettings; + + public OrganizationExportController( + ICipherService cipherService, + ICollectionService collectionService, + IUserService userService, + GlobalSettings globalSettings) { - private readonly IUserService _userService; - private readonly ICollectionService _collectionService; - private readonly ICipherService _cipherService; - private readonly GlobalSettings _globalSettings; + _cipherService = cipherService; + _collectionService = collectionService; + _userService = userService; + _globalSettings = globalSettings; + } - public OrganizationExportController( - ICipherService cipherService, - ICollectionService collectionService, - IUserService userService, - GlobalSettings globalSettings) + [HttpGet("export")] + public async Task Export(Guid organizationId) + { + var userId = _userService.GetProperUserId(User).Value; + + IEnumerable orgCollections = await _collectionService.GetOrganizationCollections(organizationId); + (IEnumerable orgCiphers, Dictionary> collectionCiphersGroupDict) = await _cipherService.GetOrganizationCiphers(userId, organizationId); + + var result = new OrganizationExportResponseModel { - _cipherService = cipherService; - _collectionService = collectionService; - _userService = userService; - _globalSettings = globalSettings; - } + Collections = GetOrganizationCollectionsResponse(orgCollections), + Ciphers = GetOrganizationCiphersResponse(orgCiphers, collectionCiphersGroupDict) + }; - [HttpGet("export")] - public async Task Export(Guid organizationId) - { - var userId = _userService.GetProperUserId(User).Value; + return result; + } - IEnumerable orgCollections = await _collectionService.GetOrganizationCollections(organizationId); - (IEnumerable orgCiphers, Dictionary> collectionCiphersGroupDict) = await _cipherService.GetOrganizationCiphers(userId, organizationId); + private ListResponseModel GetOrganizationCollectionsResponse(IEnumerable orgCollections) + { + var collections = orgCollections.Select(c => new CollectionResponseModel(c)); + return new ListResponseModel(collections); + } - var result = new OrganizationExportResponseModel - { - Collections = GetOrganizationCollectionsResponse(orgCollections), - Ciphers = GetOrganizationCiphersResponse(orgCiphers, collectionCiphersGroupDict) - }; + private ListResponseModel GetOrganizationCiphersResponse(IEnumerable orgCiphers, + Dictionary> collectionCiphersGroupDict) + { + var responses = orgCiphers.Select(c => new CipherMiniDetailsResponseModel(c, _globalSettings, + collectionCiphersGroupDict, c.OrganizationUseTotp)); - return result; - } - - private ListResponseModel GetOrganizationCollectionsResponse(IEnumerable orgCollections) - { - var collections = orgCollections.Select(c => new CollectionResponseModel(c)); - return new ListResponseModel(collections); - } - - private ListResponseModel GetOrganizationCiphersResponse(IEnumerable orgCiphers, - Dictionary> collectionCiphersGroupDict) - { - var responses = orgCiphers.Select(c => new CipherMiniDetailsResponseModel(c, _globalSettings, - collectionCiphersGroupDict, c.OrganizationUseTotp)); - - return new ListResponseModel(responses); - } + return new ListResponseModel(responses); } } diff --git a/src/Api/Controllers/OrganizationSponsorshipsController.cs b/src/Api/Controllers/OrganizationSponsorshipsController.cs index ae7be386c..fc5d38db1 100644 --- a/src/Api/Controllers/OrganizationSponsorshipsController.cs +++ b/src/Api/Controllers/OrganizationSponsorshipsController.cs @@ -12,180 +12,179 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("organization/sponsorship")] +public class OrganizationSponsorshipsController : Controller { - [Route("organization/sponsorship")] - public class OrganizationSponsorshipsController : Controller + private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IValidateRedemptionTokenCommand _validateRedemptionTokenCommand; + private readonly IValidateBillingSyncKeyCommand _validateBillingSyncKeyCommand; + private readonly ICreateSponsorshipCommand _createSponsorshipCommand; + private readonly ISendSponsorshipOfferCommand _sendSponsorshipOfferCommand; + private readonly ISetUpSponsorshipCommand _setUpSponsorshipCommand; + private readonly IRevokeSponsorshipCommand _revokeSponsorshipCommand; + private readonly IRemoveSponsorshipCommand _removeSponsorshipCommand; + private readonly ICloudSyncSponsorshipsCommand _syncSponsorshipsCommand; + private readonly ICurrentContext _currentContext; + private readonly IUserService _userService; + + public OrganizationSponsorshipsController( + IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IValidateRedemptionTokenCommand validateRedemptionTokenCommand, + IValidateBillingSyncKeyCommand validateBillingSyncKeyCommand, + ICreateSponsorshipCommand createSponsorshipCommand, + ISendSponsorshipOfferCommand sendSponsorshipOfferCommand, + ISetUpSponsorshipCommand setUpSponsorshipCommand, + IRevokeSponsorshipCommand revokeSponsorshipCommand, + IRemoveSponsorshipCommand removeSponsorshipCommand, + ICloudSyncSponsorshipsCommand syncSponsorshipsCommand, + IUserService userService, + ICurrentContext currentContext) { - private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IValidateRedemptionTokenCommand _validateRedemptionTokenCommand; - private readonly IValidateBillingSyncKeyCommand _validateBillingSyncKeyCommand; - private readonly ICreateSponsorshipCommand _createSponsorshipCommand; - private readonly ISendSponsorshipOfferCommand _sendSponsorshipOfferCommand; - private readonly ISetUpSponsorshipCommand _setUpSponsorshipCommand; - private readonly IRevokeSponsorshipCommand _revokeSponsorshipCommand; - private readonly IRemoveSponsorshipCommand _removeSponsorshipCommand; - private readonly ICloudSyncSponsorshipsCommand _syncSponsorshipsCommand; - private readonly ICurrentContext _currentContext; - private readonly IUserService _userService; - - public OrganizationSponsorshipsController( - IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IValidateRedemptionTokenCommand validateRedemptionTokenCommand, - IValidateBillingSyncKeyCommand validateBillingSyncKeyCommand, - ICreateSponsorshipCommand createSponsorshipCommand, - ISendSponsorshipOfferCommand sendSponsorshipOfferCommand, - ISetUpSponsorshipCommand setUpSponsorshipCommand, - IRevokeSponsorshipCommand revokeSponsorshipCommand, - IRemoveSponsorshipCommand removeSponsorshipCommand, - ICloudSyncSponsorshipsCommand syncSponsorshipsCommand, - IUserService userService, - ICurrentContext currentContext) - { - _organizationSponsorshipRepository = organizationSponsorshipRepository; - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _validateRedemptionTokenCommand = validateRedemptionTokenCommand; - _validateBillingSyncKeyCommand = validateBillingSyncKeyCommand; - _createSponsorshipCommand = createSponsorshipCommand; - _sendSponsorshipOfferCommand = sendSponsorshipOfferCommand; - _setUpSponsorshipCommand = setUpSponsorshipCommand; - _revokeSponsorshipCommand = revokeSponsorshipCommand; - _removeSponsorshipCommand = removeSponsorshipCommand; - _syncSponsorshipsCommand = syncSponsorshipsCommand; - _userService = userService; - _currentContext = currentContext; - } - - [Authorize("Application")] - [HttpPost("{sponsoringOrgId}/families-for-enterprise")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task CreateSponsorship(Guid sponsoringOrgId, [FromBody] OrganizationSponsorshipCreateRequestModel model) - { - var sponsoringOrg = await _organizationRepository.GetByIdAsync(sponsoringOrgId); - - var sponsorship = await _createSponsorshipCommand.CreateSponsorshipAsync( - sponsoringOrg, - await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default), - model.PlanSponsorshipType, model.SponsoredEmail, model.FriendlyName); - await _sendSponsorshipOfferCommand.SendSponsorshipOfferAsync(sponsorship, sponsoringOrg.Name); - } - - [Authorize("Application")] - [HttpPost("{sponsoringOrgId}/families-for-enterprise/resend")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task ResendSponsorshipOffer(Guid sponsoringOrgId) - { - var sponsoringOrgUser = await _organizationUserRepository - .GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default); - - await _sendSponsorshipOfferCommand.SendSponsorshipOfferAsync( - await _organizationRepository.GetByIdAsync(sponsoringOrgId), - sponsoringOrgUser, - await _organizationSponsorshipRepository - .GetBySponsoringOrganizationUserIdAsync(sponsoringOrgUser.Id)); - } - - [Authorize("Application")] - [HttpPost("validate-token")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PreValidateSponsorshipToken([FromQuery] string sponsorshipToken) - { - return (await _validateRedemptionTokenCommand.ValidateRedemptionTokenAsync(sponsorshipToken, (await CurrentUser).Email)).valid; - } - - [Authorize("Application")] - [HttpPost("redeem")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task RedeemSponsorship([FromQuery] string sponsorshipToken, [FromBody] OrganizationSponsorshipRedeemRequestModel model) - { - var (valid, sponsorship) = await _validateRedemptionTokenCommand.ValidateRedemptionTokenAsync(sponsorshipToken, (await CurrentUser).Email); - - if (!valid) - { - throw new BadRequestException("Failed to parse sponsorship token."); - } - - if (!await _currentContext.OrganizationOwner(model.SponsoredOrganizationId)) - { - throw new BadRequestException("Can only redeem sponsorship for an organization you own."); - } - - await _setUpSponsorshipCommand.SetUpSponsorshipAsync( - sponsorship, - await _organizationRepository.GetByIdAsync(model.SponsoredOrganizationId)); - } - - [Authorize("Installation")] - [HttpPost("sync")] - public async Task Sync([FromBody] OrganizationSponsorshipSyncRequestModel model) - { - var sponsoringOrg = await _organizationRepository.GetByIdAsync(model.SponsoringOrganizationCloudId); - if (!await _validateBillingSyncKeyCommand.ValidateBillingSyncKeyAsync(sponsoringOrg, model.BillingSyncKey)) - { - throw new BadRequestException("Invalid Billing Sync Key"); - } - - var (syncResponseData, offersToSend) = await _syncSponsorshipsCommand.SyncOrganization(sponsoringOrg, model.ToOrganizationSponsorshipSync().SponsorshipsBatch); - await _sendSponsorshipOfferCommand.BulkSendSponsorshipOfferAsync(sponsoringOrg.Name, offersToSend); - return new OrganizationSponsorshipSyncResponseModel(syncResponseData); - } - - [Authorize("Application")] - [HttpDelete("{sponsoringOrganizationId}")] - [HttpPost("{sponsoringOrganizationId}/delete")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task RevokeSponsorship(Guid sponsoringOrganizationId) - { - - var orgUser = await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrganizationId, _currentContext.UserId ?? default); - if (_currentContext.UserId != orgUser?.UserId) - { - throw new BadRequestException("Can only revoke a sponsorship you granted."); - } - - var existingOrgSponsorship = await _organizationSponsorshipRepository - .GetBySponsoringOrganizationUserIdAsync(orgUser.Id); - - await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship); - } - - [Authorize("Application")] - [HttpDelete("sponsored/{sponsoredOrgId}")] - [HttpPost("sponsored/{sponsoredOrgId}/remove")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task RemoveSponsorship(Guid sponsoredOrgId) - { - - if (!await _currentContext.OrganizationOwner(sponsoredOrgId)) - { - throw new BadRequestException("Only the owner of an organization can remove sponsorship."); - } - - var existingOrgSponsorship = await _organizationSponsorshipRepository - .GetBySponsoredOrganizationIdAsync(sponsoredOrgId); - - await _removeSponsorshipCommand.RemoveSponsorshipAsync(existingOrgSponsorship); - } - - [HttpGet("{sponsoringOrgId}/sync-status")] - public async Task GetSyncStatus(Guid sponsoringOrgId) - { - var sponsoringOrg = await _organizationRepository.GetByIdAsync(sponsoringOrgId); - - if (!await _currentContext.OrganizationOwner(sponsoringOrg.Id)) - { - throw new NotFoundException(); - } - - var lastSyncDate = await _organizationSponsorshipRepository.GetLatestSyncDateBySponsoringOrganizationIdAsync(sponsoringOrg.Id); - return new OrganizationSponsorshipSyncStatusResponseModel(lastSyncDate); - } - - private Task CurrentUser => _userService.GetUserByIdAsync(_currentContext.UserId.Value); + _organizationSponsorshipRepository = organizationSponsorshipRepository; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _validateRedemptionTokenCommand = validateRedemptionTokenCommand; + _validateBillingSyncKeyCommand = validateBillingSyncKeyCommand; + _createSponsorshipCommand = createSponsorshipCommand; + _sendSponsorshipOfferCommand = sendSponsorshipOfferCommand; + _setUpSponsorshipCommand = setUpSponsorshipCommand; + _revokeSponsorshipCommand = revokeSponsorshipCommand; + _removeSponsorshipCommand = removeSponsorshipCommand; + _syncSponsorshipsCommand = syncSponsorshipsCommand; + _userService = userService; + _currentContext = currentContext; } + + [Authorize("Application")] + [HttpPost("{sponsoringOrgId}/families-for-enterprise")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task CreateSponsorship(Guid sponsoringOrgId, [FromBody] OrganizationSponsorshipCreateRequestModel model) + { + var sponsoringOrg = await _organizationRepository.GetByIdAsync(sponsoringOrgId); + + var sponsorship = await _createSponsorshipCommand.CreateSponsorshipAsync( + sponsoringOrg, + await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default), + model.PlanSponsorshipType, model.SponsoredEmail, model.FriendlyName); + await _sendSponsorshipOfferCommand.SendSponsorshipOfferAsync(sponsorship, sponsoringOrg.Name); + } + + [Authorize("Application")] + [HttpPost("{sponsoringOrgId}/families-for-enterprise/resend")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task ResendSponsorshipOffer(Guid sponsoringOrgId) + { + var sponsoringOrgUser = await _organizationUserRepository + .GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default); + + await _sendSponsorshipOfferCommand.SendSponsorshipOfferAsync( + await _organizationRepository.GetByIdAsync(sponsoringOrgId), + sponsoringOrgUser, + await _organizationSponsorshipRepository + .GetBySponsoringOrganizationUserIdAsync(sponsoringOrgUser.Id)); + } + + [Authorize("Application")] + [HttpPost("validate-token")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PreValidateSponsorshipToken([FromQuery] string sponsorshipToken) + { + return (await _validateRedemptionTokenCommand.ValidateRedemptionTokenAsync(sponsorshipToken, (await CurrentUser).Email)).valid; + } + + [Authorize("Application")] + [HttpPost("redeem")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task RedeemSponsorship([FromQuery] string sponsorshipToken, [FromBody] OrganizationSponsorshipRedeemRequestModel model) + { + var (valid, sponsorship) = await _validateRedemptionTokenCommand.ValidateRedemptionTokenAsync(sponsorshipToken, (await CurrentUser).Email); + + if (!valid) + { + throw new BadRequestException("Failed to parse sponsorship token."); + } + + if (!await _currentContext.OrganizationOwner(model.SponsoredOrganizationId)) + { + throw new BadRequestException("Can only redeem sponsorship for an organization you own."); + } + + await _setUpSponsorshipCommand.SetUpSponsorshipAsync( + sponsorship, + await _organizationRepository.GetByIdAsync(model.SponsoredOrganizationId)); + } + + [Authorize("Installation")] + [HttpPost("sync")] + public async Task Sync([FromBody] OrganizationSponsorshipSyncRequestModel model) + { + var sponsoringOrg = await _organizationRepository.GetByIdAsync(model.SponsoringOrganizationCloudId); + if (!await _validateBillingSyncKeyCommand.ValidateBillingSyncKeyAsync(sponsoringOrg, model.BillingSyncKey)) + { + throw new BadRequestException("Invalid Billing Sync Key"); + } + + var (syncResponseData, offersToSend) = await _syncSponsorshipsCommand.SyncOrganization(sponsoringOrg, model.ToOrganizationSponsorshipSync().SponsorshipsBatch); + await _sendSponsorshipOfferCommand.BulkSendSponsorshipOfferAsync(sponsoringOrg.Name, offersToSend); + return new OrganizationSponsorshipSyncResponseModel(syncResponseData); + } + + [Authorize("Application")] + [HttpDelete("{sponsoringOrganizationId}")] + [HttpPost("{sponsoringOrganizationId}/delete")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task RevokeSponsorship(Guid sponsoringOrganizationId) + { + + var orgUser = await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrganizationId, _currentContext.UserId ?? default); + if (_currentContext.UserId != orgUser?.UserId) + { + throw new BadRequestException("Can only revoke a sponsorship you granted."); + } + + var existingOrgSponsorship = await _organizationSponsorshipRepository + .GetBySponsoringOrganizationUserIdAsync(orgUser.Id); + + await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship); + } + + [Authorize("Application")] + [HttpDelete("sponsored/{sponsoredOrgId}")] + [HttpPost("sponsored/{sponsoredOrgId}/remove")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task RemoveSponsorship(Guid sponsoredOrgId) + { + + if (!await _currentContext.OrganizationOwner(sponsoredOrgId)) + { + throw new BadRequestException("Only the owner of an organization can remove sponsorship."); + } + + var existingOrgSponsorship = await _organizationSponsorshipRepository + .GetBySponsoredOrganizationIdAsync(sponsoredOrgId); + + await _removeSponsorshipCommand.RemoveSponsorshipAsync(existingOrgSponsorship); + } + + [HttpGet("{sponsoringOrgId}/sync-status")] + public async Task GetSyncStatus(Guid sponsoringOrgId) + { + var sponsoringOrg = await _organizationRepository.GetByIdAsync(sponsoringOrgId); + + if (!await _currentContext.OrganizationOwner(sponsoringOrg.Id)) + { + throw new NotFoundException(); + } + + var lastSyncDate = await _organizationSponsorshipRepository.GetLatestSyncDateBySponsoringOrganizationIdAsync(sponsoringOrg.Id); + return new OrganizationSponsorshipSyncStatusResponseModel(lastSyncDate); + } + + private Task CurrentUser => _userService.GetUserByIdAsync(_currentContext.UserId.Value); } diff --git a/src/Api/Controllers/OrganizationUsersController.cs b/src/Api/Controllers/OrganizationUsersController.cs index b1e5451eb..64340e3ed 100644 --- a/src/Api/Controllers/OrganizationUsersController.cs +++ b/src/Api/Controllers/OrganizationUsersController.cs @@ -13,460 +13,459 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("organizations/{orgId}/users")] +[Authorize("Application")] +public class OrganizationUsersController : Controller { - [Route("organizations/{orgId}/users")] - [Authorize("Application")] - public class OrganizationUsersController : Controller + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IOrganizationService _organizationService; + private readonly ICollectionRepository _collectionRepository; + private readonly IGroupRepository _groupRepository; + private readonly IUserService _userService; + private readonly IPolicyRepository _policyRepository; + private readonly ICurrentContext _currentContext; + + public OrganizationUsersController( + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationService organizationService, + ICollectionRepository collectionRepository, + IGroupRepository groupRepository, + IUserService userService, + IPolicyRepository policyRepository, + ICurrentContext currentContext) { - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IOrganizationService _organizationService; - private readonly ICollectionRepository _collectionRepository; - private readonly IGroupRepository _groupRepository; - private readonly IUserService _userService; - private readonly IPolicyRepository _policyRepository; - private readonly ICurrentContext _currentContext; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _organizationService = organizationService; + _collectionRepository = collectionRepository; + _groupRepository = groupRepository; + _userService = userService; + _policyRepository = policyRepository; + _currentContext = currentContext; + } - public OrganizationUsersController( - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationService organizationService, - ICollectionRepository collectionRepository, - IGroupRepository groupRepository, - IUserService userService, - IPolicyRepository policyRepository, - ICurrentContext currentContext) + [HttpGet("{id}")] + public async Task Get(string orgId, string id) + { + var organizationUser = await _organizationUserRepository.GetByIdWithCollectionsAsync(new Guid(id)); + if (organizationUser == null || !await _currentContext.ManageUsers(organizationUser.Item1.OrganizationId)) { - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _organizationService = organizationService; - _collectionRepository = collectionRepository; - _groupRepository = groupRepository; - _userService = userService; - _policyRepository = policyRepository; - _currentContext = currentContext; + throw new NotFoundException(); } - [HttpGet("{id}")] - public async Task Get(string orgId, string id) - { - var organizationUser = await _organizationUserRepository.GetByIdWithCollectionsAsync(new Guid(id)); - if (organizationUser == null || !await _currentContext.ManageUsers(organizationUser.Item1.OrganizationId)) - { - throw new NotFoundException(); - } + return new OrganizationUserDetailsResponseModel(organizationUser.Item1, organizationUser.Item2); + } - return new OrganizationUserDetailsResponseModel(organizationUser.Item1, organizationUser.Item2); + [HttpGet("")] + public async Task> Get(string orgId) + { + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ViewAllCollections(orgGuidId) && + !await _currentContext.ViewAssignedCollections(orgGuidId) && + !await _currentContext.ManageGroups(orgGuidId) && + !await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); } - [HttpGet("")] - public async Task> Get(string orgId) - { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ViewAllCollections(orgGuidId) && - !await _currentContext.ViewAssignedCollections(orgGuidId) && - !await _currentContext.ManageGroups(orgGuidId) && - !await _currentContext.ManageUsers(orgGuidId)) - { - throw new NotFoundException(); - } + var organizationUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(orgGuidId); + var responseTasks = organizationUsers.Select(async o => new OrganizationUserUserDetailsResponseModel(o, + await _userService.TwoFactorIsEnabledAsync(o))); + var responses = await Task.WhenAll(responseTasks); + return new ListResponseModel(responses); + } - var organizationUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(orgGuidId); - var responseTasks = organizationUsers.Select(async o => new OrganizationUserUserDetailsResponseModel(o, - await _userService.TwoFactorIsEnabledAsync(o))); - var responses = await Task.WhenAll(responseTasks); - return new ListResponseModel(responses); + [HttpGet("{id}/groups")] + public async Task> GetGroups(string orgId, string id) + { + var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); + if (organizationUser == null || (!await _currentContext.ManageGroups(organizationUser.OrganizationId) && + !await _currentContext.ManageUsers(organizationUser.OrganizationId))) + { + throw new NotFoundException(); } - [HttpGet("{id}/groups")] - public async Task> GetGroups(string orgId, string id) - { - var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); - if (organizationUser == null || (!await _currentContext.ManageGroups(organizationUser.OrganizationId) && - !await _currentContext.ManageUsers(organizationUser.OrganizationId))) - { - throw new NotFoundException(); - } + var groupIds = await _groupRepository.GetManyIdsByUserIdAsync(organizationUser.Id); + var responses = groupIds.Select(g => g.ToString()); + return responses; + } - var groupIds = await _groupRepository.GetManyIdsByUserIdAsync(organizationUser.Id); - var responses = groupIds.Select(g => g.ToString()); - return responses; + [HttpGet("{id}/reset-password-details")] + public async Task GetResetPasswordDetails(string orgId, string id) + { + // Make sure the calling user can reset passwords for this org + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageResetPassword(orgGuidId)) + { + throw new NotFoundException(); } - [HttpGet("{id}/reset-password-details")] - public async Task GetResetPasswordDetails(string orgId, string id) + var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); + if (organizationUser == null || !organizationUser.UserId.HasValue) { - // Make sure the calling user can reset passwords for this org - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageResetPassword(orgGuidId)) - { - throw new NotFoundException(); - } - - var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); - if (organizationUser == null || !organizationUser.UserId.HasValue) - { - throw new NotFoundException(); - } - - // Retrieve data necessary for response (KDF, KDF Iterations, ResetPasswordKey) - // TODO Reset Password - Revisit this and create SPROC to reduce DB calls - var user = await _userService.GetUserByIdAsync(organizationUser.UserId.Value); - if (user == null) - { - throw new NotFoundException(); - } - - // Retrieve Encrypted Private Key from organization - var org = await _organizationRepository.GetByIdAsync(orgGuidId); - if (org == null) - { - throw new NotFoundException(); - } - - return new OrganizationUserResetPasswordDetailsResponseModel(new OrganizationUserResetPasswordDetails(organizationUser, user, org)); + throw new NotFoundException(); } - [HttpPost("invite")] - public async Task Invite(string orgId, [FromBody] OrganizationUserInviteRequestModel model) + // Retrieve data necessary for response (KDF, KDF Iterations, ResetPasswordKey) + // TODO Reset Password - Revisit this and create SPROC to reduce DB calls + var user = await _userService.GetUserByIdAsync(organizationUser.UserId.Value); + if (user == null) { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User); - var result = await _organizationService.InviteUsersAsync(orgGuidId, userId.Value, - new (OrganizationUserInvite, string)[] { (new OrganizationUserInvite(model.ToData()), null) }); + throw new NotFoundException(); } - [HttpPost("reinvite")] - public async Task> BulkReinvite(string orgId, [FromBody] OrganizationUserBulkRequestModel model) + // Retrieve Encrypted Private Key from organization + var org = await _organizationRepository.GetByIdAsync(orgGuidId); + if (org == null) { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User); - var result = await _organizationService.ResendInvitesAsync(orgGuidId, userId.Value, model.Ids); - return new ListResponseModel( - result.Select(t => new OrganizationUserBulkResponseModel(t.Item1.Id, t.Item2))); + throw new NotFoundException(); } - [HttpPost("{id}/reinvite")] - public async Task Reinvite(string orgId, string id) - { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) - { - throw new NotFoundException(); - } + return new OrganizationUserResetPasswordDetailsResponseModel(new OrganizationUserResetPasswordDetails(organizationUser, user, org)); + } - var userId = _userService.GetProperUserId(User); - await _organizationService.ResendInviteAsync(orgGuidId, userId.Value, new Guid(id)); + [HttpPost("invite")] + public async Task Invite(string orgId, [FromBody] OrganizationUserInviteRequestModel model) + { + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); } - [HttpPost("{organizationUserId}/accept")] - public async Task Accept(Guid orgId, Guid organizationUserId, [FromBody] OrganizationUserAcceptRequestModel model) + var userId = _userService.GetProperUserId(User); + var result = await _organizationService.InviteUsersAsync(orgGuidId, userId.Value, + new (OrganizationUserInvite, string)[] { (new OrganizationUserInvite(model.ToData()), null) }); + } + + [HttpPost("reinvite")] + public async Task> BulkReinvite(string orgId, [FromBody] OrganizationUserBulkRequestModel model) + { + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var masterPasswordPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword); - var useMasterPasswordPolicy = masterPasswordPolicy != null && - masterPasswordPolicy.Enabled && - masterPasswordPolicy.GetDataModel().AutoEnrollEnabled; - - if (useMasterPasswordPolicy && - string.IsNullOrWhiteSpace(model.ResetPasswordKey)) - { - throw new BadRequestException(string.Empty, "Master Password reset is required, but not provided."); - } - - await _organizationService.AcceptUserAsync(organizationUserId, user, model.Token, _userService); - - if (useMasterPasswordPolicy) - { - await _organizationService.UpdateUserResetPasswordEnrollmentAsync(orgId, user.Id, model.ResetPasswordKey, user.Id); - } + throw new NotFoundException(); } - [HttpPost("{id}/confirm")] - public async Task Confirm(string orgId, string id, [FromBody] OrganizationUserConfirmRequestModel model) - { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) - { - throw new NotFoundException(); - } + var userId = _userService.GetProperUserId(User); + var result = await _organizationService.ResendInvitesAsync(orgGuidId, userId.Value, model.Ids); + return new ListResponseModel( + result.Select(t => new OrganizationUserBulkResponseModel(t.Item1.Id, t.Item2))); + } - var userId = _userService.GetProperUserId(User); - var result = await _organizationService.ConfirmUserAsync(orgGuidId, new Guid(id), model.Key, userId.Value, - _userService); + [HttpPost("{id}/reinvite")] + public async Task Reinvite(string orgId, string id) + { + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); } - [HttpPost("confirm")] - public async Task> BulkConfirm(string orgId, - [FromBody] OrganizationUserBulkConfirmRequestModel model) + var userId = _userService.GetProperUserId(User); + await _organizationService.ResendInviteAsync(orgGuidId, userId.Value, new Guid(id)); + } + + [HttpPost("{organizationUserId}/accept")] + public async Task Accept(Guid orgId, Guid organizationUserId, [FromBody] OrganizationUserAcceptRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User); - var results = await _organizationService.ConfirmUsersAsync(orgGuidId, model.ToDictionary(), userId.Value, - _userService); - - return new ListResponseModel(results.Select(r => - new OrganizationUserBulkResponseModel(r.Item1.Id, r.Item2))); + throw new UnauthorizedAccessException(); } - [HttpPost("public-keys")] - public async Task> UserPublicKeys(string orgId, [FromBody] OrganizationUserBulkRequestModel model) - { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) - { - throw new NotFoundException(); - } + var masterPasswordPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword); + var useMasterPasswordPolicy = masterPasswordPolicy != null && + masterPasswordPolicy.Enabled && + masterPasswordPolicy.GetDataModel().AutoEnrollEnabled; - var result = await _organizationUserRepository.GetManyPublicKeysByOrganizationUserAsync(orgGuidId, model.Ids); - var responses = result.Select(r => new OrganizationUserPublicKeyResponseModel(r.Id, r.UserId, r.PublicKey)).ToList(); - return new ListResponseModel(responses); + if (useMasterPasswordPolicy && + string.IsNullOrWhiteSpace(model.ResetPasswordKey)) + { + throw new BadRequestException(string.Empty, "Master Password reset is required, but not provided."); } - [HttpPut("{id}")] - [HttpPost("{id}")] - public async Task Put(string orgId, string id, [FromBody] OrganizationUserUpdateRequestModel model) + await _organizationService.AcceptUserAsync(organizationUserId, user, model.Token, _userService); + + if (useMasterPasswordPolicy) { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) - { - throw new NotFoundException(); - } - - var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); - if (organizationUser == null || organizationUser.OrganizationId != orgGuidId) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User); - await _organizationService.SaveUserAsync(model.ToOrganizationUser(organizationUser), userId.Value, - model.Collections?.Select(c => c.ToSelectionReadOnly())); - } - - [HttpPut("{id}/groups")] - [HttpPost("{id}/groups")] - public async Task PutGroups(string orgId, string id, [FromBody] OrganizationUserUpdateGroupsRequestModel model) - { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) - { - throw new NotFoundException(); - } - - var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); - if (organizationUser == null || organizationUser.OrganizationId != orgGuidId) - { - throw new NotFoundException(); - } - - var loggedInUserId = _userService.GetProperUserId(User); - await _organizationService.UpdateUserGroupsAsync(organizationUser, model.GroupIds.Select(g => new Guid(g)), loggedInUserId); - } - - [HttpPut("{userId}/reset-password-enrollment")] - public async Task PutResetPasswordEnrollment(string orgId, string userId, [FromBody] OrganizationUserResetPasswordEnrollmentRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (model.ResetPasswordKey != null && !await _userService.VerifySecretAsync(user, model.Secret)) - { - await Task.Delay(2000); - throw new BadRequestException("MasterPasswordHash", "Invalid password."); - } - else - { - var callingUserId = user.Id; - await _organizationService.UpdateUserResetPasswordEnrollmentAsync( - new Guid(orgId), new Guid(userId), model.ResetPasswordKey, callingUserId); - } - } - - [HttpPut("{id}/reset-password")] - public async Task PutResetPassword(string orgId, string id, [FromBody] OrganizationUserResetPasswordRequestModel model) - { - - var orgGuidId = new Guid(orgId); - - // Calling user must have Manage Reset Password permission - if (!await _currentContext.ManageResetPassword(orgGuidId)) - { - throw new NotFoundException(); - } - - // Get the users role, since provider users aren't a member of the organization we use the owner check - var orgUserType = await _currentContext.OrganizationOwner(orgGuidId) - ? OrganizationUserType.Owner - : _currentContext.Organizations?.FirstOrDefault(o => o.Id == orgGuidId)?.Type; - if (orgUserType == null) - { - throw new NotFoundException(); - } - - var result = await _userService.AdminResetPasswordAsync(orgUserType.Value, orgGuidId, new Guid(id), model.NewMasterPasswordHash, model.Key); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - await Task.Delay(2000); - throw new BadRequestException(ModelState); - } - - [HttpDelete("{id}")] - [HttpPost("{id}/delete")] - public async Task Delete(string orgId, string id) - { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User); - await _organizationService.DeleteUserAsync(orgGuidId, new Guid(id), userId.Value); - } - - [HttpDelete("")] - [HttpPost("delete")] - public async Task> BulkDelete(string orgId, [FromBody] OrganizationUserBulkRequestModel model) - { - var orgGuidId = new Guid(orgId); - if (!await _currentContext.ManageUsers(orgGuidId)) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User); - var result = await _organizationService.DeleteUsersAsync(orgGuidId, model.Ids, userId.Value); - return new ListResponseModel(result.Select(r => - new OrganizationUserBulkResponseModel(r.Item1.Id, r.Item2))); - } - - [Obsolete("2022-07-22 Moved to {id}/revoke endpoint")] - [HttpPatch("{id}/deactivate")] - [HttpPut("{id}/deactivate")] - public async Task Deactivate(Guid orgId, Guid id) - { - await RevokeAsync(orgId, id); - } - - [Obsolete("2022-07-22 Moved to /revoke endpoint")] - [HttpPatch("deactivate")] - [HttpPut("deactivate")] - public async Task> BulkDeactivate(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) - { - return await BulkRevokeAsync(orgId, model); - } - - [Obsolete("2022-07-22 Moved to {id}/restore endpoint")] - [HttpPatch("{id}/activate")] - [HttpPut("{id}/activate")] - public async Task Activate(Guid orgId, Guid id) - { - await RestoreAsync(orgId, id); - } - - [Obsolete("2022-07-22 Moved to /restore endpoint")] - [HttpPatch("activate")] - [HttpPut("activate")] - public async Task> BulkActivate(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) - { - return await BulkRestoreAsync(orgId, model); - } - - [HttpPatch("{id}/revoke")] - [HttpPut("{id}/revoke")] - public async Task RevokeAsync(Guid orgId, Guid id) - { - await RestoreOrRevokeUserAsync(orgId, id, _organizationService.RevokeUserAsync); - } - - [HttpPatch("revoke")] - [HttpPut("revoke")] - public async Task> BulkRevokeAsync(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) - { - return await RestoreOrRevokeUsersAsync(orgId, model, _organizationService.RevokeUsersAsync); - } - - [HttpPatch("{id}/restore")] - [HttpPut("{id}/restore")] - public async Task RestoreAsync(Guid orgId, Guid id) - { - await RestoreOrRevokeUserAsync(orgId, id, (orgUser, userId) => _organizationService.RestoreUserAsync(orgUser, userId, _userService)); - } - - [HttpPatch("restore")] - [HttpPut("restore")] - public async Task> BulkRestoreAsync(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) - { - return await RestoreOrRevokeUsersAsync(orgId, model, (orgId, orgUserIds, restoringUserId) => _organizationService.RestoreUsersAsync(orgId, orgUserIds, restoringUserId, _userService)); - } - - private async Task RestoreOrRevokeUserAsync( - Guid orgId, - Guid id, - Func statusAction) - { - if (!await _currentContext.ManageUsers(orgId)) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User); - var orgUser = await _organizationUserRepository.GetByIdAsync(id); - if (orgUser == null || orgUser.OrganizationId != orgId) - { - throw new NotFoundException(); - } - - await statusAction(orgUser, userId); - } - - private async Task> RestoreOrRevokeUsersAsync( - Guid orgId, - OrganizationUserBulkRequestModel model, - Func, Guid?, Task>>> statusAction) - { - if (!await _currentContext.ManageUsers(orgId)) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User); - var result = await statusAction(orgId, model.Ids, userId.Value); - return new ListResponseModel(result.Select(r => - new OrganizationUserBulkResponseModel(r.Item1.Id, r.Item2))); + await _organizationService.UpdateUserResetPasswordEnrollmentAsync(orgId, user.Id, model.ResetPasswordKey, user.Id); } } + + [HttpPost("{id}/confirm")] + public async Task Confirm(string orgId, string id, [FromBody] OrganizationUserConfirmRequestModel model) + { + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + var result = await _organizationService.ConfirmUserAsync(orgGuidId, new Guid(id), model.Key, userId.Value, + _userService); + } + + [HttpPost("confirm")] + public async Task> BulkConfirm(string orgId, + [FromBody] OrganizationUserBulkConfirmRequestModel model) + { + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + var results = await _organizationService.ConfirmUsersAsync(orgGuidId, model.ToDictionary(), userId.Value, + _userService); + + return new ListResponseModel(results.Select(r => + new OrganizationUserBulkResponseModel(r.Item1.Id, r.Item2))); + } + + [HttpPost("public-keys")] + public async Task> UserPublicKeys(string orgId, [FromBody] OrganizationUserBulkRequestModel model) + { + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); + } + + var result = await _organizationUserRepository.GetManyPublicKeysByOrganizationUserAsync(orgGuidId, model.Ids); + var responses = result.Select(r => new OrganizationUserPublicKeyResponseModel(r.Id, r.UserId, r.PublicKey)).ToList(); + return new ListResponseModel(responses); + } + + [HttpPut("{id}")] + [HttpPost("{id}")] + public async Task Put(string orgId, string id, [FromBody] OrganizationUserUpdateRequestModel model) + { + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); + } + + var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); + if (organizationUser == null || organizationUser.OrganizationId != orgGuidId) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + await _organizationService.SaveUserAsync(model.ToOrganizationUser(organizationUser), userId.Value, + model.Collections?.Select(c => c.ToSelectionReadOnly())); + } + + [HttpPut("{id}/groups")] + [HttpPost("{id}/groups")] + public async Task PutGroups(string orgId, string id, [FromBody] OrganizationUserUpdateGroupsRequestModel model) + { + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); + } + + var organizationUser = await _organizationUserRepository.GetByIdAsync(new Guid(id)); + if (organizationUser == null || organizationUser.OrganizationId != orgGuidId) + { + throw new NotFoundException(); + } + + var loggedInUserId = _userService.GetProperUserId(User); + await _organizationService.UpdateUserGroupsAsync(organizationUser, model.GroupIds.Select(g => new Guid(g)), loggedInUserId); + } + + [HttpPut("{userId}/reset-password-enrollment")] + public async Task PutResetPasswordEnrollment(string orgId, string userId, [FromBody] OrganizationUserResetPasswordEnrollmentRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (model.ResetPasswordKey != null && !await _userService.VerifySecretAsync(user, model.Secret)) + { + await Task.Delay(2000); + throw new BadRequestException("MasterPasswordHash", "Invalid password."); + } + else + { + var callingUserId = user.Id; + await _organizationService.UpdateUserResetPasswordEnrollmentAsync( + new Guid(orgId), new Guid(userId), model.ResetPasswordKey, callingUserId); + } + } + + [HttpPut("{id}/reset-password")] + public async Task PutResetPassword(string orgId, string id, [FromBody] OrganizationUserResetPasswordRequestModel model) + { + + var orgGuidId = new Guid(orgId); + + // Calling user must have Manage Reset Password permission + if (!await _currentContext.ManageResetPassword(orgGuidId)) + { + throw new NotFoundException(); + } + + // Get the users role, since provider users aren't a member of the organization we use the owner check + var orgUserType = await _currentContext.OrganizationOwner(orgGuidId) + ? OrganizationUserType.Owner + : _currentContext.Organizations?.FirstOrDefault(o => o.Id == orgGuidId)?.Type; + if (orgUserType == null) + { + throw new NotFoundException(); + } + + var result = await _userService.AdminResetPasswordAsync(orgUserType.Value, orgGuidId, new Guid(id), model.NewMasterPasswordHash, model.Key); + if (result.Succeeded) + { + return; + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + [HttpDelete("{id}")] + [HttpPost("{id}/delete")] + public async Task Delete(string orgId, string id) + { + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + await _organizationService.DeleteUserAsync(orgGuidId, new Guid(id), userId.Value); + } + + [HttpDelete("")] + [HttpPost("delete")] + public async Task> BulkDelete(string orgId, [FromBody] OrganizationUserBulkRequestModel model) + { + var orgGuidId = new Guid(orgId); + if (!await _currentContext.ManageUsers(orgGuidId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + var result = await _organizationService.DeleteUsersAsync(orgGuidId, model.Ids, userId.Value); + return new ListResponseModel(result.Select(r => + new OrganizationUserBulkResponseModel(r.Item1.Id, r.Item2))); + } + + [Obsolete("2022-07-22 Moved to {id}/revoke endpoint")] + [HttpPatch("{id}/deactivate")] + [HttpPut("{id}/deactivate")] + public async Task Deactivate(Guid orgId, Guid id) + { + await RevokeAsync(orgId, id); + } + + [Obsolete("2022-07-22 Moved to /revoke endpoint")] + [HttpPatch("deactivate")] + [HttpPut("deactivate")] + public async Task> BulkDeactivate(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) + { + return await BulkRevokeAsync(orgId, model); + } + + [Obsolete("2022-07-22 Moved to {id}/restore endpoint")] + [HttpPatch("{id}/activate")] + [HttpPut("{id}/activate")] + public async Task Activate(Guid orgId, Guid id) + { + await RestoreAsync(orgId, id); + } + + [Obsolete("2022-07-22 Moved to /restore endpoint")] + [HttpPatch("activate")] + [HttpPut("activate")] + public async Task> BulkActivate(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) + { + return await BulkRestoreAsync(orgId, model); + } + + [HttpPatch("{id}/revoke")] + [HttpPut("{id}/revoke")] + public async Task RevokeAsync(Guid orgId, Guid id) + { + await RestoreOrRevokeUserAsync(orgId, id, _organizationService.RevokeUserAsync); + } + + [HttpPatch("revoke")] + [HttpPut("revoke")] + public async Task> BulkRevokeAsync(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) + { + return await RestoreOrRevokeUsersAsync(orgId, model, _organizationService.RevokeUsersAsync); + } + + [HttpPatch("{id}/restore")] + [HttpPut("{id}/restore")] + public async Task RestoreAsync(Guid orgId, Guid id) + { + await RestoreOrRevokeUserAsync(orgId, id, (orgUser, userId) => _organizationService.RestoreUserAsync(orgUser, userId, _userService)); + } + + [HttpPatch("restore")] + [HttpPut("restore")] + public async Task> BulkRestoreAsync(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) + { + return await RestoreOrRevokeUsersAsync(orgId, model, (orgId, orgUserIds, restoringUserId) => _organizationService.RestoreUsersAsync(orgId, orgUserIds, restoringUserId, _userService)); + } + + private async Task RestoreOrRevokeUserAsync( + Guid orgId, + Guid id, + Func statusAction) + { + if (!await _currentContext.ManageUsers(orgId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + var orgUser = await _organizationUserRepository.GetByIdAsync(id); + if (orgUser == null || orgUser.OrganizationId != orgId) + { + throw new NotFoundException(); + } + + await statusAction(orgUser, userId); + } + + private async Task> RestoreOrRevokeUsersAsync( + Guid orgId, + OrganizationUserBulkRequestModel model, + Func, Guid?, Task>>> statusAction) + { + if (!await _currentContext.ManageUsers(orgId)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + var result = await statusAction(orgId, model.Ids, userId.Value); + return new ListResponseModel(result.Select(r => + new OrganizationUserBulkResponseModel(r.Item1.Id, r.Item2))); + } } diff --git a/src/Api/Controllers/OrganizationsController.cs b/src/Api/Controllers/OrganizationsController.cs index 7a5b26d9e..f38b0dbc3 100644 --- a/src/Api/Controllers/OrganizationsController.cs +++ b/src/Api/Controllers/OrganizationsController.cs @@ -18,698 +18,697 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("organizations")] +[Authorize("Application")] +public class OrganizationsController : Controller { - [Route("organizations")] - [Authorize("Application")] - public class OrganizationsController : Controller + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IPolicyRepository _policyRepository; + private readonly IOrganizationService _organizationService; + private readonly IUserService _userService; + private readonly IPaymentService _paymentService; + private readonly ICurrentContext _currentContext; + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly ISsoConfigService _ssoConfigService; + private readonly IGetOrganizationApiKeyCommand _getOrganizationApiKeyCommand; + private readonly IRotateOrganizationApiKeyCommand _rotateOrganizationApiKeyCommand; + private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; + private readonly GlobalSettings _globalSettings; + + public OrganizationsController( + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IPolicyRepository policyRepository, + IOrganizationService organizationService, + IUserService userService, + IPaymentService paymentService, + ICurrentContext currentContext, + ISsoConfigRepository ssoConfigRepository, + ISsoConfigService ssoConfigService, + IGetOrganizationApiKeyCommand getOrganizationApiKeyCommand, + IRotateOrganizationApiKeyCommand rotateOrganizationApiKeyCommand, + IOrganizationApiKeyRepository organizationApiKeyRepository, + GlobalSettings globalSettings) { - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPolicyRepository _policyRepository; - private readonly IOrganizationService _organizationService; - private readonly IUserService _userService; - private readonly IPaymentService _paymentService; - private readonly ICurrentContext _currentContext; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly ISsoConfigService _ssoConfigService; - private readonly IGetOrganizationApiKeyCommand _getOrganizationApiKeyCommand; - private readonly IRotateOrganizationApiKeyCommand _rotateOrganizationApiKeyCommand; - private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; - private readonly GlobalSettings _globalSettings; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _policyRepository = policyRepository; + _organizationService = organizationService; + _userService = userService; + _paymentService = paymentService; + _currentContext = currentContext; + _ssoConfigRepository = ssoConfigRepository; + _ssoConfigService = ssoConfigService; + _getOrganizationApiKeyCommand = getOrganizationApiKeyCommand; + _rotateOrganizationApiKeyCommand = rotateOrganizationApiKeyCommand; + _organizationApiKeyRepository = organizationApiKeyRepository; + _globalSettings = globalSettings; + } - public OrganizationsController( - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IPolicyRepository policyRepository, - IOrganizationService organizationService, - IUserService userService, - IPaymentService paymentService, - ICurrentContext currentContext, - ISsoConfigRepository ssoConfigRepository, - ISsoConfigService ssoConfigService, - IGetOrganizationApiKeyCommand getOrganizationApiKeyCommand, - IRotateOrganizationApiKeyCommand rotateOrganizationApiKeyCommand, - IOrganizationApiKeyRepository organizationApiKeyRepository, - GlobalSettings globalSettings) + [HttpGet("{id}")] + public async Task Get(string id) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.OrganizationOwner(orgIdGuid)) { - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _policyRepository = policyRepository; - _organizationService = organizationService; - _userService = userService; - _paymentService = paymentService; - _currentContext = currentContext; - _ssoConfigRepository = ssoConfigRepository; - _ssoConfigService = ssoConfigService; - _getOrganizationApiKeyCommand = getOrganizationApiKeyCommand; - _rotateOrganizationApiKeyCommand = rotateOrganizationApiKeyCommand; - _organizationApiKeyRepository = organizationApiKeyRepository; - _globalSettings = globalSettings; + throw new NotFoundException(); } - [HttpGet("{id}")] - public async Task Get(string id) + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) { - var orgIdGuid = new Guid(id); - if (!await _currentContext.OrganizationOwner(orgIdGuid)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - return new OrganizationResponseModel(organization); + throw new NotFoundException(); } - [HttpGet("{id}/billing")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetBilling(string id) + return new OrganizationResponseModel(organization); + } + + [HttpGet("{id}/billing")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task GetBilling(string id) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - var billingInfo = await _paymentService.GetBillingAsync(organization); - return new BillingResponseModel(billingInfo); + throw new NotFoundException(); } - [HttpGet("{id}/subscription")] - public async Task GetSubscription(string id) + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - if (!_globalSettings.SelfHosted && organization.Gateway != null) - { - var subscriptionInfo = await _paymentService.GetSubscriptionAsync(organization); - if (subscriptionInfo == null) - { - throw new NotFoundException(); - } - return new OrganizationSubscriptionResponseModel(organization, subscriptionInfo); - } - else - { - return new OrganizationSubscriptionResponseModel(organization); - } + throw new NotFoundException(); } - [HttpGet("{id}/license")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetLicense(string id, [FromQuery] Guid installationId) + var billingInfo = await _paymentService.GetBillingAsync(organization); + return new BillingResponseModel(billingInfo); + } + + [HttpGet("{id}/subscription")] + public async Task GetSubscription(string id) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) { - var orgIdGuid = new Guid(id); - if (!await _currentContext.OrganizationOwner(orgIdGuid)) - { - throw new NotFoundException(); - } - - var license = await _organizationService.GenerateLicenseAsync(orgIdGuid, installationId); - if (license == null) - { - throw new NotFoundException(); - } - - return license; + throw new NotFoundException(); } - [HttpGet("")] - public async Task> GetUser() + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) { - var userId = _userService.GetProperUserId(User).Value; - var organizations = await _organizationUserRepository.GetManyDetailsByUserAsync(userId, - OrganizationUserStatusType.Confirmed); - var responses = organizations.Select(o => new ProfileOrganizationResponseModel(o)); - return new ListResponseModel(responses); + throw new NotFoundException(); } - [HttpGet("{identifier}/auto-enroll-status")] - public async Task GetAutoEnrollStatus(string identifier) + if (!_globalSettings.SelfHosted && organization.Gateway != null) { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var organization = await _organizationRepository.GetByIdentifierAsync(identifier); - if (organization == null) + var subscriptionInfo = await _paymentService.GetSubscriptionAsync(organization); + if (subscriptionInfo == null) { throw new NotFoundException(); } - - var organizationUser = await _organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id); - if (organizationUser == null) - { - throw new NotFoundException(); - } - - var resetPasswordPolicy = - await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); - if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled || resetPasswordPolicy.Data == null) - { - return new OrganizationAutoEnrollStatusResponseModel(organization.Id, false); - } - - var data = JsonSerializer.Deserialize(resetPasswordPolicy.Data, JsonHelpers.IgnoreCase); - return new OrganizationAutoEnrollStatusResponseModel(organization.Id, data?.AutoEnrollEnabled ?? false); + return new OrganizationSubscriptionResponseModel(organization, subscriptionInfo); } - - [HttpPost("")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task Post([FromBody] OrganizationCreateRequestModel model) + else { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var organizationSignup = model.ToOrganizationSignup(user); - var result = await _organizationService.SignUpAsync(organizationSignup); - return new OrganizationResponseModel(result.Item1); - } - - [HttpPost("license")] - [SelfHosted(SelfHostedOnly = true)] - public async Task PostLicense(OrganizationCreateLicenseRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var license = await ApiHelpers.ReadJsonFileFromBody(HttpContext, model.License); - if (license == null) - { - throw new BadRequestException("Invalid license"); - } - - var result = await _organizationService.SignUpAsync(license, user, model.Key, - model.CollectionName, model.Keys?.PublicKey, model.Keys?.EncryptedPrivateKey); - return new OrganizationResponseModel(result.Item1); - } - - [HttpPut("{id}")] - [HttpPost("{id}")] - public async Task Put(string id, [FromBody] OrganizationUpdateRequestModel model) - { - var orgIdGuid = new Guid(id); - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - var updateBilling = !_globalSettings.SelfHosted && (model.BusinessName != organization.BusinessName || - model.BillingEmail != organization.BillingEmail); - - var hasRequiredPermissions = updateBilling - ? await _currentContext.ManageBilling(orgIdGuid) - : await _currentContext.OrganizationOwner(orgIdGuid); - - if (!hasRequiredPermissions) - { - throw new NotFoundException(); - } - - await _organizationService.UpdateAsync(model.ToOrganization(organization, _globalSettings), updateBilling); - return new OrganizationResponseModel(organization); - } - - [HttpPost("{id}/payment")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostPayment(string id, [FromBody] PaymentRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - await _organizationService.ReplacePaymentMethodAsync(orgIdGuid, model.PaymentToken, - model.PaymentMethodType.Value, new TaxInfo - { - BillingAddressLine1 = model.Line1, - BillingAddressLine2 = model.Line2, - BillingAddressState = model.State, - BillingAddressCity = model.City, - BillingAddressPostalCode = model.PostalCode, - BillingAddressCountry = model.Country, - TaxIdNumber = model.TaxId, - }); - } - - [HttpPost("{id}/upgrade")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostUpgrade(string id, [FromBody] OrganizationUpgradeRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - var result = await _organizationService.UpgradePlanAsync(orgIdGuid, model.ToOrganizationUpgrade()); - return new PaymentResponseModel - { - Success = result.Item1, - PaymentIntentClientSecret = result.Item2 - }; - } - - [HttpPost("{id}/subscription")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostSubscription(string id, [FromBody] OrganizationSubscriptionUpdateRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - await _organizationService.UpdateSubscription(orgIdGuid, model.SeatAdjustment, model.MaxAutoscaleSeats); - } - - [HttpPost("{id}/seat")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostSeat(string id, [FromBody] OrganizationSeatRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - var result = await _organizationService.AdjustSeatsAsync(orgIdGuid, model.SeatAdjustment.Value); - return new PaymentResponseModel - { - Success = true, - PaymentIntentClientSecret = result - }; - } - - [HttpPost("{id}/storage")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostStorage(string id, [FromBody] StorageRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - var result = await _organizationService.AdjustStorageAsync(orgIdGuid, model.StorageGbAdjustment.Value); - return new PaymentResponseModel - { - Success = true, - PaymentIntentClientSecret = result - }; - } - - [HttpPost("{id}/verify-bank")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostVerifyBank(string id, [FromBody] OrganizationVerifyBankRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - await _organizationService.VerifyBankAsync(orgIdGuid, model.Amount1.Value, model.Amount2.Value); - } - - [HttpPost("{id}/cancel")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostCancel(string id) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - await _organizationService.CancelSubscriptionAsync(orgIdGuid); - } - - [HttpPost("{id}/reinstate")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostReinstate(string id) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManageBilling(orgIdGuid)) - { - throw new NotFoundException(); - } - - await _organizationService.ReinstateSubscriptionAsync(orgIdGuid); - } - - [HttpPost("{id}/leave")] - public async Task Leave(string id) - { - var orgGuidId = new Guid(id); - if (!await _currentContext.OrganizationUser(orgGuidId)) - { - throw new NotFoundException(); - } - - var user = await _userService.GetUserByPrincipalAsync(User); - - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(orgGuidId); - if (ssoConfig?.GetData()?.KeyConnectorEnabled == true && - user.UsesKeyConnector) - { - throw new BadRequestException("Your organization's Single Sign-On settings prevent you from leaving."); - } - - - await _organizationService.DeleteUserAsync(orgGuidId, user.Id); - } - - [HttpDelete("{id}")] - [HttpPost("{id}/delete")] - public async Task Delete(string id, [FromBody] SecretVerificationRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.OrganizationOwner(orgIdGuid)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (!await _userService.VerifySecretAsync(user, model.Secret)) - { - await Task.Delay(2000); - throw new BadRequestException(string.Empty, "User verification failed."); - } - else - { - await _organizationService.DeleteAsync(organization); - } - } - - [HttpPost("{id}/license")] - [SelfHosted(SelfHostedOnly = true)] - public async Task PostLicense(string id, LicenseRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.OrganizationOwner(orgIdGuid)) - { - throw new NotFoundException(); - } - - var license = await ApiHelpers.ReadJsonFileFromBody(HttpContext, model.License); - if (license == null) - { - throw new BadRequestException("Invalid license"); - } - - await _organizationService.UpdateLicenseAsync(new Guid(id), license); - } - - [HttpPost("{id}/import")] - public async Task Import(string id, [FromBody] ImportOrganizationUsersRequestModel model) - { - if (!_globalSettings.SelfHosted && !model.LargeImport && - (model.Groups.Count() > 2000 || model.Users.Count(u => !u.Deleted) > 2000)) - { - throw new BadRequestException("You cannot import this much data at once."); - } - - var orgIdGuid = new Guid(id); - if (!await _currentContext.OrganizationAdmin(orgIdGuid)) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User); - await _organizationService.ImportAsync( - orgIdGuid, - userId.Value, - model.Groups.Select(g => g.ToImportedGroup(orgIdGuid)), - model.Users.Where(u => !u.Deleted).Select(u => u.ToImportedOrganizationUser()), - model.Users.Where(u => u.Deleted).Select(u => u.ExternalId), - model.OverwriteExisting); - } - - [HttpPost("{id}/api-key")] - public async Task ApiKey(string id, [FromBody] OrganizationApiKeyRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await HasApiKeyAccessAsync(orgIdGuid, model.Type)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - if (model.Type == OrganizationApiKeyType.BillingSync || model.Type == OrganizationApiKeyType.Scim) - { - // Non-enterprise orgs should not be able to create or view an apikey of billing sync/scim key types - var plan = StaticStore.GetPlan(organization.PlanType); - if (plan.Product != ProductType.Enterprise) - { - throw new NotFoundException(); - } - } - - var organizationApiKey = await _getOrganizationApiKeyCommand - .GetOrganizationApiKeyAsync(organization.Id, model.Type); - - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (model.Type != OrganizationApiKeyType.Scim - && !await _userService.VerifySecretAsync(user, model.Secret)) - { - await Task.Delay(2000); - throw new BadRequestException("MasterPasswordHash", "Invalid password."); - } - else - { - var response = new ApiKeyResponseModel(organizationApiKey); - return response; - } - } - - [HttpGet("{id}/api-key-information/{type?}")] - public async Task> ApiKeyInformation(Guid id, OrganizationApiKeyType? type) - { - if (!await HasApiKeyAccessAsync(id, type)) - { - throw new NotFoundException(); - } - - var apiKeys = await _organizationApiKeyRepository.GetManyByOrganizationIdTypeAsync(id, type); - - return new ListResponseModel( - apiKeys.Select(k => new OrganizationApiKeyInformation(k))); - } - - [HttpPost("{id}/rotate-api-key")] - public async Task RotateApiKey(string id, [FromBody] OrganizationApiKeyRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await HasApiKeyAccessAsync(orgIdGuid, model.Type)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - var organizationApiKey = await _getOrganizationApiKeyCommand - .GetOrganizationApiKeyAsync(organization.Id, model.Type); - - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (model.Type != OrganizationApiKeyType.Scim - && !await _userService.VerifySecretAsync(user, model.Secret)) - { - await Task.Delay(2000); - throw new BadRequestException("MasterPasswordHash", "Invalid password."); - } - else - { - await _rotateOrganizationApiKeyCommand.RotateApiKeyAsync(organizationApiKey); - var response = new ApiKeyResponseModel(organizationApiKey); - return response; - } - } - - private async Task HasApiKeyAccessAsync(Guid orgId, OrganizationApiKeyType? type) - { - return type switch - { - OrganizationApiKeyType.Scim => await _currentContext.ManageScim(orgId), - _ => await _currentContext.OrganizationOwner(orgId), - }; - } - - [HttpGet("{id}/tax")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetTaxInfo(string id) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.OrganizationOwner(orgIdGuid)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - var taxInfo = await _paymentService.GetTaxInfoAsync(organization); - return new TaxInfoResponseModel(taxInfo); - } - - [HttpPut("{id}/tax")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PutTaxInfo(string id, [FromBody] OrganizationTaxInfoUpdateRequestModel model) - { - var orgIdGuid = new Guid(id); - if (!await _currentContext.OrganizationOwner(orgIdGuid)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - var taxInfo = new TaxInfo - { - TaxIdNumber = model.TaxId, - BillingAddressLine1 = model.Line1, - BillingAddressLine2 = model.Line2, - BillingAddressCity = model.City, - BillingAddressState = model.State, - BillingAddressPostalCode = model.PostalCode, - BillingAddressCountry = model.Country, - }; - await _paymentService.SaveTaxInfoAsync(organization, taxInfo); - } - - [HttpGet("{id}/keys")] - public async Task GetKeys(string id) - { - var org = await _organizationRepository.GetByIdAsync(new Guid(id)); - if (org == null) - { - throw new NotFoundException(); - } - - return new OrganizationKeysResponseModel(org); - } - - [HttpPost("{id}/keys")] - public async Task PostKeys(string id, [FromBody] OrganizationKeysRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var org = await _organizationService.UpdateOrganizationKeysAsync(new Guid(id), model.PublicKey, model.EncryptedPrivateKey); - return new OrganizationKeysResponseModel(org); - } - - [HttpGet("{id:guid}/sso")] - public async Task GetSso(Guid id) - { - if (!await _currentContext.ManageSso(id)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(id); - if (organization == null) - { - throw new NotFoundException(); - } - - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(id); - - return new OrganizationSsoResponseModel(organization, _globalSettings, ssoConfig); - } - - [HttpPost("{id:guid}/sso")] - public async Task PostSso(Guid id, [FromBody] OrganizationSsoRequestModel model) - { - if (!await _currentContext.ManageSso(id)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(id); - if (organization == null) - { - throw new NotFoundException(); - } - - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(id); - ssoConfig = ssoConfig == null ? model.ToSsoConfig(id) : model.ToSsoConfig(ssoConfig); - - await _ssoConfigService.SaveAsync(ssoConfig, organization); - - return new OrganizationSsoResponseModel(organization, _globalSettings, ssoConfig); + return new OrganizationSubscriptionResponseModel(organization); } } + + [HttpGet("{id}/license")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task GetLicense(string id, [FromQuery] Guid installationId) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.OrganizationOwner(orgIdGuid)) + { + throw new NotFoundException(); + } + + var license = await _organizationService.GenerateLicenseAsync(orgIdGuid, installationId); + if (license == null) + { + throw new NotFoundException(); + } + + return license; + } + + [HttpGet("")] + public async Task> GetUser() + { + var userId = _userService.GetProperUserId(User).Value; + var organizations = await _organizationUserRepository.GetManyDetailsByUserAsync(userId, + OrganizationUserStatusType.Confirmed); + var responses = organizations.Select(o => new ProfileOrganizationResponseModel(o)); + return new ListResponseModel(responses); + } + + [HttpGet("{identifier}/auto-enroll-status")] + public async Task GetAutoEnrollStatus(string identifier) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var organization = await _organizationRepository.GetByIdentifierAsync(identifier); + if (organization == null) + { + throw new NotFoundException(); + } + + var organizationUser = await _organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id); + if (organizationUser == null) + { + throw new NotFoundException(); + } + + var resetPasswordPolicy = + await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); + if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled || resetPasswordPolicy.Data == null) + { + return new OrganizationAutoEnrollStatusResponseModel(organization.Id, false); + } + + var data = JsonSerializer.Deserialize(resetPasswordPolicy.Data, JsonHelpers.IgnoreCase); + return new OrganizationAutoEnrollStatusResponseModel(organization.Id, data?.AutoEnrollEnabled ?? false); + } + + [HttpPost("")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task Post([FromBody] OrganizationCreateRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var organizationSignup = model.ToOrganizationSignup(user); + var result = await _organizationService.SignUpAsync(organizationSignup); + return new OrganizationResponseModel(result.Item1); + } + + [HttpPost("license")] + [SelfHosted(SelfHostedOnly = true)] + public async Task PostLicense(OrganizationCreateLicenseRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var license = await ApiHelpers.ReadJsonFileFromBody(HttpContext, model.License); + if (license == null) + { + throw new BadRequestException("Invalid license"); + } + + var result = await _organizationService.SignUpAsync(license, user, model.Key, + model.CollectionName, model.Keys?.PublicKey, model.Keys?.EncryptedPrivateKey); + return new OrganizationResponseModel(result.Item1); + } + + [HttpPut("{id}")] + [HttpPost("{id}")] + public async Task Put(string id, [FromBody] OrganizationUpdateRequestModel model) + { + var orgIdGuid = new Guid(id); + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + var updateBilling = !_globalSettings.SelfHosted && (model.BusinessName != organization.BusinessName || + model.BillingEmail != organization.BillingEmail); + + var hasRequiredPermissions = updateBilling + ? await _currentContext.ManageBilling(orgIdGuid) + : await _currentContext.OrganizationOwner(orgIdGuid); + + if (!hasRequiredPermissions) + { + throw new NotFoundException(); + } + + await _organizationService.UpdateAsync(model.ToOrganization(organization, _globalSettings), updateBilling); + return new OrganizationResponseModel(organization); + } + + [HttpPost("{id}/payment")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostPayment(string id, [FromBody] PaymentRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) + { + throw new NotFoundException(); + } + + await _organizationService.ReplacePaymentMethodAsync(orgIdGuid, model.PaymentToken, + model.PaymentMethodType.Value, new TaxInfo + { + BillingAddressLine1 = model.Line1, + BillingAddressLine2 = model.Line2, + BillingAddressState = model.State, + BillingAddressCity = model.City, + BillingAddressPostalCode = model.PostalCode, + BillingAddressCountry = model.Country, + TaxIdNumber = model.TaxId, + }); + } + + [HttpPost("{id}/upgrade")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostUpgrade(string id, [FromBody] OrganizationUpgradeRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) + { + throw new NotFoundException(); + } + + var result = await _organizationService.UpgradePlanAsync(orgIdGuid, model.ToOrganizationUpgrade()); + return new PaymentResponseModel + { + Success = result.Item1, + PaymentIntentClientSecret = result.Item2 + }; + } + + [HttpPost("{id}/subscription")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostSubscription(string id, [FromBody] OrganizationSubscriptionUpdateRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) + { + throw new NotFoundException(); + } + + await _organizationService.UpdateSubscription(orgIdGuid, model.SeatAdjustment, model.MaxAutoscaleSeats); + } + + [HttpPost("{id}/seat")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostSeat(string id, [FromBody] OrganizationSeatRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) + { + throw new NotFoundException(); + } + + var result = await _organizationService.AdjustSeatsAsync(orgIdGuid, model.SeatAdjustment.Value); + return new PaymentResponseModel + { + Success = true, + PaymentIntentClientSecret = result + }; + } + + [HttpPost("{id}/storage")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostStorage(string id, [FromBody] StorageRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) + { + throw new NotFoundException(); + } + + var result = await _organizationService.AdjustStorageAsync(orgIdGuid, model.StorageGbAdjustment.Value); + return new PaymentResponseModel + { + Success = true, + PaymentIntentClientSecret = result + }; + } + + [HttpPost("{id}/verify-bank")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostVerifyBank(string id, [FromBody] OrganizationVerifyBankRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) + { + throw new NotFoundException(); + } + + await _organizationService.VerifyBankAsync(orgIdGuid, model.Amount1.Value, model.Amount2.Value); + } + + [HttpPost("{id}/cancel")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostCancel(string id) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) + { + throw new NotFoundException(); + } + + await _organizationService.CancelSubscriptionAsync(orgIdGuid); + } + + [HttpPost("{id}/reinstate")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PostReinstate(string id) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManageBilling(orgIdGuid)) + { + throw new NotFoundException(); + } + + await _organizationService.ReinstateSubscriptionAsync(orgIdGuid); + } + + [HttpPost("{id}/leave")] + public async Task Leave(string id) + { + var orgGuidId = new Guid(id); + if (!await _currentContext.OrganizationUser(orgGuidId)) + { + throw new NotFoundException(); + } + + var user = await _userService.GetUserByPrincipalAsync(User); + + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(orgGuidId); + if (ssoConfig?.GetData()?.KeyConnectorEnabled == true && + user.UsesKeyConnector) + { + throw new BadRequestException("Your organization's Single Sign-On settings prevent you from leaving."); + } + + + await _organizationService.DeleteUserAsync(orgGuidId, user.Id); + } + + [HttpDelete("{id}")] + [HttpPost("{id}/delete")] + public async Task Delete(string id, [FromBody] SecretVerificationRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.OrganizationOwner(orgIdGuid)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (!await _userService.VerifySecretAsync(user, model.Secret)) + { + await Task.Delay(2000); + throw new BadRequestException(string.Empty, "User verification failed."); + } + else + { + await _organizationService.DeleteAsync(organization); + } + } + + [HttpPost("{id}/license")] + [SelfHosted(SelfHostedOnly = true)] + public async Task PostLicense(string id, LicenseRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.OrganizationOwner(orgIdGuid)) + { + throw new NotFoundException(); + } + + var license = await ApiHelpers.ReadJsonFileFromBody(HttpContext, model.License); + if (license == null) + { + throw new BadRequestException("Invalid license"); + } + + await _organizationService.UpdateLicenseAsync(new Guid(id), license); + } + + [HttpPost("{id}/import")] + public async Task Import(string id, [FromBody] ImportOrganizationUsersRequestModel model) + { + if (!_globalSettings.SelfHosted && !model.LargeImport && + (model.Groups.Count() > 2000 || model.Users.Count(u => !u.Deleted) > 2000)) + { + throw new BadRequestException("You cannot import this much data at once."); + } + + var orgIdGuid = new Guid(id); + if (!await _currentContext.OrganizationAdmin(orgIdGuid)) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User); + await _organizationService.ImportAsync( + orgIdGuid, + userId.Value, + model.Groups.Select(g => g.ToImportedGroup(orgIdGuid)), + model.Users.Where(u => !u.Deleted).Select(u => u.ToImportedOrganizationUser()), + model.Users.Where(u => u.Deleted).Select(u => u.ExternalId), + model.OverwriteExisting); + } + + [HttpPost("{id}/api-key")] + public async Task ApiKey(string id, [FromBody] OrganizationApiKeyRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await HasApiKeyAccessAsync(orgIdGuid, model.Type)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + if (model.Type == OrganizationApiKeyType.BillingSync || model.Type == OrganizationApiKeyType.Scim) + { + // Non-enterprise orgs should not be able to create or view an apikey of billing sync/scim key types + var plan = StaticStore.GetPlan(organization.PlanType); + if (plan.Product != ProductType.Enterprise) + { + throw new NotFoundException(); + } + } + + var organizationApiKey = await _getOrganizationApiKeyCommand + .GetOrganizationApiKeyAsync(organization.Id, model.Type); + + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (model.Type != OrganizationApiKeyType.Scim + && !await _userService.VerifySecretAsync(user, model.Secret)) + { + await Task.Delay(2000); + throw new BadRequestException("MasterPasswordHash", "Invalid password."); + } + else + { + var response = new ApiKeyResponseModel(organizationApiKey); + return response; + } + } + + [HttpGet("{id}/api-key-information/{type?}")] + public async Task> ApiKeyInformation(Guid id, OrganizationApiKeyType? type) + { + if (!await HasApiKeyAccessAsync(id, type)) + { + throw new NotFoundException(); + } + + var apiKeys = await _organizationApiKeyRepository.GetManyByOrganizationIdTypeAsync(id, type); + + return new ListResponseModel( + apiKeys.Select(k => new OrganizationApiKeyInformation(k))); + } + + [HttpPost("{id}/rotate-api-key")] + public async Task RotateApiKey(string id, [FromBody] OrganizationApiKeyRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await HasApiKeyAccessAsync(orgIdGuid, model.Type)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + var organizationApiKey = await _getOrganizationApiKeyCommand + .GetOrganizationApiKeyAsync(organization.Id, model.Type); + + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (model.Type != OrganizationApiKeyType.Scim + && !await _userService.VerifySecretAsync(user, model.Secret)) + { + await Task.Delay(2000); + throw new BadRequestException("MasterPasswordHash", "Invalid password."); + } + else + { + await _rotateOrganizationApiKeyCommand.RotateApiKeyAsync(organizationApiKey); + var response = new ApiKeyResponseModel(organizationApiKey); + return response; + } + } + + private async Task HasApiKeyAccessAsync(Guid orgId, OrganizationApiKeyType? type) + { + return type switch + { + OrganizationApiKeyType.Scim => await _currentContext.ManageScim(orgId), + _ => await _currentContext.OrganizationOwner(orgId), + }; + } + + [HttpGet("{id}/tax")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task GetTaxInfo(string id) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.OrganizationOwner(orgIdGuid)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + var taxInfo = await _paymentService.GetTaxInfoAsync(organization); + return new TaxInfoResponseModel(taxInfo); + } + + [HttpPut("{id}/tax")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task PutTaxInfo(string id, [FromBody] OrganizationTaxInfoUpdateRequestModel model) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.OrganizationOwner(orgIdGuid)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + var taxInfo = new TaxInfo + { + TaxIdNumber = model.TaxId, + BillingAddressLine1 = model.Line1, + BillingAddressLine2 = model.Line2, + BillingAddressCity = model.City, + BillingAddressState = model.State, + BillingAddressPostalCode = model.PostalCode, + BillingAddressCountry = model.Country, + }; + await _paymentService.SaveTaxInfoAsync(organization, taxInfo); + } + + [HttpGet("{id}/keys")] + public async Task GetKeys(string id) + { + var org = await _organizationRepository.GetByIdAsync(new Guid(id)); + if (org == null) + { + throw new NotFoundException(); + } + + return new OrganizationKeysResponseModel(org); + } + + [HttpPost("{id}/keys")] + public async Task PostKeys(string id, [FromBody] OrganizationKeysRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var org = await _organizationService.UpdateOrganizationKeysAsync(new Guid(id), model.PublicKey, model.EncryptedPrivateKey); + return new OrganizationKeysResponseModel(org); + } + + [HttpGet("{id:guid}/sso")] + public async Task GetSso(Guid id) + { + if (!await _currentContext.ManageSso(id)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(id); + if (organization == null) + { + throw new NotFoundException(); + } + + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(id); + + return new OrganizationSsoResponseModel(organization, _globalSettings, ssoConfig); + } + + [HttpPost("{id:guid}/sso")] + public async Task PostSso(Guid id, [FromBody] OrganizationSsoRequestModel model) + { + if (!await _currentContext.ManageSso(id)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(id); + if (organization == null) + { + throw new NotFoundException(); + } + + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(id); + ssoConfig = ssoConfig == null ? model.ToSsoConfig(id) : model.ToSsoConfig(ssoConfig); + + await _ssoConfigService.SaveAsync(ssoConfig, organization); + + return new OrganizationSsoResponseModel(organization, _globalSettings, ssoConfig); + } } diff --git a/src/Api/Controllers/PlansController.cs b/src/Api/Controllers/PlansController.cs index 5f5d44c33..d738e60cf 100644 --- a/src/Api/Controllers/PlansController.cs +++ b/src/Api/Controllers/PlansController.cs @@ -4,33 +4,32 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("plans")] +[Authorize("Web")] +public class PlansController : Controller { - [Route("plans")] - [Authorize("Web")] - public class PlansController : Controller + private readonly ITaxRateRepository _taxRateRepository; + public PlansController(ITaxRateRepository taxRateRepository) { - private readonly ITaxRateRepository _taxRateRepository; - public PlansController(ITaxRateRepository taxRateRepository) - { - _taxRateRepository = taxRateRepository; - } + _taxRateRepository = taxRateRepository; + } - [HttpGet("")] - [AllowAnonymous] - public ListResponseModel Get() - { - var data = StaticStore.Plans; - var responses = data.Select(plan => new PlanResponseModel(plan)); - return new ListResponseModel(responses); - } + [HttpGet("")] + [AllowAnonymous] + public ListResponseModel Get() + { + var data = StaticStore.Plans; + var responses = data.Select(plan => new PlanResponseModel(plan)); + return new ListResponseModel(responses); + } - [HttpGet("sales-tax-rates")] - public async Task> GetTaxRates() - { - var data = await _taxRateRepository.GetAllActiveAsync(); - var responses = data.Select(x => new TaxRateResponseModel(x)); - return new ListResponseModel(responses); - } + [HttpGet("sales-tax-rates")] + public async Task> GetTaxRates() + { + var data = await _taxRateRepository.GetAllActiveAsync(); + var responses = data.Select(x => new TaxRateResponseModel(x)); + return new ListResponseModel(responses); } } diff --git a/src/Api/Controllers/PoliciesController.cs b/src/Api/Controllers/PoliciesController.cs index 756b8a9d3..175e1d6a8 100644 --- a/src/Api/Controllers/PoliciesController.cs +++ b/src/Api/Controllers/PoliciesController.cs @@ -11,145 +11,144 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.DataProtection; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("organizations/{orgId}/policies")] +[Authorize("Application")] +public class PoliciesController : Controller { - [Route("organizations/{orgId}/policies")] - [Authorize("Application")] - public class PoliciesController : Controller + private readonly IPolicyRepository _policyRepository; + private readonly IPolicyService _policyService; + private readonly IOrganizationService _organizationService; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IUserService _userService; + private readonly ICurrentContext _currentContext; + private readonly GlobalSettings _globalSettings; + private readonly IDataProtector _organizationServiceDataProtector; + + public PoliciesController( + IPolicyRepository policyRepository, + IPolicyService policyService, + IOrganizationService organizationService, + IOrganizationUserRepository organizationUserRepository, + IUserService userService, + ICurrentContext currentContext, + GlobalSettings globalSettings, + IDataProtectionProvider dataProtectionProvider) { - private readonly IPolicyRepository _policyRepository; - private readonly IPolicyService _policyService; - private readonly IOrganizationService _organizationService; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IUserService _userService; - private readonly ICurrentContext _currentContext; - private readonly GlobalSettings _globalSettings; - private readonly IDataProtector _organizationServiceDataProtector; + _policyRepository = policyRepository; + _policyService = policyService; + _organizationService = organizationService; + _organizationUserRepository = organizationUserRepository; + _userService = userService; + _currentContext = currentContext; + _globalSettings = globalSettings; + _organizationServiceDataProtector = dataProtectionProvider.CreateProtector( + "OrganizationServiceDataProtector"); + } - public PoliciesController( - IPolicyRepository policyRepository, - IPolicyService policyService, - IOrganizationService organizationService, - IOrganizationUserRepository organizationUserRepository, - IUserService userService, - ICurrentContext currentContext, - GlobalSettings globalSettings, - IDataProtectionProvider dataProtectionProvider) + [HttpGet("{type}")] + public async Task Get(string orgId, int type) + { + var orgIdGuid = new Guid(orgId); + if (!await _currentContext.ManagePolicies(orgIdGuid)) { - _policyRepository = policyRepository; - _policyService = policyService; - _organizationService = organizationService; - _organizationUserRepository = organizationUserRepository; - _userService = userService; - _currentContext = currentContext; - _globalSettings = globalSettings; - _organizationServiceDataProtector = dataProtectionProvider.CreateProtector( - "OrganizationServiceDataProtector"); + throw new NotFoundException(); + } + var policy = await _policyRepository.GetByOrganizationIdTypeAsync(orgIdGuid, (PolicyType)type); + if (policy == null) + { + throw new NotFoundException(); } - [HttpGet("{type}")] - public async Task Get(string orgId, int type) - { - var orgIdGuid = new Guid(orgId); - if (!await _currentContext.ManagePolicies(orgIdGuid)) - { - throw new NotFoundException(); - } - var policy = await _policyRepository.GetByOrganizationIdTypeAsync(orgIdGuid, (PolicyType)type); - if (policy == null) - { - throw new NotFoundException(); - } + return new PolicyResponseModel(policy); + } - return new PolicyResponseModel(policy); + [HttpGet("")] + public async Task> Get(string orgId) + { + var orgIdGuid = new Guid(orgId); + if (!await _currentContext.ManagePolicies(orgIdGuid)) + { + throw new NotFoundException(); } - [HttpGet("")] - public async Task> Get(string orgId) - { - var orgIdGuid = new Guid(orgId); - if (!await _currentContext.ManagePolicies(orgIdGuid)) - { - throw new NotFoundException(); - } + var policies = await _policyRepository.GetManyByOrganizationIdAsync(orgIdGuid); + var responses = policies.Select(p => new PolicyResponseModel(p)); + return new ListResponseModel(responses); + } - var policies = await _policyRepository.GetManyByOrganizationIdAsync(orgIdGuid); - var responses = policies.Select(p => new PolicyResponseModel(p)); - return new ListResponseModel(responses); + [AllowAnonymous] + [HttpGet("token")] + public async Task> GetByToken(string orgId, [FromQuery] string email, + [FromQuery] string token, [FromQuery] string organizationUserId) + { + var orgUserId = new Guid(organizationUserId); + var tokenValid = CoreHelpers.UserInviteTokenIsValid(_organizationServiceDataProtector, token, + email, orgUserId, _globalSettings); + if (!tokenValid) + { + throw new NotFoundException(); } - [AllowAnonymous] - [HttpGet("token")] - public async Task> GetByToken(string orgId, [FromQuery] string email, - [FromQuery] string token, [FromQuery] string organizationUserId) + var orgIdGuid = new Guid(orgId); + var orgUser = await _organizationUserRepository.GetByIdAsync(orgUserId); + if (orgUser == null || orgUser.OrganizationId != orgIdGuid) { - var orgUserId = new Guid(organizationUserId); - var tokenValid = CoreHelpers.UserInviteTokenIsValid(_organizationServiceDataProtector, token, - email, orgUserId, _globalSettings); - if (!tokenValid) - { - throw new NotFoundException(); - } - - var orgIdGuid = new Guid(orgId); - var orgUser = await _organizationUserRepository.GetByIdAsync(orgUserId); - if (orgUser == null || orgUser.OrganizationId != orgIdGuid) - { - throw new NotFoundException(); - } - - var policies = await _policyRepository.GetManyByOrganizationIdAsync(orgIdGuid); - var responses = policies.Where(p => p.Enabled).Select(p => new PolicyResponseModel(p)); - return new ListResponseModel(responses); + throw new NotFoundException(); } - [AllowAnonymous] - [HttpGet("invited-user")] - public async Task> GetByInvitedUser(string orgId, [FromQuery] string userId) - { - var user = await _userService.GetUserByIdAsync(new Guid(userId)); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - var orgIdGuid = new Guid(orgId); - var orgUsersByUserId = await _organizationUserRepository.GetManyByUserAsync(user.Id); - var orgUser = orgUsersByUserId.SingleOrDefault(u => u.OrganizationId == orgIdGuid); - if (orgUser == null) - { - throw new NotFoundException(); - } - if (orgUser.Status != OrganizationUserStatusType.Invited) - { - throw new UnauthorizedAccessException(); - } + var policies = await _policyRepository.GetManyByOrganizationIdAsync(orgIdGuid); + var responses = policies.Where(p => p.Enabled).Select(p => new PolicyResponseModel(p)); + return new ListResponseModel(responses); + } - var policies = await _policyRepository.GetManyByOrganizationIdAsync(orgIdGuid); - var responses = policies.Where(p => p.Enabled).Select(p => new PolicyResponseModel(p)); - return new ListResponseModel(responses); + [AllowAnonymous] + [HttpGet("invited-user")] + public async Task> GetByInvitedUser(string orgId, [FromQuery] string userId) + { + var user = await _userService.GetUserByIdAsync(new Guid(userId)); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + var orgIdGuid = new Guid(orgId); + var orgUsersByUserId = await _organizationUserRepository.GetManyByUserAsync(user.Id); + var orgUser = orgUsersByUserId.SingleOrDefault(u => u.OrganizationId == orgIdGuid); + if (orgUser == null) + { + throw new NotFoundException(); + } + if (orgUser.Status != OrganizationUserStatusType.Invited) + { + throw new UnauthorizedAccessException(); } - [HttpPut("{type}")] - public async Task Put(string orgId, int type, [FromBody] PolicyRequestModel model) - { - var orgIdGuid = new Guid(orgId); - if (!await _currentContext.ManagePolicies(orgIdGuid)) - { - throw new NotFoundException(); - } - var policy = await _policyRepository.GetByOrganizationIdTypeAsync(new Guid(orgId), (PolicyType)type); - if (policy == null) - { - policy = model.ToPolicy(orgIdGuid); - } - else - { - policy = model.ToPolicy(policy); - } + var policies = await _policyRepository.GetManyByOrganizationIdAsync(orgIdGuid); + var responses = policies.Where(p => p.Enabled).Select(p => new PolicyResponseModel(p)); + return new ListResponseModel(responses); + } - var userId = _userService.GetProperUserId(User); - await _policyService.SaveAsync(policy, _userService, _organizationService, userId); - return new PolicyResponseModel(policy); + [HttpPut("{type}")] + public async Task Put(string orgId, int type, [FromBody] PolicyRequestModel model) + { + var orgIdGuid = new Guid(orgId); + if (!await _currentContext.ManagePolicies(orgIdGuid)) + { + throw new NotFoundException(); } + var policy = await _policyRepository.GetByOrganizationIdTypeAsync(new Guid(orgId), (PolicyType)type); + if (policy == null) + { + policy = model.ToPolicy(orgIdGuid); + } + else + { + policy = model.ToPolicy(policy); + } + + var userId = _userService.GetProperUserId(User); + await _policyService.SaveAsync(policy, _userService, _organizationService, userId); + return new PolicyResponseModel(policy); } } diff --git a/src/Api/Controllers/ProviderOrganizationsController.cs b/src/Api/Controllers/ProviderOrganizationsController.cs index f4772fbe2..222d11302 100644 --- a/src/Api/Controllers/ProviderOrganizationsController.cs +++ b/src/Api/Controllers/ProviderOrganizationsController.cs @@ -9,87 +9,86 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("providers/{providerId:guid}/organizations")] +[Authorize("Application")] +public class ProviderOrganizationsController : Controller { - [Route("providers/{providerId:guid}/organizations")] - [Authorize("Application")] - public class ProviderOrganizationsController : Controller + + private readonly IProviderOrganizationRepository _providerOrganizationRepository; + private readonly IProviderService _providerService; + private readonly IUserService _userService; + private readonly ICurrentContext _currentContext; + + public ProviderOrganizationsController( + IProviderOrganizationRepository providerOrganizationRepository, + IProviderService providerService, + IUserService userService, + ICurrentContext currentContext) { + _providerOrganizationRepository = providerOrganizationRepository; + _providerService = providerService; + _userService = userService; + _currentContext = currentContext; + } - private readonly IProviderOrganizationRepository _providerOrganizationRepository; - private readonly IProviderService _providerService; - private readonly IUserService _userService; - private readonly ICurrentContext _currentContext; - - public ProviderOrganizationsController( - IProviderOrganizationRepository providerOrganizationRepository, - IProviderService providerService, - IUserService userService, - ICurrentContext currentContext) + [HttpGet("")] + public async Task> Get(Guid providerId) + { + if (!_currentContext.AccessProviderOrganizations(providerId)) { - _providerOrganizationRepository = providerOrganizationRepository; - _providerService = providerService; - _userService = userService; - _currentContext = currentContext; + throw new NotFoundException(); } - [HttpGet("")] - public async Task> Get(Guid providerId) - { - if (!_currentContext.AccessProviderOrganizations(providerId)) - { - throw new NotFoundException(); - } + var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId); + var responses = providerOrganizations.Select(o => new ProviderOrganizationOrganizationDetailsResponseModel(o)); + return new ListResponseModel(responses); + } - var providerOrganizations = await _providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId); - var responses = providerOrganizations.Select(o => new ProviderOrganizationOrganizationDetailsResponseModel(o)); - return new ListResponseModel(responses); + [HttpPost("add")] + public async Task Add(Guid providerId, [FromBody] ProviderOrganizationAddRequestModel model) + { + if (!_currentContext.ManageProviderOrganizations(providerId)) + { + throw new NotFoundException(); } - [HttpPost("add")] - public async Task Add(Guid providerId, [FromBody] ProviderOrganizationAddRequestModel model) + var userId = _userService.GetProperUserId(User).Value; + + await _providerService.AddOrganization(providerId, model.OrganizationId, userId, model.Key); + } + + [HttpPost("")] + [SelfHosted(NotSelfHostedOnly = true)] + public async Task Post(Guid providerId, [FromBody] ProviderOrganizationCreateRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) { - if (!_currentContext.ManageProviderOrganizations(providerId)) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User).Value; - - await _providerService.AddOrganization(providerId, model.OrganizationId, userId, model.Key); + throw new UnauthorizedAccessException(); } - [HttpPost("")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task Post(Guid providerId, [FromBody] ProviderOrganizationCreateRequestModel model) + if (!_currentContext.ManageProviderOrganizations(providerId)) { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (!_currentContext.ManageProviderOrganizations(providerId)) - { - throw new NotFoundException(); - } - - var organizationSignup = model.OrganizationCreateRequest.ToOrganizationSignup(user); - var result = await _providerService.CreateOrganizationAsync(providerId, organizationSignup, model.ClientOwnerEmail, user); - return new ProviderOrganizationResponseModel(result); + throw new NotFoundException(); } - [HttpDelete("{id:guid}")] - [HttpPost("{id:guid}/delete")] - public async Task Delete(Guid providerId, Guid id) - { - if (!_currentContext.ManageProviderOrganizations(providerId)) - { - throw new NotFoundException(); - } + var organizationSignup = model.OrganizationCreateRequest.ToOrganizationSignup(user); + var result = await _providerService.CreateOrganizationAsync(providerId, organizationSignup, model.ClientOwnerEmail, user); + return new ProviderOrganizationResponseModel(result); + } - var userId = _userService.GetProperUserId(User); - await _providerService.RemoveOrganizationAsync(providerId, id, userId.Value); + [HttpDelete("{id:guid}")] + [HttpPost("{id:guid}/delete")] + public async Task Delete(Guid providerId, Guid id) + { + if (!_currentContext.ManageProviderOrganizations(providerId)) + { + throw new NotFoundException(); } + + var userId = _userService.GetProperUserId(User); + await _providerService.RemoveOrganizationAsync(providerId, id, userId.Value); } } diff --git a/src/Api/Controllers/ProviderUsersController.cs b/src/Api/Controllers/ProviderUsersController.cs index ad9dec639..f88394c0b 100644 --- a/src/Api/Controllers/ProviderUsersController.cs +++ b/src/Api/Controllers/ProviderUsersController.cs @@ -9,192 +9,191 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("providers/{providerId:guid}/users")] +[Authorize("Application")] +public class ProviderUsersController : Controller { - [Route("providers/{providerId:guid}/users")] - [Authorize("Application")] - public class ProviderUsersController : Controller + private readonly IProviderUserRepository _providerUserRepository; + private readonly IProviderService _providerService; + private readonly IUserService _userService; + private readonly ICurrentContext _currentContext; + + public ProviderUsersController( + IProviderUserRepository providerUserRepository, + IProviderService providerService, + IUserService userService, + ICurrentContext currentContext) { - private readonly IProviderUserRepository _providerUserRepository; - private readonly IProviderService _providerService; - private readonly IUserService _userService; - private readonly ICurrentContext _currentContext; + _providerUserRepository = providerUserRepository; + _providerService = providerService; + _userService = userService; + _currentContext = currentContext; + } - public ProviderUsersController( - IProviderUserRepository providerUserRepository, - IProviderService providerService, - IUserService userService, - ICurrentContext currentContext) + [HttpGet("{id:guid}")] + public async Task Get(Guid providerId, Guid id) + { + var providerUser = await _providerUserRepository.GetByIdAsync(id); + if (providerUser == null || !_currentContext.ProviderManageUsers(providerUser.ProviderId)) { - _providerUserRepository = providerUserRepository; - _providerService = providerService; - _userService = userService; - _currentContext = currentContext; + throw new NotFoundException(); } - [HttpGet("{id:guid}")] - public async Task Get(Guid providerId, Guid id) - { - var providerUser = await _providerUserRepository.GetByIdAsync(id); - if (providerUser == null || !_currentContext.ProviderManageUsers(providerUser.ProviderId)) - { - throw new NotFoundException(); - } + return new ProviderUserResponseModel(providerUser); + } - return new ProviderUserResponseModel(providerUser); + [HttpGet("")] + public async Task> Get(Guid providerId) + { + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); } - [HttpGet("")] - public async Task> Get(Guid providerId) - { - if (!_currentContext.ProviderManageUsers(providerId)) - { - throw new NotFoundException(); - } + var providerUsers = await _providerUserRepository.GetManyDetailsByProviderAsync(providerId); + var responses = providerUsers.Select(o => new ProviderUserUserDetailsResponseModel(o)); + return new ListResponseModel(responses); + } - var providerUsers = await _providerUserRepository.GetManyDetailsByProviderAsync(providerId); - var responses = providerUsers.Select(o => new ProviderUserUserDetailsResponseModel(o)); - return new ListResponseModel(responses); + [HttpPost("invite")] + public async Task Invite(Guid providerId, [FromBody] ProviderUserInviteRequestModel model) + { + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); } - [HttpPost("invite")] - public async Task Invite(Guid providerId, [FromBody] ProviderUserInviteRequestModel model) - { - if (!_currentContext.ProviderManageUsers(providerId)) - { - throw new NotFoundException(); - } + var invite = ProviderUserInviteFactory.CreateIntialInvite(model.Emails, model.Type.Value, + _userService.GetProperUserId(User).Value, providerId); + await _providerService.InviteUserAsync(invite); + } - var invite = ProviderUserInviteFactory.CreateIntialInvite(model.Emails, model.Type.Value, - _userService.GetProperUserId(User).Value, providerId); - await _providerService.InviteUserAsync(invite); + [HttpPost("reinvite")] + public async Task> BulkReinvite(Guid providerId, [FromBody] ProviderUserBulkRequestModel model) + { + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); } - [HttpPost("reinvite")] - public async Task> BulkReinvite(Guid providerId, [FromBody] ProviderUserBulkRequestModel model) - { - if (!_currentContext.ProviderManageUsers(providerId)) - { - throw new NotFoundException(); - } + var invite = ProviderUserInviteFactory.CreateReinvite(model.Ids, _userService.GetProperUserId(User).Value, providerId); + var result = await _providerService.ResendInvitesAsync(invite); + return new ListResponseModel( + result.Select(t => new ProviderUserBulkResponseModel(t.Item1.Id, t.Item2))); + } - var invite = ProviderUserInviteFactory.CreateReinvite(model.Ids, _userService.GetProperUserId(User).Value, providerId); - var result = await _providerService.ResendInvitesAsync(invite); - return new ListResponseModel( - result.Select(t => new ProviderUserBulkResponseModel(t.Item1.Id, t.Item2))); + [HttpPost("{id:guid}/reinvite")] + public async Task Reinvite(Guid providerId, Guid id) + { + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); } - [HttpPost("{id:guid}/reinvite")] - public async Task Reinvite(Guid providerId, Guid id) - { - if (!_currentContext.ProviderManageUsers(providerId)) - { - throw new NotFoundException(); - } + var invite = ProviderUserInviteFactory.CreateReinvite(new[] { id }, + _userService.GetProperUserId(User).Value, providerId); + await _providerService.ResendInvitesAsync(invite); + } - var invite = ProviderUserInviteFactory.CreateReinvite(new[] { id }, - _userService.GetProperUserId(User).Value, providerId); - await _providerService.ResendInvitesAsync(invite); + [HttpPost("{id:guid}/accept")] + public async Task Accept(Guid providerId, Guid id, [FromBody] ProviderUserAcceptRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); } - [HttpPost("{id:guid}/accept")] - public async Task Accept(Guid providerId, Guid id, [FromBody] ProviderUserAcceptRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } + await _providerService.AcceptUserAsync(id, user, model.Token); + } - await _providerService.AcceptUserAsync(id, user, model.Token); + [HttpPost("{id:guid}/confirm")] + public async Task Confirm(Guid providerId, Guid id, [FromBody] ProviderUserConfirmRequestModel model) + { + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); } - [HttpPost("{id:guid}/confirm")] - public async Task Confirm(Guid providerId, Guid id, [FromBody] ProviderUserConfirmRequestModel model) - { - if (!_currentContext.ProviderManageUsers(providerId)) - { - throw new NotFoundException(); - } + var userId = _userService.GetProperUserId(User); + await _providerService.ConfirmUsersAsync(providerId, new Dictionary { [id] = model.Key }, userId.Value); + } - var userId = _userService.GetProperUserId(User); - await _providerService.ConfirmUsersAsync(providerId, new Dictionary { [id] = model.Key }, userId.Value); + [HttpPost("confirm")] + public async Task> BulkConfirm(Guid providerId, + [FromBody] ProviderUserBulkConfirmRequestModel model) + { + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); } - [HttpPost("confirm")] - public async Task> BulkConfirm(Guid providerId, - [FromBody] ProviderUserBulkConfirmRequestModel model) + var userId = _userService.GetProperUserId(User); + var results = await _providerService.ConfirmUsersAsync(providerId, model.ToDictionary(), userId.Value); + + return new ListResponseModel(results.Select(r => + new ProviderUserBulkResponseModel(r.Item1.Id, r.Item2))); + } + + [HttpPost("public-keys")] + public async Task> UserPublicKeys(Guid providerId, [FromBody] ProviderUserBulkRequestModel model) + { + if (!_currentContext.ProviderManageUsers(providerId)) { - if (!_currentContext.ProviderManageUsers(providerId)) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User); - var results = await _providerService.ConfirmUsersAsync(providerId, model.ToDictionary(), userId.Value); - - return new ListResponseModel(results.Select(r => - new ProviderUserBulkResponseModel(r.Item1.Id, r.Item2))); + throw new NotFoundException(); } - [HttpPost("public-keys")] - public async Task> UserPublicKeys(Guid providerId, [FromBody] ProviderUserBulkRequestModel model) - { - if (!_currentContext.ProviderManageUsers(providerId)) - { - throw new NotFoundException(); - } + var result = await _providerUserRepository.GetManyPublicKeysByProviderUserAsync(providerId, model.Ids); + var responses = result.Select(r => new ProviderUserPublicKeyResponseModel(r.Id, r.UserId, r.PublicKey)).ToList(); + return new ListResponseModel(responses); + } - var result = await _providerUserRepository.GetManyPublicKeysByProviderUserAsync(providerId, model.Ids); - var responses = result.Select(r => new ProviderUserPublicKeyResponseModel(r.Id, r.UserId, r.PublicKey)).ToList(); - return new ListResponseModel(responses); + [HttpPut("{id:guid}")] + [HttpPost("{id:guid}")] + public async Task Put(Guid providerId, Guid id, [FromBody] ProviderUserUpdateRequestModel model) + { + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); } - [HttpPut("{id:guid}")] - [HttpPost("{id:guid}")] - public async Task Put(Guid providerId, Guid id, [FromBody] ProviderUserUpdateRequestModel model) + var providerUser = await _providerUserRepository.GetByIdAsync(id); + if (providerUser == null || providerUser.ProviderId != providerId) { - if (!_currentContext.ProviderManageUsers(providerId)) - { - throw new NotFoundException(); - } - - var providerUser = await _providerUserRepository.GetByIdAsync(id); - if (providerUser == null || providerUser.ProviderId != providerId) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User); - await _providerService.SaveUserAsync(model.ToProviderUser(providerUser), userId.Value); + throw new NotFoundException(); } - [HttpDelete("{id:guid}")] - [HttpPost("{id:guid}/delete")] - public async Task Delete(Guid providerId, Guid id) - { - if (!_currentContext.ProviderManageUsers(providerId)) - { - throw new NotFoundException(); - } + var userId = _userService.GetProperUserId(User); + await _providerService.SaveUserAsync(model.ToProviderUser(providerUser), userId.Value); + } - var userId = _userService.GetProperUserId(User); - await _providerService.DeleteUsersAsync(providerId, new[] { id }, userId.Value); + [HttpDelete("{id:guid}")] + [HttpPost("{id:guid}/delete")] + public async Task Delete(Guid providerId, Guid id) + { + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); } - [HttpDelete("")] - [HttpPost("delete")] - public async Task> BulkDelete(Guid providerId, [FromBody] ProviderUserBulkRequestModel model) - { - if (!_currentContext.ProviderManageUsers(providerId)) - { - throw new NotFoundException(); - } + var userId = _userService.GetProperUserId(User); + await _providerService.DeleteUsersAsync(providerId, new[] { id }, userId.Value); + } - var userId = _userService.GetProperUserId(User); - var result = await _providerService.DeleteUsersAsync(providerId, model.Ids, userId.Value); - return new ListResponseModel(result.Select(r => - new ProviderUserBulkResponseModel(r.Item1.Id, r.Item2))); + [HttpDelete("")] + [HttpPost("delete")] + public async Task> BulkDelete(Guid providerId, [FromBody] ProviderUserBulkRequestModel model) + { + if (!_currentContext.ProviderManageUsers(providerId)) + { + throw new NotFoundException(); } + + var userId = _userService.GetProperUserId(User); + var result = await _providerService.DeleteUsersAsync(providerId, model.Ids, userId.Value); + return new ListResponseModel(result.Select(r => + new ProviderUserBulkResponseModel(r.Item1.Id, r.Item2))); } } diff --git a/src/Api/Controllers/ProvidersController.cs b/src/Api/Controllers/ProvidersController.cs index 5969c0c6f..5daf9ce49 100644 --- a/src/Api/Controllers/ProvidersController.cs +++ b/src/Api/Controllers/ProvidersController.cs @@ -8,84 +8,83 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("providers")] +[Authorize("Application")] +public class ProvidersController : Controller { - [Route("providers")] - [Authorize("Application")] - public class ProvidersController : Controller + private readonly IUserService _userService; + private readonly IProviderRepository _providerRepository; + private readonly IProviderService _providerService; + private readonly ICurrentContext _currentContext; + private readonly GlobalSettings _globalSettings; + + public ProvidersController(IUserService userService, IProviderRepository providerRepository, + IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings) { - private readonly IUserService _userService; - private readonly IProviderRepository _providerRepository; - private readonly IProviderService _providerService; - private readonly ICurrentContext _currentContext; - private readonly GlobalSettings _globalSettings; + _userService = userService; + _providerRepository = providerRepository; + _providerService = providerService; + _currentContext = currentContext; + _globalSettings = globalSettings; + } - public ProvidersController(IUserService userService, IProviderRepository providerRepository, - IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings) + [HttpGet("{id:guid}")] + public async Task Get(Guid id) + { + if (!_currentContext.ProviderUser(id)) { - _userService = userService; - _providerRepository = providerRepository; - _providerService = providerService; - _currentContext = currentContext; - _globalSettings = globalSettings; + throw new NotFoundException(); } - [HttpGet("{id:guid}")] - public async Task Get(Guid id) + var provider = await _providerRepository.GetByIdAsync(id); + if (provider == null) { - if (!_currentContext.ProviderUser(id)) - { - throw new NotFoundException(); - } - - var provider = await _providerRepository.GetByIdAsync(id); - if (provider == null) - { - throw new NotFoundException(); - } - - return new ProviderResponseModel(provider); + throw new NotFoundException(); } - [HttpPut("{id:guid}")] - [HttpPost("{id:guid}")] - public async Task Put(Guid id, [FromBody] ProviderUpdateRequestModel model) + return new ProviderResponseModel(provider); + } + + [HttpPut("{id:guid}")] + [HttpPost("{id:guid}")] + public async Task Put(Guid id, [FromBody] ProviderUpdateRequestModel model) + { + if (!_currentContext.ProviderProviderAdmin(id)) { - if (!_currentContext.ProviderProviderAdmin(id)) - { - throw new NotFoundException(); - } - - var provider = await _providerRepository.GetByIdAsync(id); - if (provider == null) - { - throw new NotFoundException(); - } - - await _providerService.UpdateAsync(model.ToProvider(provider, _globalSettings)); - return new ProviderResponseModel(provider); + throw new NotFoundException(); } - [HttpPost("{id:guid}/setup")] - public async Task Setup(Guid id, [FromBody] ProviderSetupRequestModel model) + var provider = await _providerRepository.GetByIdAsync(id); + if (provider == null) { - if (!_currentContext.ProviderProviderAdmin(id)) - { - throw new NotFoundException(); - } - - var provider = await _providerRepository.GetByIdAsync(id); - if (provider == null) - { - throw new NotFoundException(); - } - - var userId = _userService.GetProperUserId(User).Value; - - var response = - await _providerService.CompleteSetupAsync(model.ToProvider(provider), userId, model.Token, model.Key); - - return new ProviderResponseModel(response); + throw new NotFoundException(); } + + await _providerService.UpdateAsync(model.ToProvider(provider, _globalSettings)); + return new ProviderResponseModel(provider); + } + + [HttpPost("{id:guid}/setup")] + public async Task Setup(Guid id, [FromBody] ProviderSetupRequestModel model) + { + if (!_currentContext.ProviderProviderAdmin(id)) + { + throw new NotFoundException(); + } + + var provider = await _providerRepository.GetByIdAsync(id); + if (provider == null) + { + throw new NotFoundException(); + } + + var userId = _userService.GetProperUserId(User).Value; + + var response = + await _providerService.CompleteSetupAsync(model.ToProvider(provider), userId, model.Token, model.Key); + + return new ProviderResponseModel(response); } } diff --git a/src/Api/Controllers/PushController.cs b/src/Api/Controllers/PushController.cs index afeaf92f7..7312cb7b8 100644 --- a/src/Api/Controllers/PushController.cs +++ b/src/Api/Controllers/PushController.cs @@ -7,109 +7,108 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("push")] +[Authorize("Push")] +[SelfHosted(NotSelfHostedOnly = true)] +public class PushController : Controller { - [Route("push")] - [Authorize("Push")] - [SelfHosted(NotSelfHostedOnly = true)] - public class PushController : Controller + private readonly IPushRegistrationService _pushRegistrationService; + private readonly IPushNotificationService _pushNotificationService; + private readonly IWebHostEnvironment _environment; + private readonly ICurrentContext _currentContext; + private readonly GlobalSettings _globalSettings; + + public PushController( + IPushRegistrationService pushRegistrationService, + IPushNotificationService pushNotificationService, + IWebHostEnvironment environment, + ICurrentContext currentContext, + GlobalSettings globalSettings) { - private readonly IPushRegistrationService _pushRegistrationService; - private readonly IPushNotificationService _pushNotificationService; - private readonly IWebHostEnvironment _environment; - private readonly ICurrentContext _currentContext; - private readonly GlobalSettings _globalSettings; + _currentContext = currentContext; + _environment = environment; + _pushRegistrationService = pushRegistrationService; + _pushNotificationService = pushNotificationService; + _globalSettings = globalSettings; + } - public PushController( - IPushRegistrationService pushRegistrationService, - IPushNotificationService pushNotificationService, - IWebHostEnvironment environment, - ICurrentContext currentContext, - GlobalSettings globalSettings) + [HttpPost("register")] + public async Task PostRegister([FromBody] PushRegistrationRequestModel model) + { + CheckUsage(); + await _pushRegistrationService.CreateOrUpdateRegistrationAsync(model.PushToken, Prefix(model.DeviceId), + Prefix(model.UserId), Prefix(model.Identifier), model.Type); + } + + [HttpDelete("{id}")] + public async Task Delete(string id) + { + CheckUsage(); + await _pushRegistrationService.DeleteRegistrationAsync(Prefix(id)); + } + + [HttpPut("add-organization")] + public async Task PutAddOrganization([FromBody] PushUpdateRequestModel model) + { + CheckUsage(); + await _pushRegistrationService.AddUserRegistrationOrganizationAsync( + model.DeviceIds.Select(d => Prefix(d)), Prefix(model.OrganizationId)); + } + + [HttpPut("delete-organization")] + public async Task PutDeleteOrganization([FromBody] PushUpdateRequestModel model) + { + CheckUsage(); + await _pushRegistrationService.DeleteUserRegistrationOrganizationAsync( + model.DeviceIds.Select(d => Prefix(d)), Prefix(model.OrganizationId)); + } + + [HttpPost("send")] + public async Task PostSend([FromBody] PushSendRequestModel model) + { + CheckUsage(); + + if (!string.IsNullOrWhiteSpace(model.UserId)) { - _currentContext = currentContext; - _environment = environment; - _pushRegistrationService = pushRegistrationService; - _pushNotificationService = pushNotificationService; - _globalSettings = globalSettings; + await _pushNotificationService.SendPayloadToUserAsync(Prefix(model.UserId), + model.Type.Value, model.Payload, Prefix(model.Identifier), Prefix(model.DeviceId)); } - - [HttpPost("register")] - public async Task PostRegister([FromBody] PushRegistrationRequestModel model) + else if (!string.IsNullOrWhiteSpace(model.OrganizationId)) { - CheckUsage(); - await _pushRegistrationService.CreateOrUpdateRegistrationAsync(model.PushToken, Prefix(model.DeviceId), - Prefix(model.UserId), Prefix(model.Identifier), model.Type); - } - - [HttpDelete("{id}")] - public async Task Delete(string id) - { - CheckUsage(); - await _pushRegistrationService.DeleteRegistrationAsync(Prefix(id)); - } - - [HttpPut("add-organization")] - public async Task PutAddOrganization([FromBody] PushUpdateRequestModel model) - { - CheckUsage(); - await _pushRegistrationService.AddUserRegistrationOrganizationAsync( - model.DeviceIds.Select(d => Prefix(d)), Prefix(model.OrganizationId)); - } - - [HttpPut("delete-organization")] - public async Task PutDeleteOrganization([FromBody] PushUpdateRequestModel model) - { - CheckUsage(); - await _pushRegistrationService.DeleteUserRegistrationOrganizationAsync( - model.DeviceIds.Select(d => Prefix(d)), Prefix(model.OrganizationId)); - } - - [HttpPost("send")] - public async Task PostSend([FromBody] PushSendRequestModel model) - { - CheckUsage(); - - if (!string.IsNullOrWhiteSpace(model.UserId)) - { - await _pushNotificationService.SendPayloadToUserAsync(Prefix(model.UserId), - model.Type.Value, model.Payload, Prefix(model.Identifier), Prefix(model.DeviceId)); - } - else if (!string.IsNullOrWhiteSpace(model.OrganizationId)) - { - await _pushNotificationService.SendPayloadToOrganizationAsync(Prefix(model.OrganizationId), - model.Type.Value, model.Payload, Prefix(model.Identifier), Prefix(model.DeviceId)); - } - } - - private string Prefix(string value) - { - if (string.IsNullOrWhiteSpace(value)) - { - return null; - } - - return $"{_currentContext.InstallationId.Value}_{value}"; - } - - private void CheckUsage() - { - if (CanUse()) - { - return; - } - - throw new BadRequestException("Not correctly configured for push relays."); - } - - private bool CanUse() - { - if (_environment.IsDevelopment()) - { - return true; - } - - return _currentContext.InstallationId.HasValue && !_globalSettings.SelfHosted; + await _pushNotificationService.SendPayloadToOrganizationAsync(Prefix(model.OrganizationId), + model.Type.Value, model.Payload, Prefix(model.Identifier), Prefix(model.DeviceId)); } } + + private string Prefix(string value) + { + if (string.IsNullOrWhiteSpace(value)) + { + return null; + } + + return $"{_currentContext.InstallationId.Value}_{value}"; + } + + private void CheckUsage() + { + if (CanUse()) + { + return; + } + + throw new BadRequestException("Not correctly configured for push relays."); + } + + private bool CanUse() + { + if (_environment.IsDevelopment()) + { + return true; + } + + return _currentContext.InstallationId.HasValue && !_globalSettings.SelfHosted; + } } diff --git a/src/Api/Controllers/SelfHosted/SelfHostedOrganizationSponsorshipsController.cs b/src/Api/Controllers/SelfHosted/SelfHostedOrganizationSponsorshipsController.cs index b74192983..ffb5c7bb9 100644 --- a/src/Api/Controllers/SelfHosted/SelfHostedOrganizationSponsorshipsController.cs +++ b/src/Api/Controllers/SelfHosted/SelfHostedOrganizationSponsorshipsController.cs @@ -7,61 +7,60 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers.SelfHosted +namespace Bit.Api.Controllers.SelfHosted; + +[Route("organization/sponsorship/self-hosted")] +[Authorize("Application")] +[SelfHosted(SelfHostedOnly = true)] +public class SelfHostedOrganizationSponsorshipsController : Controller { - [Route("organization/sponsorship/self-hosted")] - [Authorize("Application")] - [SelfHosted(SelfHostedOnly = true)] - public class SelfHostedOrganizationSponsorshipsController : Controller + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + private readonly ICreateSponsorshipCommand _offerSponsorshipCommand; + private readonly IRevokeSponsorshipCommand _revokeSponsorshipCommand; + private readonly ICurrentContext _currentContext; + + public SelfHostedOrganizationSponsorshipsController( + ICreateSponsorshipCommand offerSponsorshipCommand, + IRevokeSponsorshipCommand revokeSponsorshipCommand, + IOrganizationRepository organizationRepository, + IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IOrganizationUserRepository organizationUserRepository, + ICurrentContext currentContext + ) { - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - private readonly ICreateSponsorshipCommand _offerSponsorshipCommand; - private readonly IRevokeSponsorshipCommand _revokeSponsorshipCommand; - private readonly ICurrentContext _currentContext; + _offerSponsorshipCommand = offerSponsorshipCommand; + _revokeSponsorshipCommand = revokeSponsorshipCommand; + _organizationRepository = organizationRepository; + _organizationSponsorshipRepository = organizationSponsorshipRepository; + _organizationUserRepository = organizationUserRepository; + _currentContext = currentContext; + } - public SelfHostedOrganizationSponsorshipsController( - ICreateSponsorshipCommand offerSponsorshipCommand, - IRevokeSponsorshipCommand revokeSponsorshipCommand, - IOrganizationRepository organizationRepository, - IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IOrganizationUserRepository organizationUserRepository, - ICurrentContext currentContext - ) + [HttpPost("{sponsoringOrgId}/families-for-enterprise")] + public async Task CreateSponsorship(Guid sponsoringOrgId, [FromBody] OrganizationSponsorshipCreateRequestModel model) + { + await _offerSponsorshipCommand.CreateSponsorshipAsync( + await _organizationRepository.GetByIdAsync(sponsoringOrgId), + await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default), + model.PlanSponsorshipType, model.SponsoredEmail, model.FriendlyName); + } + + [HttpDelete("{sponsoringOrgId}")] + [HttpPost("{sponsoringOrgId}/delete")] + public async Task RevokeSponsorship(Guid sponsoringOrgId) + { + var orgUser = await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default); + + if (orgUser == null) { - _offerSponsorshipCommand = offerSponsorshipCommand; - _revokeSponsorshipCommand = revokeSponsorshipCommand; - _organizationRepository = organizationRepository; - _organizationSponsorshipRepository = organizationSponsorshipRepository; - _organizationUserRepository = organizationUserRepository; - _currentContext = currentContext; + throw new BadRequestException("Unknown Organization User"); } - [HttpPost("{sponsoringOrgId}/families-for-enterprise")] - public async Task CreateSponsorship(Guid sponsoringOrgId, [FromBody] OrganizationSponsorshipCreateRequestModel model) - { - await _offerSponsorshipCommand.CreateSponsorshipAsync( - await _organizationRepository.GetByIdAsync(sponsoringOrgId), - await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default), - model.PlanSponsorshipType, model.SponsoredEmail, model.FriendlyName); - } + var existingOrgSponsorship = await _organizationSponsorshipRepository + .GetBySponsoringOrganizationUserIdAsync(orgUser.Id); - [HttpDelete("{sponsoringOrgId}")] - [HttpPost("{sponsoringOrgId}/delete")] - public async Task RevokeSponsorship(Guid sponsoringOrgId) - { - var orgUser = await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default); - - if (orgUser == null) - { - throw new BadRequestException("Unknown Organization User"); - } - - var existingOrgSponsorship = await _organizationSponsorshipRepository - .GetBySponsoringOrganizationUserIdAsync(orgUser.Id); - - await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship); - } + await _revokeSponsorshipCommand.RevokeSponsorshipAsync(existingOrgSponsorship); } } diff --git a/src/Api/Controllers/SendsController.cs b/src/Api/Controllers/SendsController.cs index 405f5c659..5f1d7527a 100644 --- a/src/Api/Controllers/SendsController.cs +++ b/src/Api/Controllers/SendsController.cs @@ -16,323 +16,322 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("sends")] +[Authorize("Application")] +public class SendsController : Controller { - [Route("sends")] - [Authorize("Application")] - public class SendsController : Controller + private readonly ISendRepository _sendRepository; + private readonly IUserService _userService; + private readonly ISendService _sendService; + private readonly ISendFileStorageService _sendFileStorageService; + private readonly ILogger _logger; + private readonly GlobalSettings _globalSettings; + private readonly ICurrentContext _currentContext; + + public SendsController( + ISendRepository sendRepository, + IUserService userService, + ISendService sendService, + ISendFileStorageService sendFileStorageService, + ILogger logger, + GlobalSettings globalSettings, + ICurrentContext currentContext) { - private readonly ISendRepository _sendRepository; - private readonly IUserService _userService; - private readonly ISendService _sendService; - private readonly ISendFileStorageService _sendFileStorageService; - private readonly ILogger _logger; - private readonly GlobalSettings _globalSettings; - private readonly ICurrentContext _currentContext; + _sendRepository = sendRepository; + _userService = userService; + _sendService = sendService; + _sendFileStorageService = sendFileStorageService; + _logger = logger; + _globalSettings = globalSettings; + _currentContext = currentContext; + } - public SendsController( - ISendRepository sendRepository, - IUserService userService, - ISendService sendService, - ISendFileStorageService sendFileStorageService, - ILogger logger, - GlobalSettings globalSettings, - ICurrentContext currentContext) + [AllowAnonymous] + [HttpPost("access/{id}")] + public async Task Access(string id, [FromBody] SendAccessRequestModel model) + { + // Uncomment whenever we want to require the `send-id` header + //if (!_currentContext.HttpContext.Request.Headers.ContainsKey("Send-Id") || + // _currentContext.HttpContext.Request.Headers["Send-Id"] != id) + //{ + // throw new BadRequestException("Invalid Send-Id header."); + //} + + var guid = new Guid(CoreHelpers.Base64UrlDecode(id)); + var (send, passwordRequired, passwordInvalid) = + await _sendService.AccessAsync(guid, model.Password); + if (passwordRequired) { - _sendRepository = sendRepository; - _userService = userService; - _sendService = sendService; - _sendFileStorageService = sendFileStorageService; - _logger = logger; - _globalSettings = globalSettings; - _currentContext = currentContext; + return new UnauthorizedResult(); + } + if (passwordInvalid) + { + await Task.Delay(2000); + throw new BadRequestException("Invalid password."); + } + if (send == null) + { + throw new NotFoundException(); } - [AllowAnonymous] - [HttpPost("access/{id}")] - public async Task Access(string id, [FromBody] SendAccessRequestModel model) + var sendResponse = new SendAccessResponseModel(send, _globalSettings); + if (send.UserId.HasValue && !send.HideEmail.GetValueOrDefault()) { - // Uncomment whenever we want to require the `send-id` header - //if (!_currentContext.HttpContext.Request.Headers.ContainsKey("Send-Id") || - // _currentContext.HttpContext.Request.Headers["Send-Id"] != id) - //{ - // throw new BadRequestException("Invalid Send-Id header."); - //} + var creator = await _userService.GetUserByIdAsync(send.UserId.Value); + sendResponse.CreatorIdentifier = creator.Email; + } + return new ObjectResult(sendResponse); + } - var guid = new Guid(CoreHelpers.Base64UrlDecode(id)); - var (send, passwordRequired, passwordInvalid) = - await _sendService.AccessAsync(guid, model.Password); - if (passwordRequired) - { - return new UnauthorizedResult(); - } - if (passwordInvalid) - { - await Task.Delay(2000); - throw new BadRequestException("Invalid password."); - } - if (send == null) - { - throw new NotFoundException(); - } + [AllowAnonymous] + [HttpPost("{encodedSendId}/access/file/{fileId}")] + public async Task GetSendFileDownloadData(string encodedSendId, + string fileId, [FromBody] SendAccessRequestModel model) + { + // Uncomment whenever we want to require the `send-id` header + //if (!_currentContext.HttpContext.Request.Headers.ContainsKey("Send-Id") || + // _currentContext.HttpContext.Request.Headers["Send-Id"] != encodedSendId) + //{ + // throw new BadRequestException("Invalid Send-Id header."); + //} - var sendResponse = new SendAccessResponseModel(send, _globalSettings); - if (send.UserId.HasValue && !send.HideEmail.GetValueOrDefault()) - { - var creator = await _userService.GetUserByIdAsync(send.UserId.Value); - sendResponse.CreatorIdentifier = creator.Email; - } - return new ObjectResult(sendResponse); + var sendId = new Guid(CoreHelpers.Base64UrlDecode(encodedSendId)); + var send = await _sendRepository.GetByIdAsync(sendId); + + if (send == null) + { + throw new BadRequestException("Could not locate send"); } - [AllowAnonymous] - [HttpPost("{encodedSendId}/access/file/{fileId}")] - public async Task GetSendFileDownloadData(string encodedSendId, - string fileId, [FromBody] SendAccessRequestModel model) + var (url, passwordRequired, passwordInvalid) = await _sendService.GetSendFileDownloadUrlAsync(send, fileId, + model.Password); + + if (passwordRequired) { - // Uncomment whenever we want to require the `send-id` header - //if (!_currentContext.HttpContext.Request.Headers.ContainsKey("Send-Id") || - // _currentContext.HttpContext.Request.Headers["Send-Id"] != encodedSendId) - //{ - // throw new BadRequestException("Invalid Send-Id header."); - //} - - var sendId = new Guid(CoreHelpers.Base64UrlDecode(encodedSendId)); - var send = await _sendRepository.GetByIdAsync(sendId); - - if (send == null) - { - throw new BadRequestException("Could not locate send"); - } - - var (url, passwordRequired, passwordInvalid) = await _sendService.GetSendFileDownloadUrlAsync(send, fileId, - model.Password); - - if (passwordRequired) - { - return new UnauthorizedResult(); - } - if (passwordInvalid) - { - await Task.Delay(2000); - throw new BadRequestException("Invalid password."); - } - if (send == null) - { - throw new NotFoundException(); - } - - return new ObjectResult(new SendFileDownloadDataResponseModel() - { - Id = fileId, - Url = url, - }); + return new UnauthorizedResult(); + } + if (passwordInvalid) + { + await Task.Delay(2000); + throw new BadRequestException("Invalid password."); + } + if (send == null) + { + throw new NotFoundException(); } - [HttpGet("{id}")] - public async Task Get(string id) + return new ObjectResult(new SendFileDownloadDataResponseModel() { - var userId = _userService.GetProperUserId(User).Value; - var send = await _sendRepository.GetByIdAsync(new Guid(id)); - if (send == null || send.UserId != userId) - { - throw new NotFoundException(); - } + Id = fileId, + Url = url, + }); + } - return new SendResponseModel(send, _globalSettings); + [HttpGet("{id}")] + public async Task Get(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var send = await _sendRepository.GetByIdAsync(new Guid(id)); + if (send == null || send.UserId != userId) + { + throw new NotFoundException(); } - [HttpGet("")] - public async Task> Get() + return new SendResponseModel(send, _globalSettings); + } + + [HttpGet("")] + public async Task> Get() + { + var userId = _userService.GetProperUserId(User).Value; + var sends = await _sendRepository.GetManyByUserIdAsync(userId); + var responses = sends.Select(s => new SendResponseModel(s, _globalSettings)); + return new ListResponseModel(responses); + } + + [HttpPost("")] + public async Task Post([FromBody] SendRequestModel model) + { + model.ValidateCreation(); + var userId = _userService.GetProperUserId(User).Value; + var send = model.ToSend(userId, _sendService); + await _sendService.SaveSendAsync(send); + return new SendResponseModel(send, _globalSettings); + } + + [HttpPost("file")] + [Obsolete("Deprecated File Send API", false)] + [RequestSizeLimit(Constants.FileSize101mb)] + [DisableFormValueModelBinding] + public async Task PostFile() + { + if (!Request?.ContentType.Contains("multipart/") ?? true) { - var userId = _userService.GetProperUserId(User).Value; - var sends = await _sendRepository.GetManyByUserIdAsync(userId); - var responses = sends.Select(s => new SendResponseModel(s, _globalSettings)); - return new ListResponseModel(responses); + throw new BadRequestException("Invalid content."); } - [HttpPost("")] - public async Task Post([FromBody] SendRequestModel model) + Send send = null; + await Request.GetSendFileAsync(async (stream, fileName, model) => { model.ValidateCreation(); var userId = _userService.GetProperUserId(User).Value; - var send = model.ToSend(userId, _sendService); - await _sendService.SaveSendAsync(send); - return new SendResponseModel(send, _globalSettings); + var (madeSend, madeData) = model.ToSend(userId, fileName, _sendService); + send = madeSend; + await _sendService.SaveFileSendAsync(send, madeData, model.FileLength.GetValueOrDefault(0)); + await _sendService.UploadFileToExistingSendAsync(stream, send); + }); + + return new SendResponseModel(send, _globalSettings); + } + + + [HttpPost("file/v2")] + public async Task PostFile([FromBody] SendRequestModel model) + { + if (model.Type != SendType.File) + { + throw new BadRequestException("Invalid content."); } - [HttpPost("file")] - [Obsolete("Deprecated File Send API", false)] - [RequestSizeLimit(Constants.FileSize101mb)] - [DisableFormValueModelBinding] - public async Task PostFile() + if (!model.FileLength.HasValue) { - if (!Request?.ContentType.Contains("multipart/") ?? true) - { - throw new BadRequestException("Invalid content."); - } - - Send send = null; - await Request.GetSendFileAsync(async (stream, fileName, model) => - { - model.ValidateCreation(); - var userId = _userService.GetProperUserId(User).Value; - var (madeSend, madeData) = model.ToSend(userId, fileName, _sendService); - send = madeSend; - await _sendService.SaveFileSendAsync(send, madeData, model.FileLength.GetValueOrDefault(0)); - await _sendService.UploadFileToExistingSendAsync(stream, send); - }); - - return new SendResponseModel(send, _globalSettings); + throw new BadRequestException("Invalid content. File size hint is required."); } - - [HttpPost("file/v2")] - public async Task PostFile([FromBody] SendRequestModel model) + if (model.FileLength.Value > SendService.MAX_FILE_SIZE) { - if (model.Type != SendType.File) - { - throw new BadRequestException("Invalid content."); - } - - if (!model.FileLength.HasValue) - { - throw new BadRequestException("Invalid content. File size hint is required."); - } - - if (model.FileLength.Value > SendService.MAX_FILE_SIZE) - { - throw new BadRequestException($"Max file size is {SendService.MAX_FILE_SIZE_READABLE}."); - } - - var userId = _userService.GetProperUserId(User).Value; - var (send, data) = model.ToSend(userId, model.File.FileName, _sendService); - var uploadUrl = await _sendService.SaveFileSendAsync(send, data, model.FileLength.Value); - return new SendFileUploadDataResponseModel - { - Url = uploadUrl, - FileUploadType = _sendFileStorageService.FileUploadType, - SendResponse = new SendResponseModel(send, _globalSettings) - }; + throw new BadRequestException($"Max file size is {SendService.MAX_FILE_SIZE_READABLE}."); } - [HttpGet("{id}/file/{fileId}")] - public async Task RenewFileUpload(string id, string fileId) + var userId = _userService.GetProperUserId(User).Value; + var (send, data) = model.ToSend(userId, model.File.FileName, _sendService); + var uploadUrl = await _sendService.SaveFileSendAsync(send, data, model.FileLength.Value); + return new SendFileUploadDataResponseModel { - var userId = _userService.GetProperUserId(User).Value; - var sendId = new Guid(id); - var send = await _sendRepository.GetByIdAsync(sendId); - var fileData = JsonSerializer.Deserialize(send?.Data); + Url = uploadUrl, + FileUploadType = _sendFileStorageService.FileUploadType, + SendResponse = new SendResponseModel(send, _globalSettings) + }; + } - if (send == null || send.Type != SendType.File || (send.UserId.HasValue && send.UserId.Value != userId) || - !send.UserId.HasValue || fileData.Id != fileId || fileData.Validated) - { - // Not found if Send isn't found, user doesn't have access, request is faulty, - // or we've already validated the file. This last is to emulate create-only blob permissions for Azure - throw new NotFoundException(); - } + [HttpGet("{id}/file/{fileId}")] + public async Task RenewFileUpload(string id, string fileId) + { + var userId = _userService.GetProperUserId(User).Value; + var sendId = new Guid(id); + var send = await _sendRepository.GetByIdAsync(sendId); + var fileData = JsonSerializer.Deserialize(send?.Data); - return new SendFileUploadDataResponseModel - { - Url = await _sendFileStorageService.GetSendFileUploadUrlAsync(send, fileId), - FileUploadType = _sendFileStorageService.FileUploadType, - SendResponse = new SendResponseModel(send, _globalSettings), - }; + if (send == null || send.Type != SendType.File || (send.UserId.HasValue && send.UserId.Value != userId) || + !send.UserId.HasValue || fileData.Id != fileId || fileData.Validated) + { + // Not found if Send isn't found, user doesn't have access, request is faulty, + // or we've already validated the file. This last is to emulate create-only blob permissions for Azure + throw new NotFoundException(); } - [HttpPost("{id}/file/{fileId}")] - [SelfHosted(SelfHostedOnly = true)] - [RequestSizeLimit(Constants.FileSize501mb)] - [DisableFormValueModelBinding] - public async Task PostFileForExistingSend(string id, string fileId) + return new SendFileUploadDataResponseModel { - if (!Request?.ContentType.Contains("multipart/") ?? true) - { - throw new BadRequestException("Invalid content."); - } + Url = await _sendFileStorageService.GetSendFileUploadUrlAsync(send, fileId), + FileUploadType = _sendFileStorageService.FileUploadType, + SendResponse = new SendResponseModel(send, _globalSettings), + }; + } - var send = await _sendRepository.GetByIdAsync(new Guid(id)); - await Request.GetFileAsync(async (stream) => - { - await _sendService.UploadFileToExistingSendAsync(stream, send); - }); + [HttpPost("{id}/file/{fileId}")] + [SelfHosted(SelfHostedOnly = true)] + [RequestSizeLimit(Constants.FileSize501mb)] + [DisableFormValueModelBinding] + public async Task PostFileForExistingSend(string id, string fileId) + { + if (!Request?.ContentType.Contains("multipart/") ?? true) + { + throw new BadRequestException("Invalid content."); } - [AllowAnonymous] - [HttpPost("file/validate/azure")] - public async Task AzureValidateFile() + var send = await _sendRepository.GetByIdAsync(new Guid(id)); + await Request.GetFileAsync(async (stream) => + { + await _sendService.UploadFileToExistingSendAsync(stream, send); + }); + } + + [AllowAnonymous] + [HttpPost("file/validate/azure")] + public async Task AzureValidateFile() + { + return await ApiHelpers.HandleAzureEvents(Request, new Dictionary> { - return await ApiHelpers.HandleAzureEvents(Request, new Dictionary> { + "Microsoft.Storage.BlobCreated", async (eventGridEvent) => { - "Microsoft.Storage.BlobCreated", async (eventGridEvent) => + try { - try + var blobName = eventGridEvent.Subject.Split($"{AzureSendFileStorageService.FilesContainerName}/blobs/")[1]; + var sendId = AzureSendFileStorageService.SendIdFromBlobName(blobName); + var send = await _sendRepository.GetByIdAsync(new Guid(sendId)); + if (send == null) { - var blobName = eventGridEvent.Subject.Split($"{AzureSendFileStorageService.FilesContainerName}/blobs/")[1]; - var sendId = AzureSendFileStorageService.SendIdFromBlobName(blobName); - var send = await _sendRepository.GetByIdAsync(new Guid(sendId)); - if (send == null) + if (_sendFileStorageService is AzureSendFileStorageService azureSendFileStorageService) { - if (_sendFileStorageService is AzureSendFileStorageService azureSendFileStorageService) - { - await azureSendFileStorageService.DeleteBlobAsync(blobName); - } - return; + await azureSendFileStorageService.DeleteBlobAsync(blobName); } - await _sendService.ValidateSendFile(send); - } - catch (Exception e) - { - _logger.LogError(e, $"Uncaught exception occurred while handling event grid event: {JsonSerializer.Serialize(eventGridEvent)}"); return; } + await _sendService.ValidateSendFile(send); + } + catch (Exception e) + { + _logger.LogError(e, $"Uncaught exception occurred while handling event grid event: {JsonSerializer.Serialize(eventGridEvent)}"); + return; } } - }); - } - - [HttpPut("{id}")] - public async Task Put(string id, [FromBody] SendRequestModel model) - { - model.ValidateEdit(); - var userId = _userService.GetProperUserId(User).Value; - var send = await _sendRepository.GetByIdAsync(new Guid(id)); - if (send == null || send.UserId != userId) - { - throw new NotFoundException(); } + }); + } - await _sendService.SaveSendAsync(model.ToSend(send, _sendService)); - return new SendResponseModel(send, _globalSettings); - } - - [HttpPut("{id}/remove-password")] - public async Task PutRemovePassword(string id) + [HttpPut("{id}")] + public async Task Put(string id, [FromBody] SendRequestModel model) + { + model.ValidateEdit(); + var userId = _userService.GetProperUserId(User).Value; + var send = await _sendRepository.GetByIdAsync(new Guid(id)); + if (send == null || send.UserId != userId) { - var userId = _userService.GetProperUserId(User).Value; - var send = await _sendRepository.GetByIdAsync(new Guid(id)); - if (send == null || send.UserId != userId) - { - throw new NotFoundException(); - } - - send.Password = null; - await _sendService.SaveSendAsync(send); - return new SendResponseModel(send, _globalSettings); + throw new NotFoundException(); } - [HttpDelete("{id}")] - public async Task Delete(string id) + await _sendService.SaveSendAsync(model.ToSend(send, _sendService)); + return new SendResponseModel(send, _globalSettings); + } + + [HttpPut("{id}/remove-password")] + public async Task PutRemovePassword(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var send = await _sendRepository.GetByIdAsync(new Guid(id)); + if (send == null || send.UserId != userId) { - var userId = _userService.GetProperUserId(User).Value; - var send = await _sendRepository.GetByIdAsync(new Guid(id)); - if (send == null || send.UserId != userId) - { - throw new NotFoundException(); - } - - await _sendService.DeleteSendAsync(send); + throw new NotFoundException(); } + + send.Password = null; + await _sendService.SaveSendAsync(send); + return new SendResponseModel(send, _globalSettings); + } + + [HttpDelete("{id}")] + public async Task Delete(string id) + { + var userId = _userService.GetProperUserId(User).Value; + var send = await _sendRepository.GetByIdAsync(new Guid(id)); + if (send == null || send.UserId != userId) + { + throw new NotFoundException(); + } + + await _sendService.DeleteSendAsync(send); } } diff --git a/src/Api/Controllers/SettingsController.cs b/src/Api/Controllers/SettingsController.cs index 2db70b017..8489b137e 100644 --- a/src/Api/Controllers/SettingsController.cs +++ b/src/Api/Controllers/SettingsController.cs @@ -4,47 +4,46 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("settings")] +[Authorize("Application")] +public class SettingsController : Controller { - [Route("settings")] - [Authorize("Application")] - public class SettingsController : Controller + private readonly IUserService _userService; + + public SettingsController( + IUserService userService) { - private readonly IUserService _userService; + _userService = userService; + } - public SettingsController( - IUserService userService) + [HttpGet("domains")] + public async Task GetDomains(bool excluded = true) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) { - _userService = userService; + throw new UnauthorizedAccessException(); } - [HttpGet("domains")] - public async Task GetDomains(bool excluded = true) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } + var response = new DomainsResponseModel(user, excluded); + return response; + } - var response = new DomainsResponseModel(user, excluded); - return response; + [HttpPut("domains")] + [HttpPost("domains")] + public async Task PutDomains([FromBody] UpdateDomainsRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); } - [HttpPut("domains")] - [HttpPost("domains")] - public async Task PutDomains([FromBody] UpdateDomainsRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } + await _userService.SaveUserAsync(model.ToUser(user), true); - await _userService.SaveUserAsync(model.ToUser(user), true); - - var response = new DomainsResponseModel(user); - return response; - } + var response = new DomainsResponseModel(user); + return response; } } diff --git a/src/Api/Controllers/SyncController.cs b/src/Api/Controllers/SyncController.cs index c85554386..49ccfeacf 100644 --- a/src/Api/Controllers/SyncController.cs +++ b/src/Api/Controllers/SyncController.cs @@ -10,85 +10,84 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("sync")] +[Authorize("Application")] +public class SyncController : Controller { - [Route("sync")] - [Authorize("Application")] - public class SyncController : Controller + private readonly IUserService _userService; + private readonly IFolderRepository _folderRepository; + private readonly ICipherRepository _cipherRepository; + private readonly ICollectionRepository _collectionRepository; + private readonly ICollectionCipherRepository _collectionCipherRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IProviderUserRepository _providerUserRepository; + private readonly IPolicyRepository _policyRepository; + private readonly ISendRepository _sendRepository; + private readonly GlobalSettings _globalSettings; + + public SyncController( + IUserService userService, + IFolderRepository folderRepository, + ICipherRepository cipherRepository, + ICollectionRepository collectionRepository, + ICollectionCipherRepository collectionCipherRepository, + IOrganizationUserRepository organizationUserRepository, + IProviderUserRepository providerUserRepository, + IPolicyRepository policyRepository, + ISendRepository sendRepository, + GlobalSettings globalSettings) { - private readonly IUserService _userService; - private readonly IFolderRepository _folderRepository; - private readonly ICipherRepository _cipherRepository; - private readonly ICollectionRepository _collectionRepository; - private readonly ICollectionCipherRepository _collectionCipherRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IProviderUserRepository _providerUserRepository; - private readonly IPolicyRepository _policyRepository; - private readonly ISendRepository _sendRepository; - private readonly GlobalSettings _globalSettings; + _userService = userService; + _folderRepository = folderRepository; + _cipherRepository = cipherRepository; + _collectionRepository = collectionRepository; + _collectionCipherRepository = collectionCipherRepository; + _organizationUserRepository = organizationUserRepository; + _providerUserRepository = providerUserRepository; + _policyRepository = policyRepository; + _sendRepository = sendRepository; + _globalSettings = globalSettings; + } - public SyncController( - IUserService userService, - IFolderRepository folderRepository, - ICipherRepository cipherRepository, - ICollectionRepository collectionRepository, - ICollectionCipherRepository collectionCipherRepository, - IOrganizationUserRepository organizationUserRepository, - IProviderUserRepository providerUserRepository, - IPolicyRepository policyRepository, - ISendRepository sendRepository, - GlobalSettings globalSettings) + [HttpGet("")] + public async Task Get([FromQuery] bool excludeDomains = false) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) { - _userService = userService; - _folderRepository = folderRepository; - _cipherRepository = cipherRepository; - _collectionRepository = collectionRepository; - _collectionCipherRepository = collectionCipherRepository; - _organizationUserRepository = organizationUserRepository; - _providerUserRepository = providerUserRepository; - _policyRepository = policyRepository; - _sendRepository = sendRepository; - _globalSettings = globalSettings; + throw new BadRequestException("User not found."); } - [HttpGet("")] - public async Task Get([FromQuery] bool excludeDomains = false) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new BadRequestException("User not found."); - } - - var organizationUserDetails = await _organizationUserRepository.GetManyDetailsByUserAsync(user.Id, - OrganizationUserStatusType.Confirmed); - var providerUserDetails = await _providerUserRepository.GetManyDetailsByUserAsync(user.Id, + var organizationUserDetails = await _organizationUserRepository.GetManyDetailsByUserAsync(user.Id, + OrganizationUserStatusType.Confirmed); + var providerUserDetails = await _providerUserRepository.GetManyDetailsByUserAsync(user.Id, + ProviderUserStatusType.Confirmed); + var providerUserOrganizationDetails = + await _providerUserRepository.GetManyOrganizationDetailsByUserAsync(user.Id, ProviderUserStatusType.Confirmed); - var providerUserOrganizationDetails = - await _providerUserRepository.GetManyOrganizationDetailsByUserAsync(user.Id, - ProviderUserStatusType.Confirmed); - var hasEnabledOrgs = organizationUserDetails.Any(o => o.Enabled); - var folders = await _folderRepository.GetManyByUserIdAsync(user.Id); - var ciphers = await _cipherRepository.GetManyByUserIdAsync(user.Id, hasEnabledOrgs); - var sends = await _sendRepository.GetManyByUserIdAsync(user.Id); + var hasEnabledOrgs = organizationUserDetails.Any(o => o.Enabled); + var folders = await _folderRepository.GetManyByUserIdAsync(user.Id); + var ciphers = await _cipherRepository.GetManyByUserIdAsync(user.Id, hasEnabledOrgs); + var sends = await _sendRepository.GetManyByUserIdAsync(user.Id); - IEnumerable collections = null; - IDictionary> collectionCiphersGroupDict = null; - IEnumerable policies = null; - if (hasEnabledOrgs) - { - collections = await _collectionRepository.GetManyByUserIdAsync(user.Id); - var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdAsync(user.Id); - collectionCiphersGroupDict = collectionCiphers.GroupBy(c => c.CipherId).ToDictionary(s => s.Key); - policies = await _policyRepository.GetManyByUserIdAsync(user.Id); - } - - var userTwoFactorEnabled = await _userService.TwoFactorIsEnabledAsync(user); - var userHasPremiumFromOrganization = await _userService.HasPremiumFromOrganization(user); - var response = new SyncResponseModel(_globalSettings, user, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationUserDetails, - providerUserDetails, providerUserOrganizationDetails, folders, collections, ciphers, - collectionCiphersGroupDict, excludeDomains, policies, sends); - return response; + IEnumerable collections = null; + IDictionary> collectionCiphersGroupDict = null; + IEnumerable policies = null; + if (hasEnabledOrgs) + { + collections = await _collectionRepository.GetManyByUserIdAsync(user.Id); + var collectionCiphers = await _collectionCipherRepository.GetManyByUserIdAsync(user.Id); + collectionCiphersGroupDict = collectionCiphers.GroupBy(c => c.CipherId).ToDictionary(s => s.Key); + policies = await _policyRepository.GetManyByUserIdAsync(user.Id); } + + var userTwoFactorEnabled = await _userService.TwoFactorIsEnabledAsync(user); + var userHasPremiumFromOrganization = await _userService.HasPremiumFromOrganization(user); + var response = new SyncResponseModel(_globalSettings, user, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationUserDetails, + providerUserDetails, providerUserOrganizationDetails, folders, collections, ciphers, + collectionCiphersGroupDict, excludeDomains, policies, sends); + return response; } } diff --git a/src/Api/Controllers/TwoFactorController.cs b/src/Api/Controllers/TwoFactorController.cs index b8e71f5df..6ed2b8796 100644 --- a/src/Api/Controllers/TwoFactorController.cs +++ b/src/Api/Controllers/TwoFactorController.cs @@ -16,443 +16,442 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("two-factor")] +[Authorize("Web")] +public class TwoFactorController : Controller { - [Route("two-factor")] - [Authorize("Web")] - public class TwoFactorController : Controller + private readonly IUserService _userService; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationService _organizationService; + private readonly GlobalSettings _globalSettings; + private readonly UserManager _userManager; + private readonly ICurrentContext _currentContext; + + public TwoFactorController( + IUserService userService, + IOrganizationRepository organizationRepository, + IOrganizationService organizationService, + GlobalSettings globalSettings, + UserManager userManager, + ICurrentContext currentContext) { - private readonly IUserService _userService; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationService _organizationService; - private readonly GlobalSettings _globalSettings; - private readonly UserManager _userManager; - private readonly ICurrentContext _currentContext; + _userService = userService; + _organizationRepository = organizationRepository; + _organizationService = organizationService; + _globalSettings = globalSettings; + _userManager = userManager; + _currentContext = currentContext; + } - public TwoFactorController( - IUserService userService, - IOrganizationRepository organizationRepository, - IOrganizationService organizationService, - GlobalSettings globalSettings, - UserManager userManager, - ICurrentContext currentContext) + [HttpGet("")] + public async Task> Get() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) { - _userService = userService; - _organizationRepository = organizationRepository; - _organizationService = organizationService; - _globalSettings = globalSettings; - _userManager = userManager; - _currentContext = currentContext; + throw new UnauthorizedAccessException(); } - [HttpGet("")] - public async Task> Get() - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } + var providers = user.GetTwoFactorProviders()?.Select( + p => new TwoFactorProviderResponseModel(p.Key, p.Value)); + return new ListResponseModel(providers); + } - var providers = user.GetTwoFactorProviders()?.Select( - p => new TwoFactorProviderResponseModel(p.Key, p.Value)); - return new ListResponseModel(providers); + [HttpGet("~/organizations/{id}/two-factor")] + public async Task> GetOrganization(string id) + { + var orgIdGuid = new Guid(id); + if (!await _currentContext.OrganizationAdmin(orgIdGuid)) + { + throw new NotFoundException(); } - [HttpGet("~/organizations/{id}/two-factor")] - public async Task> GetOrganization(string id) + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) { - var orgIdGuid = new Guid(id); - if (!await _currentContext.OrganizationAdmin(orgIdGuid)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - var providers = organization.GetTwoFactorProviders()?.Select( - p => new TwoFactorProviderResponseModel(p.Key, p.Value)); - return new ListResponseModel(providers); + throw new NotFoundException(); } - [HttpPost("get-authenticator")] - public async Task GetAuthenticator([FromBody] SecretVerificationRequestModel model) + var providers = organization.GetTwoFactorProviders()?.Select( + p => new TwoFactorProviderResponseModel(p.Key, p.Value)); + return new ListResponseModel(providers); + } + + [HttpPost("get-authenticator")] + public async Task GetAuthenticator([FromBody] SecretVerificationRequestModel model) + { + var user = await CheckAsync(model, false); + var response = new TwoFactorAuthenticatorResponseModel(user); + return response; + } + + [HttpPut("authenticator")] + [HttpPost("authenticator")] + public async Task PutAuthenticator( + [FromBody] UpdateTwoFactorAuthenticatorRequestModel model) + { + var user = await CheckAsync(model, false); + model.ToUser(user); + + if (!await _userManager.VerifyTwoFactorTokenAsync(user, + CoreHelpers.CustomProviderName(TwoFactorProviderType.Authenticator), model.Token)) { - var user = await CheckAsync(model, false); - var response = new TwoFactorAuthenticatorResponseModel(user); - return response; - } - - [HttpPut("authenticator")] - [HttpPost("authenticator")] - public async Task PutAuthenticator( - [FromBody] UpdateTwoFactorAuthenticatorRequestModel model) - { - var user = await CheckAsync(model, false); - model.ToUser(user); - - if (!await _userManager.VerifyTwoFactorTokenAsync(user, - CoreHelpers.CustomProviderName(TwoFactorProviderType.Authenticator), model.Token)) - { - await Task.Delay(2000); - throw new BadRequestException("Token", "Invalid token."); - } - - await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Authenticator); - var response = new TwoFactorAuthenticatorResponseModel(user); - return response; - } - - [HttpPost("get-yubikey")] - public async Task GetYubiKey([FromBody] SecretVerificationRequestModel model) - { - var user = await CheckAsync(model, true); - var response = new TwoFactorYubiKeyResponseModel(user); - return response; - } - - [HttpPut("yubikey")] - [HttpPost("yubikey")] - public async Task PutYubiKey([FromBody] UpdateTwoFactorYubicoOtpRequestModel model) - { - var user = await CheckAsync(model, true); - model.ToUser(user); - - await ValidateYubiKeyAsync(user, nameof(model.Key1), model.Key1); - await ValidateYubiKeyAsync(user, nameof(model.Key2), model.Key2); - await ValidateYubiKeyAsync(user, nameof(model.Key3), model.Key3); - await ValidateYubiKeyAsync(user, nameof(model.Key4), model.Key4); - await ValidateYubiKeyAsync(user, nameof(model.Key5), model.Key5); - - await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.YubiKey); - var response = new TwoFactorYubiKeyResponseModel(user); - return response; - } - - [HttpPost("get-duo")] - public async Task GetDuo([FromBody] SecretVerificationRequestModel model) - { - var user = await CheckAsync(model, true); - var response = new TwoFactorDuoResponseModel(user); - return response; - } - - [HttpPut("duo")] - [HttpPost("duo")] - public async Task PutDuo([FromBody] UpdateTwoFactorDuoRequestModel model) - { - var user = await CheckAsync(model, true); - try - { - var duoApi = new DuoApi(model.IntegrationKey, model.SecretKey, model.Host); - duoApi.JSONApiCall("GET", "/auth/v2/check"); - } - catch (DuoException) - { - throw new BadRequestException("Duo configuration settings are not valid. Please re-check the Duo Admin panel."); - } - - model.ToUser(user); - await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Duo); - var response = new TwoFactorDuoResponseModel(user); - return response; - } - - [HttpPost("~/organizations/{id}/two-factor/get-duo")] - public async Task GetOrganizationDuo(string id, - [FromBody] SecretVerificationRequestModel model) - { - var user = await CheckAsync(model, false); - - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManagePolicies(orgIdGuid)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - var response = new TwoFactorDuoResponseModel(organization); - return response; - } - - [HttpPut("~/organizations/{id}/two-factor/duo")] - [HttpPost("~/organizations/{id}/two-factor/duo")] - public async Task PutOrganizationDuo(string id, - [FromBody] UpdateTwoFactorDuoRequestModel model) - { - var user = await CheckAsync(model, false); - - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManagePolicies(orgIdGuid)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - try - { - var duoApi = new DuoApi(model.IntegrationKey, model.SecretKey, model.Host); - duoApi.JSONApiCall("GET", "/auth/v2/check"); - } - catch (DuoException) - { - throw new BadRequestException("Duo configuration settings are not valid. Please re-check the Duo Admin panel."); - } - - model.ToOrganization(organization); - await _organizationService.UpdateTwoFactorProviderAsync(organization, - TwoFactorProviderType.OrganizationDuo); - var response = new TwoFactorDuoResponseModel(organization); - return response; - } - - [HttpPost("get-webauthn")] - public async Task GetWebAuthn([FromBody] SecretVerificationRequestModel model) - { - var user = await CheckAsync(model, true); - var response = new TwoFactorWebAuthnResponseModel(user); - return response; - } - - [HttpPost("get-webauthn-challenge")] - public async Task GetWebAuthnChallenge([FromBody] SecretVerificationRequestModel model) - { - var user = await CheckAsync(model, true); - var reg = await _userService.StartWebAuthnRegistrationAsync(user); - return reg; - } - - [HttpPut("webauthn")] - [HttpPost("webauthn")] - public async Task PutWebAuthn([FromBody] TwoFactorWebAuthnRequestModel model) - { - var user = await CheckAsync(model, true); - - var success = await _userService.CompleteWebAuthRegistrationAsync( - user, model.Id.Value, model.Name, model.DeviceResponse); - if (!success) - { - throw new BadRequestException("Unable to complete WebAuthn registration."); - } - var response = new TwoFactorWebAuthnResponseModel(user); - return response; - } - - [HttpDelete("webauthn")] - public async Task DeleteWebAuthn([FromBody] TwoFactorWebAuthnDeleteRequestModel model) - { - var user = await CheckAsync(model, true); - await _userService.DeleteWebAuthnKeyAsync(user, model.Id.Value); - var response = new TwoFactorWebAuthnResponseModel(user); - return response; - } - - [HttpPost("get-email")] - public async Task GetEmail([FromBody] SecretVerificationRequestModel model) - { - var user = await CheckAsync(model, false); - var response = new TwoFactorEmailResponseModel(user); - return response; - } - - [HttpPost("send-email")] - public async Task SendEmail([FromBody] TwoFactorEmailRequestModel model) - { - var user = await CheckAsync(model, false); - model.ToUser(user); - await _userService.SendTwoFactorEmailAsync(user); - } - - [AllowAnonymous] - [HttpPost("send-email-login")] - public async Task SendEmailLogin([FromBody] TwoFactorEmailRequestModel model) - { - var user = await _userManager.FindByEmailAsync(model.Email.ToLowerInvariant()); - if (user != null) - { - if (await _userService.VerifySecretAsync(user, model.Secret)) - { - var isBecauseNewDeviceLogin = false; - if (user.GetTwoFactorProvider(TwoFactorProviderType.Email) is null - && - await _userService.Needs2FABecauseNewDeviceAsync(user, model.DeviceIdentifier, null)) - { - model.ToUser(user); - isBecauseNewDeviceLogin = true; - } - - await _userService.SendTwoFactorEmailAsync(user, isBecauseNewDeviceLogin); - return; - } - } - await Task.Delay(2000); - throw new BadRequestException("Cannot send two-factor email."); + throw new BadRequestException("Token", "Invalid token."); } - [HttpPut("email")] - [HttpPost("email")] - public async Task PutEmail([FromBody] UpdateTwoFactorEmailRequestModel model) + await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Authenticator); + var response = new TwoFactorAuthenticatorResponseModel(user); + return response; + } + + [HttpPost("get-yubikey")] + public async Task GetYubiKey([FromBody] SecretVerificationRequestModel model) + { + var user = await CheckAsync(model, true); + var response = new TwoFactorYubiKeyResponseModel(user); + return response; + } + + [HttpPut("yubikey")] + [HttpPost("yubikey")] + public async Task PutYubiKey([FromBody] UpdateTwoFactorYubicoOtpRequestModel model) + { + var user = await CheckAsync(model, true); + model.ToUser(user); + + await ValidateYubiKeyAsync(user, nameof(model.Key1), model.Key1); + await ValidateYubiKeyAsync(user, nameof(model.Key2), model.Key2); + await ValidateYubiKeyAsync(user, nameof(model.Key3), model.Key3); + await ValidateYubiKeyAsync(user, nameof(model.Key4), model.Key4); + await ValidateYubiKeyAsync(user, nameof(model.Key5), model.Key5); + + await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.YubiKey); + var response = new TwoFactorYubiKeyResponseModel(user); + return response; + } + + [HttpPost("get-duo")] + public async Task GetDuo([FromBody] SecretVerificationRequestModel model) + { + var user = await CheckAsync(model, true); + var response = new TwoFactorDuoResponseModel(user); + return response; + } + + [HttpPut("duo")] + [HttpPost("duo")] + public async Task PutDuo([FromBody] UpdateTwoFactorDuoRequestModel model) + { + var user = await CheckAsync(model, true); + try { - var user = await CheckAsync(model, false); - model.ToUser(user); - - if (!await _userManager.VerifyTwoFactorTokenAsync(user, - CoreHelpers.CustomProviderName(TwoFactorProviderType.Email), model.Token)) - { - await Task.Delay(2000); - throw new BadRequestException("Token", "Invalid token."); - } - - await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Email); - var response = new TwoFactorEmailResponseModel(user); - return response; + var duoApi = new DuoApi(model.IntegrationKey, model.SecretKey, model.Host); + duoApi.JSONApiCall("GET", "/auth/v2/check"); + } + catch (DuoException) + { + throw new BadRequestException("Duo configuration settings are not valid. Please re-check the Duo Admin panel."); } - [HttpPut("disable")] - [HttpPost("disable")] - public async Task PutDisable([FromBody] TwoFactorProviderRequestModel model) + model.ToUser(user); + await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Duo); + var response = new TwoFactorDuoResponseModel(user); + return response; + } + + [HttpPost("~/organizations/{id}/two-factor/get-duo")] + public async Task GetOrganizationDuo(string id, + [FromBody] SecretVerificationRequestModel model) + { + var user = await CheckAsync(model, false); + + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManagePolicies(orgIdGuid)) { - var user = await CheckAsync(model, false); - await _userService.DisableTwoFactorProviderAsync(user, model.Type.Value, _organizationService); - var response = new TwoFactorProviderResponseModel(model.Type.Value, user); - return response; + throw new NotFoundException(); } - [HttpPut("~/organizations/{id}/two-factor/disable")] - [HttpPost("~/organizations/{id}/two-factor/disable")] - public async Task PutOrganizationDisable(string id, - [FromBody] TwoFactorProviderRequestModel model) + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) { - var user = await CheckAsync(model, false); - - var orgIdGuid = new Guid(id); - if (!await _currentContext.ManagePolicies(orgIdGuid)) - { - throw new NotFoundException(); - } - - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) - { - throw new NotFoundException(); - } - - await _organizationService.DisableTwoFactorProviderAsync(organization, model.Type.Value); - var response = new TwoFactorProviderResponseModel(model.Type.Value, organization); - return response; + throw new NotFoundException(); } - [HttpPost("get-recover")] - public async Task GetRecover([FromBody] SecretVerificationRequestModel model) + var response = new TwoFactorDuoResponseModel(organization); + return response; + } + + [HttpPut("~/organizations/{id}/two-factor/duo")] + [HttpPost("~/organizations/{id}/two-factor/duo")] + public async Task PutOrganizationDuo(string id, + [FromBody] UpdateTwoFactorDuoRequestModel model) + { + var user = await CheckAsync(model, false); + + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManagePolicies(orgIdGuid)) { - var user = await CheckAsync(model, false); - var response = new TwoFactorRecoverResponseModel(user); - return response; + throw new NotFoundException(); } - [HttpPost("recover")] - [AllowAnonymous] - public async Task PostRecover([FromBody] TwoFactorRecoveryRequestModel model) + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) { - if (!await _userService.RecoverTwoFactorAsync(model.Email, model.MasterPasswordHash, model.RecoveryCode, - _organizationService)) - { - await Task.Delay(2000); - throw new BadRequestException(string.Empty, "Invalid information. Try again."); - } + throw new NotFoundException(); } - [HttpGet("get-device-verification-settings")] - public async Task GetDeviceVerificationSettings() + try { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (User.Claims.HasSsoIdP()) - { - return new DeviceVerificationResponseModel(false, false); - } - - var canUserEditDeviceVerificationSettings = _userService.CanEditDeviceVerificationSettings(user); - return new DeviceVerificationResponseModel(canUserEditDeviceVerificationSettings, canUserEditDeviceVerificationSettings && user.UnknownDeviceVerificationEnabled); + var duoApi = new DuoApi(model.IntegrationKey, model.SecretKey, model.Host); + duoApi.JSONApiCall("GET", "/auth/v2/check"); + } + catch (DuoException) + { + throw new BadRequestException("Duo configuration settings are not valid. Please re-check the Duo Admin panel."); } - [HttpPut("device-verification-settings")] - public async Task PutDeviceVerificationSettings([FromBody] DeviceVerificationRequestModel model) - { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - if (!_userService.CanEditDeviceVerificationSettings(user) - || User.Claims.HasSsoIdP()) - { - throw new InvalidOperationException("Can't update device verification settings"); - } + model.ToOrganization(organization); + await _organizationService.UpdateTwoFactorProviderAsync(organization, + TwoFactorProviderType.OrganizationDuo); + var response = new TwoFactorDuoResponseModel(organization); + return response; + } - model.ToUser(user); - await _userService.SaveUserAsync(user); - return new DeviceVerificationResponseModel(true, user.UnknownDeviceVerificationEnabled); + [HttpPost("get-webauthn")] + public async Task GetWebAuthn([FromBody] SecretVerificationRequestModel model) + { + var user = await CheckAsync(model, true); + var response = new TwoFactorWebAuthnResponseModel(user); + return response; + } + + [HttpPost("get-webauthn-challenge")] + public async Task GetWebAuthnChallenge([FromBody] SecretVerificationRequestModel model) + { + var user = await CheckAsync(model, true); + var reg = await _userService.StartWebAuthnRegistrationAsync(user); + return reg; + } + + [HttpPut("webauthn")] + [HttpPost("webauthn")] + public async Task PutWebAuthn([FromBody] TwoFactorWebAuthnRequestModel model) + { + var user = await CheckAsync(model, true); + + var success = await _userService.CompleteWebAuthRegistrationAsync( + user, model.Id.Value, model.Name, model.DeviceResponse); + if (!success) + { + throw new BadRequestException("Unable to complete WebAuthn registration."); } + var response = new TwoFactorWebAuthnResponseModel(user); + return response; + } - private async Task CheckAsync(SecretVerificationRequestModel model, bool premium) + [HttpDelete("webauthn")] + public async Task DeleteWebAuthn([FromBody] TwoFactorWebAuthnDeleteRequestModel model) + { + var user = await CheckAsync(model, true); + await _userService.DeleteWebAuthnKeyAsync(user, model.Id.Value); + var response = new TwoFactorWebAuthnResponseModel(user); + return response; + } + + [HttpPost("get-email")] + public async Task GetEmail([FromBody] SecretVerificationRequestModel model) + { + var user = await CheckAsync(model, false); + var response = new TwoFactorEmailResponseModel(user); + return response; + } + + [HttpPost("send-email")] + public async Task SendEmail([FromBody] TwoFactorEmailRequestModel model) + { + var user = await CheckAsync(model, false); + model.ToUser(user); + await _userService.SendTwoFactorEmailAsync(user); + } + + [AllowAnonymous] + [HttpPost("send-email-login")] + public async Task SendEmailLogin([FromBody] TwoFactorEmailRequestModel model) + { + var user = await _userManager.FindByEmailAsync(model.Email.ToLowerInvariant()); + if (user != null) { - var user = await _userService.GetUserByPrincipalAsync(User); - if (user == null) + if (await _userService.VerifySecretAsync(user, model.Secret)) { - throw new UnauthorizedAccessException(); - } + var isBecauseNewDeviceLogin = false; + if (user.GetTwoFactorProvider(TwoFactorProviderType.Email) is null + && + await _userService.Needs2FABecauseNewDeviceAsync(user, model.DeviceIdentifier, null)) + { + model.ToUser(user); + isBecauseNewDeviceLogin = true; + } - if (!await _userService.VerifySecretAsync(user, model.Secret)) - { - await Task.Delay(2000); - throw new BadRequestException(string.Empty, "User verification failed."); - } - - if (premium && !(await _userService.CanAccessPremium(user))) - { - throw new BadRequestException("Premium status is required."); - } - - return user; - } - - private async Task ValidateYubiKeyAsync(User user, string name, string value) - { - if (string.IsNullOrWhiteSpace(value) || value.Length == 12) - { + await _userService.SendTwoFactorEmailAsync(user, isBecauseNewDeviceLogin); return; } + } - if (!await _userManager.VerifyTwoFactorTokenAsync(user, - CoreHelpers.CustomProviderName(TwoFactorProviderType.YubiKey), value)) - { - await Task.Delay(2000); - throw new BadRequestException(name, $"{name} is invalid."); - } - else - { - await Task.Delay(500); - } + await Task.Delay(2000); + throw new BadRequestException("Cannot send two-factor email."); + } + + [HttpPut("email")] + [HttpPost("email")] + public async Task PutEmail([FromBody] UpdateTwoFactorEmailRequestModel model) + { + var user = await CheckAsync(model, false); + model.ToUser(user); + + if (!await _userManager.VerifyTwoFactorTokenAsync(user, + CoreHelpers.CustomProviderName(TwoFactorProviderType.Email), model.Token)) + { + await Task.Delay(2000); + throw new BadRequestException("Token", "Invalid token."); + } + + await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Email); + var response = new TwoFactorEmailResponseModel(user); + return response; + } + + [HttpPut("disable")] + [HttpPost("disable")] + public async Task PutDisable([FromBody] TwoFactorProviderRequestModel model) + { + var user = await CheckAsync(model, false); + await _userService.DisableTwoFactorProviderAsync(user, model.Type.Value, _organizationService); + var response = new TwoFactorProviderResponseModel(model.Type.Value, user); + return response; + } + + [HttpPut("~/organizations/{id}/two-factor/disable")] + [HttpPost("~/organizations/{id}/two-factor/disable")] + public async Task PutOrganizationDisable(string id, + [FromBody] TwoFactorProviderRequestModel model) + { + var user = await CheckAsync(model, false); + + var orgIdGuid = new Guid(id); + if (!await _currentContext.ManagePolicies(orgIdGuid)) + { + throw new NotFoundException(); + } + + var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); + if (organization == null) + { + throw new NotFoundException(); + } + + await _organizationService.DisableTwoFactorProviderAsync(organization, model.Type.Value); + var response = new TwoFactorProviderResponseModel(model.Type.Value, organization); + return response; + } + + [HttpPost("get-recover")] + public async Task GetRecover([FromBody] SecretVerificationRequestModel model) + { + var user = await CheckAsync(model, false); + var response = new TwoFactorRecoverResponseModel(user); + return response; + } + + [HttpPost("recover")] + [AllowAnonymous] + public async Task PostRecover([FromBody] TwoFactorRecoveryRequestModel model) + { + if (!await _userService.RecoverTwoFactorAsync(model.Email, model.MasterPasswordHash, model.RecoveryCode, + _organizationService)) + { + await Task.Delay(2000); + throw new BadRequestException(string.Empty, "Invalid information. Try again."); + } + } + + [HttpGet("get-device-verification-settings")] + public async Task GetDeviceVerificationSettings() + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (User.Claims.HasSsoIdP()) + { + return new DeviceVerificationResponseModel(false, false); + } + + var canUserEditDeviceVerificationSettings = _userService.CanEditDeviceVerificationSettings(user); + return new DeviceVerificationResponseModel(canUserEditDeviceVerificationSettings, canUserEditDeviceVerificationSettings && user.UnknownDeviceVerificationEnabled); + } + + [HttpPut("device-verification-settings")] + public async Task PutDeviceVerificationSettings([FromBody] DeviceVerificationRequestModel model) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + if (!_userService.CanEditDeviceVerificationSettings(user) + || User.Claims.HasSsoIdP()) + { + throw new InvalidOperationException("Can't update device verification settings"); + } + + model.ToUser(user); + await _userService.SaveUserAsync(user); + return new DeviceVerificationResponseModel(true, user.UnknownDeviceVerificationEnabled); + } + + private async Task CheckAsync(SecretVerificationRequestModel model, bool premium) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + if (!await _userService.VerifySecretAsync(user, model.Secret)) + { + await Task.Delay(2000); + throw new BadRequestException(string.Empty, "User verification failed."); + } + + if (premium && !(await _userService.CanAccessPremium(user))) + { + throw new BadRequestException("Premium status is required."); + } + + return user; + } + + private async Task ValidateYubiKeyAsync(User user, string name, string value) + { + if (string.IsNullOrWhiteSpace(value) || value.Length == 12) + { + return; + } + + if (!await _userManager.VerifyTwoFactorTokenAsync(user, + CoreHelpers.CustomProviderName(TwoFactorProviderType.YubiKey), value)) + { + await Task.Delay(2000); + throw new BadRequestException(name, $"{name} is invalid."); + } + else + { + await Task.Delay(500); } } } diff --git a/src/Api/Controllers/UsersController.cs b/src/Api/Controllers/UsersController.cs index eeb50301e..4dfd047d3 100644 --- a/src/Api/Controllers/UsersController.cs +++ b/src/Api/Controllers/UsersController.cs @@ -4,31 +4,30 @@ using Bit.Core.Repositories; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers +namespace Bit.Api.Controllers; + +[Route("users")] +[Authorize("Application")] +public class UsersController : Controller { - [Route("users")] - [Authorize("Application")] - public class UsersController : Controller + private readonly IUserRepository _userRepository; + + public UsersController( + IUserRepository userRepository) { - private readonly IUserRepository _userRepository; + _userRepository = userRepository; + } - public UsersController( - IUserRepository userRepository) + [HttpGet("{id}/public-key")] + public async Task Get(string id) + { + var guidId = new Guid(id); + var key = await _userRepository.GetPublicKeyAsync(guidId); + if (key == null) { - _userRepository = userRepository; + throw new NotFoundException(); } - [HttpGet("{id}/public-key")] - public async Task Get(string id) - { - var guidId = new Guid(id); - var key = await _userRepository.GetPublicKeyAsync(guidId); - if (key == null) - { - throw new NotFoundException(); - } - - return new UserKeyResponseModel(guidId, key); - } + return new UserKeyResponseModel(guidId, key); } } diff --git a/src/Api/Jobs/AliveJob.cs b/src/Api/Jobs/AliveJob.cs index 354b8206e..71136ef7c 100644 --- a/src/Api/Jobs/AliveJob.cs +++ b/src/Api/Jobs/AliveJob.cs @@ -2,17 +2,16 @@ using Bit.Core.Jobs; using Quartz; -namespace Bit.Api.Jobs -{ - public class AliveJob : BaseJob - { - public AliveJob(ILogger logger) - : base(logger) { } +namespace Bit.Api.Jobs; - protected override Task ExecuteJobAsync(IJobExecutionContext context) - { - _logger.LogInformation(Constants.BypassFiltersEventId, null, "It's alive!"); - return Task.FromResult(0); - } +public class AliveJob : BaseJob +{ + public AliveJob(ILogger logger) + : base(logger) { } + + protected override Task ExecuteJobAsync(IJobExecutionContext context) + { + _logger.LogInformation(Constants.BypassFiltersEventId, null, "It's alive!"); + return Task.FromResult(0); } } diff --git a/src/Api/Jobs/EmergencyAccessNotificationJob.cs b/src/Api/Jobs/EmergencyAccessNotificationJob.cs index 4851ef38c..6520de352 100644 --- a/src/Api/Jobs/EmergencyAccessNotificationJob.cs +++ b/src/Api/Jobs/EmergencyAccessNotificationJob.cs @@ -2,23 +2,22 @@ using Bit.Core.Services; using Quartz; -namespace Bit.Api.Jobs +namespace Bit.Api.Jobs; + +public class EmergencyAccessNotificationJob : BaseJob { - public class EmergencyAccessNotificationJob : BaseJob + private readonly IServiceScopeFactory _serviceScopeFactory; + + public EmergencyAccessNotificationJob(IServiceScopeFactory serviceScopeFactory, ILogger logger) + : base(logger) { - private readonly IServiceScopeFactory _serviceScopeFactory; + _serviceScopeFactory = serviceScopeFactory; + } - public EmergencyAccessNotificationJob(IServiceScopeFactory serviceScopeFactory, ILogger logger) - : base(logger) - { - _serviceScopeFactory = serviceScopeFactory; - } - - protected override async Task ExecuteJobAsync(IJobExecutionContext context) - { - using var scope = _serviceScopeFactory.CreateScope(); - var emergencyAccessService = scope.ServiceProvider.GetService(typeof(IEmergencyAccessService)) as IEmergencyAccessService; - await emergencyAccessService.SendNotificationsAsync(); - } + protected override async Task ExecuteJobAsync(IJobExecutionContext context) + { + using var scope = _serviceScopeFactory.CreateScope(); + var emergencyAccessService = scope.ServiceProvider.GetService(typeof(IEmergencyAccessService)) as IEmergencyAccessService; + await emergencyAccessService.SendNotificationsAsync(); } } diff --git a/src/Api/Jobs/EmergencyAccessTimeoutJob.cs b/src/Api/Jobs/EmergencyAccessTimeoutJob.cs index 7e7e85c6d..642f4173c 100644 --- a/src/Api/Jobs/EmergencyAccessTimeoutJob.cs +++ b/src/Api/Jobs/EmergencyAccessTimeoutJob.cs @@ -2,23 +2,22 @@ using Bit.Core.Services; using Quartz; -namespace Bit.Api.Jobs +namespace Bit.Api.Jobs; + +public class EmergencyAccessTimeoutJob : BaseJob { - public class EmergencyAccessTimeoutJob : BaseJob + private readonly IServiceScopeFactory _serviceScopeFactory; + + public EmergencyAccessTimeoutJob(IServiceScopeFactory serviceScopeFactory, ILogger logger) + : base(logger) { - private readonly IServiceScopeFactory _serviceScopeFactory; + _serviceScopeFactory = serviceScopeFactory; + } - public EmergencyAccessTimeoutJob(IServiceScopeFactory serviceScopeFactory, ILogger logger) - : base(logger) - { - _serviceScopeFactory = serviceScopeFactory; - } - - protected override async Task ExecuteJobAsync(IJobExecutionContext context) - { - using var scope = _serviceScopeFactory.CreateScope(); - var emergencyAccessService = scope.ServiceProvider.GetService(typeof(IEmergencyAccessService)) as IEmergencyAccessService; - await emergencyAccessService.HandleTimedOutRequestsAsync(); - } + protected override async Task ExecuteJobAsync(IJobExecutionContext context) + { + using var scope = _serviceScopeFactory.CreateScope(); + var emergencyAccessService = scope.ServiceProvider.GetService(typeof(IEmergencyAccessService)) as IEmergencyAccessService; + await emergencyAccessService.HandleTimedOutRequestsAsync(); } } diff --git a/src/Api/Jobs/JobsHostedService.cs b/src/Api/Jobs/JobsHostedService.cs index 99adbb0e2..241a01242 100644 --- a/src/Api/Jobs/JobsHostedService.cs +++ b/src/Api/Jobs/JobsHostedService.cs @@ -2,82 +2,81 @@ using Bit.Core.Settings; using Quartz; -namespace Bit.Api.Jobs +namespace Bit.Api.Jobs; + +public class JobsHostedService : BaseJobsHostedService { - public class JobsHostedService : BaseJobsHostedService + public JobsHostedService( + GlobalSettings globalSettings, + IServiceProvider serviceProvider, + ILogger logger, + ILogger listenerLogger) + : base(globalSettings, serviceProvider, logger, listenerLogger) { } + + public override async Task StartAsync(CancellationToken cancellationToken) { - public JobsHostedService( - GlobalSettings globalSettings, - IServiceProvider serviceProvider, - ILogger logger, - ILogger listenerLogger) - : base(globalSettings, serviceProvider, logger, listenerLogger) { } + var everyTopOfTheHourTrigger = TriggerBuilder.Create() + .WithIdentity("EveryTopOfTheHourTrigger") + .StartNow() + .WithCronSchedule("0 0 * * * ?") + .Build(); + var emergencyAccessNotificationTrigger = TriggerBuilder.Create() + .WithIdentity("EmergencyAccessNotificationTrigger") + .StartNow() + .WithCronSchedule("0 0 * * * ?") + .Build(); + var emergencyAccessTimeoutTrigger = TriggerBuilder.Create() + .WithIdentity("EmergencyAccessTimeoutTrigger") + .StartNow() + .WithCronSchedule("0 0 * * * ?") + .Build(); + var everyTopOfTheSixthHourTrigger = TriggerBuilder.Create() + .WithIdentity("EveryTopOfTheSixthHourTrigger") + .StartNow() + .WithCronSchedule("0 0 */6 * * ?") + .Build(); + var everyTwelfthHourAndThirtyMinutesTrigger = TriggerBuilder.Create() + .WithIdentity("EveryTwelfthHourAndThirtyMinutesTrigger") + .StartNow() + .WithCronSchedule("0 30 */12 * * ?") + .Build(); + var randomDailySponsorshipSyncTrigger = TriggerBuilder.Create() + .WithIdentity("RandomDailySponsorshipSyncTrigger") + .StartAt(DateBuilder.FutureDate(new Random().Next(24), IntervalUnit.Hour)) + .WithSimpleSchedule(x => x + .WithIntervalInHours(24) + .RepeatForever()) + .Build(); - public override async Task StartAsync(CancellationToken cancellationToken) + var jobs = new List> { - var everyTopOfTheHourTrigger = TriggerBuilder.Create() - .WithIdentity("EveryTopOfTheHourTrigger") - .StartNow() - .WithCronSchedule("0 0 * * * ?") - .Build(); - var emergencyAccessNotificationTrigger = TriggerBuilder.Create() - .WithIdentity("EmergencyAccessNotificationTrigger") - .StartNow() - .WithCronSchedule("0 0 * * * ?") - .Build(); - var emergencyAccessTimeoutTrigger = TriggerBuilder.Create() - .WithIdentity("EmergencyAccessTimeoutTrigger") - .StartNow() - .WithCronSchedule("0 0 * * * ?") - .Build(); - var everyTopOfTheSixthHourTrigger = TriggerBuilder.Create() - .WithIdentity("EveryTopOfTheSixthHourTrigger") - .StartNow() - .WithCronSchedule("0 0 */6 * * ?") - .Build(); - var everyTwelfthHourAndThirtyMinutesTrigger = TriggerBuilder.Create() - .WithIdentity("EveryTwelfthHourAndThirtyMinutesTrigger") - .StartNow() - .WithCronSchedule("0 30 */12 * * ?") - .Build(); - var randomDailySponsorshipSyncTrigger = TriggerBuilder.Create() - .WithIdentity("RandomDailySponsorshipSyncTrigger") - .StartAt(DateBuilder.FutureDate(new Random().Next(24), IntervalUnit.Hour)) - .WithSimpleSchedule(x => x - .WithIntervalInHours(24) - .RepeatForever()) - .Build(); + new Tuple(typeof(AliveJob), everyTopOfTheHourTrigger), + new Tuple(typeof(EmergencyAccessNotificationJob), emergencyAccessNotificationTrigger), + new Tuple(typeof(EmergencyAccessTimeoutJob), emergencyAccessTimeoutTrigger), + new Tuple(typeof(ValidateUsersJob), everyTopOfTheSixthHourTrigger), + new Tuple(typeof(ValidateOrganizationsJob), everyTwelfthHourAndThirtyMinutesTrigger) + }; - var jobs = new List> - { - new Tuple(typeof(AliveJob), everyTopOfTheHourTrigger), - new Tuple(typeof(EmergencyAccessNotificationJob), emergencyAccessNotificationTrigger), - new Tuple(typeof(EmergencyAccessTimeoutJob), emergencyAccessTimeoutTrigger), - new Tuple(typeof(ValidateUsersJob), everyTopOfTheSixthHourTrigger), - new Tuple(typeof(ValidateOrganizationsJob), everyTwelfthHourAndThirtyMinutesTrigger) - }; - - if (_globalSettings.SelfHosted && _globalSettings.EnableCloudCommunication) - { - jobs.Add(new Tuple(typeof(SelfHostedSponsorshipSyncJob), randomDailySponsorshipSyncTrigger)); - } - - Jobs = jobs; - - await base.StartAsync(cancellationToken); + if (_globalSettings.SelfHosted && _globalSettings.EnableCloudCommunication) + { + jobs.Add(new Tuple(typeof(SelfHostedSponsorshipSyncJob), randomDailySponsorshipSyncTrigger)); } - public static void AddJobsServices(IServiceCollection services, bool selfHosted) + Jobs = jobs; + + await base.StartAsync(cancellationToken); + } + + public static void AddJobsServices(IServiceCollection services, bool selfHosted) + { + if (selfHosted) { - if (selfHosted) - { - services.AddTransient(); - } - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); + services.AddTransient(); } + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); } } diff --git a/src/Api/Jobs/SelfHostedSponsorshipSyncJob.cs b/src/Api/Jobs/SelfHostedSponsorshipSyncJob.cs index 7ffb32d3b..d21759824 100644 --- a/src/Api/Jobs/SelfHostedSponsorshipSyncJob.cs +++ b/src/Api/Jobs/SelfHostedSponsorshipSyncJob.cs @@ -7,59 +7,58 @@ using Bit.Core.Services; using Bit.Core.Settings; using Quartz; -namespace Bit.Api.Jobs -{ - public class SelfHostedSponsorshipSyncJob : BaseJob - { - private readonly IServiceProvider _serviceProvider; - private IOrganizationRepository _organizationRepository; - private IOrganizationConnectionRepository _organizationConnectionRepository; - private readonly ILicensingService _licensingService; - private GlobalSettings _globalSettings; +namespace Bit.Api.Jobs; - public SelfHostedSponsorshipSyncJob( - IServiceProvider serviceProvider, - IOrganizationRepository organizationRepository, - IOrganizationConnectionRepository organizationConnectionRepository, - ILicensingService licensingService, - ILogger logger, - GlobalSettings globalSettings) - : base(logger) +public class SelfHostedSponsorshipSyncJob : BaseJob +{ + private readonly IServiceProvider _serviceProvider; + private IOrganizationRepository _organizationRepository; + private IOrganizationConnectionRepository _organizationConnectionRepository; + private readonly ILicensingService _licensingService; + private GlobalSettings _globalSettings; + + public SelfHostedSponsorshipSyncJob( + IServiceProvider serviceProvider, + IOrganizationRepository organizationRepository, + IOrganizationConnectionRepository organizationConnectionRepository, + ILicensingService licensingService, + ILogger logger, + GlobalSettings globalSettings) + : base(logger) + { + _serviceProvider = serviceProvider; + _organizationRepository = organizationRepository; + _organizationConnectionRepository = organizationConnectionRepository; + _licensingService = licensingService; + _globalSettings = globalSettings; + } + + protected override async Task ExecuteJobAsync(IJobExecutionContext context) + { + if (!_globalSettings.EnableCloudCommunication) { - _serviceProvider = serviceProvider; - _organizationRepository = organizationRepository; - _organizationConnectionRepository = organizationConnectionRepository; - _licensingService = licensingService; - _globalSettings = globalSettings; + _logger.LogInformation("Skipping Organization sync with cloud - Cloud communication is disabled in global settings"); + return; } - protected override async Task ExecuteJobAsync(IJobExecutionContext context) + var organizations = await _organizationRepository.GetManyByEnabledAsync(); + + using (var scope = _serviceProvider.CreateScope()) { - if (!_globalSettings.EnableCloudCommunication) + var syncCommand = scope.ServiceProvider.GetRequiredService(); + foreach (var org in organizations) { - _logger.LogInformation("Skipping Organization sync with cloud - Cloud communication is disabled in global settings"); - return; - } - - var organizations = await _organizationRepository.GetManyByEnabledAsync(); - - using (var scope = _serviceProvider.CreateScope()) - { - var syncCommand = scope.ServiceProvider.GetRequiredService(); - foreach (var org in organizations) + var connection = (await _organizationConnectionRepository.GetEnabledByOrganizationIdTypeAsync(org.Id, OrganizationConnectionType.CloudBillingSync)).FirstOrDefault(); + if (connection != null) { - var connection = (await _organizationConnectionRepository.GetEnabledByOrganizationIdTypeAsync(org.Id, OrganizationConnectionType.CloudBillingSync)).FirstOrDefault(); - if (connection != null) + try { - try - { - var config = connection.GetConfig(); - await syncCommand.SyncOrganization(org.Id, config.CloudOrganizationId, connection); - } - catch (Exception ex) - { - _logger.LogError(ex, $"Sponsorship sync for organization {org.Name} Failed"); - } + var config = connection.GetConfig(); + await syncCommand.SyncOrganization(org.Id, config.CloudOrganizationId, connection); + } + catch (Exception ex) + { + _logger.LogError(ex, $"Sponsorship sync for organization {org.Name} Failed"); } } } diff --git a/src/Api/Jobs/ValidateOrganizationsJob.cs b/src/Api/Jobs/ValidateOrganizationsJob.cs index d3ec2dad5..8c4225a01 100644 --- a/src/Api/Jobs/ValidateOrganizationsJob.cs +++ b/src/Api/Jobs/ValidateOrganizationsJob.cs @@ -2,23 +2,22 @@ using Bit.Core.Services; using Quartz; -namespace Bit.Api.Jobs +namespace Bit.Api.Jobs; + +public class ValidateOrganizationsJob : BaseJob { - public class ValidateOrganizationsJob : BaseJob + private readonly ILicensingService _licensingService; + + public ValidateOrganizationsJob( + ILicensingService licensingService, + ILogger logger) + : base(logger) { - private readonly ILicensingService _licensingService; + _licensingService = licensingService; + } - public ValidateOrganizationsJob( - ILicensingService licensingService, - ILogger logger) - : base(logger) - { - _licensingService = licensingService; - } - - protected async override Task ExecuteJobAsync(IJobExecutionContext context) - { - await _licensingService.ValidateOrganizationsAsync(); - } + protected async override Task ExecuteJobAsync(IJobExecutionContext context) + { + await _licensingService.ValidateOrganizationsAsync(); } } diff --git a/src/Api/Jobs/ValidateUsersJob.cs b/src/Api/Jobs/ValidateUsersJob.cs index 126162427..be531b47d 100644 --- a/src/Api/Jobs/ValidateUsersJob.cs +++ b/src/Api/Jobs/ValidateUsersJob.cs @@ -2,23 +2,22 @@ using Bit.Core.Services; using Quartz; -namespace Bit.Api.Jobs +namespace Bit.Api.Jobs; + +public class ValidateUsersJob : BaseJob { - public class ValidateUsersJob : BaseJob + private readonly ILicensingService _licensingService; + + public ValidateUsersJob( + ILicensingService licensingService, + ILogger logger) + : base(logger) { - private readonly ILicensingService _licensingService; + _licensingService = licensingService; + } - public ValidateUsersJob( - ILicensingService licensingService, - ILogger logger) - : base(logger) - { - _licensingService = licensingService; - } - - protected async override Task ExecuteJobAsync(IJobExecutionContext context) - { - await _licensingService.ValidateUsersAsync(); - } + protected async override Task ExecuteJobAsync(IJobExecutionContext context) + { + await _licensingService.ValidateUsersAsync(); } } diff --git a/src/Api/Models/CipherAttachmentModel.cs b/src/Api/Models/CipherAttachmentModel.cs index a51080658..c1ae19718 100644 --- a/src/Api/Models/CipherAttachmentModel.cs +++ b/src/Api/Models/CipherAttachmentModel.cs @@ -1,21 +1,20 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models +namespace Bit.Api.Models; + +public class CipherAttachmentModel { - public class CipherAttachmentModel + public CipherAttachmentModel() { } + + public CipherAttachmentModel(CipherAttachment.MetaData data) { - public CipherAttachmentModel() { } - - public CipherAttachmentModel(CipherAttachment.MetaData data) - { - FileName = data.FileName; - Key = data.Key; - } - - [EncryptedStringLength(1000)] - public string FileName { get; set; } - [EncryptedStringLength(1000)] - public string Key { get; set; } + FileName = data.FileName; + Key = data.Key; } + + [EncryptedStringLength(1000)] + public string FileName { get; set; } + [EncryptedStringLength(1000)] + public string Key { get; set; } } diff --git a/src/Api/Models/CipherCardModel.cs b/src/Api/Models/CipherCardModel.cs index d95123e32..07ea4d1e6 100644 --- a/src/Api/Models/CipherCardModel.cs +++ b/src/Api/Models/CipherCardModel.cs @@ -2,39 +2,38 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models +namespace Bit.Api.Models; + +public class CipherCardModel { - public class CipherCardModel + public CipherCardModel() { } + + public CipherCardModel(CipherCardData data) { - public CipherCardModel() { } - - public CipherCardModel(CipherCardData data) - { - CardholderName = data.CardholderName; - Brand = data.Brand; - Number = data.Number; - ExpMonth = data.ExpMonth; - ExpYear = data.ExpYear; - Code = data.Code; - } - - [EncryptedString] - [EncryptedStringLength(1000)] - public string CardholderName { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Brand { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Number { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string ExpMonth { get; set; } - [EncryptedString] - [StringLength(1000)] - public string ExpYear { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Code { get; set; } + CardholderName = data.CardholderName; + Brand = data.Brand; + Number = data.Number; + ExpMonth = data.ExpMonth; + ExpYear = data.ExpYear; + Code = data.Code; } + + [EncryptedString] + [EncryptedStringLength(1000)] + public string CardholderName { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Brand { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Number { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string ExpMonth { get; set; } + [EncryptedString] + [StringLength(1000)] + public string ExpYear { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Code { get; set; } } diff --git a/src/Api/Models/CipherFieldModel.cs b/src/Api/Models/CipherFieldModel.cs index 5ade6e883..675dcfce0 100644 --- a/src/Api/Models/CipherFieldModel.cs +++ b/src/Api/Models/CipherFieldModel.cs @@ -2,36 +2,35 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models +namespace Bit.Api.Models; + +public class CipherFieldModel { - public class CipherFieldModel + public CipherFieldModel() { } + + public CipherFieldModel(CipherFieldData data) { - public CipherFieldModel() { } + Type = data.Type; + Name = data.Name; + Value = data.Value; + LinkedId = data.LinkedId ?? null; + } - public CipherFieldModel(CipherFieldData data) + public FieldType Type { get; set; } + [EncryptedStringLength(1000)] + public string Name { get; set; } + [EncryptedStringLength(5000)] + public string Value { get; set; } + public int? LinkedId { get; set; } + + public CipherFieldData ToCipherFieldData() + { + return new CipherFieldData { - Type = data.Type; - Name = data.Name; - Value = data.Value; - LinkedId = data.LinkedId ?? null; - } - - public FieldType Type { get; set; } - [EncryptedStringLength(1000)] - public string Name { get; set; } - [EncryptedStringLength(5000)] - public string Value { get; set; } - public int? LinkedId { get; set; } - - public CipherFieldData ToCipherFieldData() - { - return new CipherFieldData - { - Type = Type, - Name = Name, - Value = Value, - LinkedId = LinkedId ?? null, - }; - } + Type = Type, + Name = Name, + Value = Value, + LinkedId = LinkedId ?? null, + }; } } diff --git a/src/Api/Models/CipherIdentityModel.cs b/src/Api/Models/CipherIdentityModel.cs index ce5016619..7c1fed164 100644 --- a/src/Api/Models/CipherIdentityModel.cs +++ b/src/Api/Models/CipherIdentityModel.cs @@ -2,87 +2,86 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models +namespace Bit.Api.Models; + +public class CipherIdentityModel { - public class CipherIdentityModel + public CipherIdentityModel() { } + + public CipherIdentityModel(CipherIdentityData data) { - public CipherIdentityModel() { } - - public CipherIdentityModel(CipherIdentityData data) - { - Title = data.Title; - FirstName = data.FirstName; - MiddleName = data.MiddleName; - LastName = data.LastName; - Address1 = data.Address1; - Address2 = data.Address2; - Address3 = data.Address3; - City = data.City; - State = data.State; - PostalCode = data.PostalCode; - Country = data.Country; - Company = data.Company; - Email = data.Email; - Phone = data.Phone; - SSN = data.SSN; - Username = data.Username; - PassportNumber = data.PassportNumber; - LicenseNumber = data.LicenseNumber; - } - - [EncryptedString] - [EncryptedStringLength(1000)] - public string Title { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string FirstName { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string MiddleName { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string LastName { get; set; } - [EncryptedString] - [StringLength(1000)] - public string Address1 { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Address2 { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Address3 { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string City { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string State { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string PostalCode { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Country { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Company { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Email { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Phone { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string SSN { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Username { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string PassportNumber { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string LicenseNumber { get; set; } + Title = data.Title; + FirstName = data.FirstName; + MiddleName = data.MiddleName; + LastName = data.LastName; + Address1 = data.Address1; + Address2 = data.Address2; + Address3 = data.Address3; + City = data.City; + State = data.State; + PostalCode = data.PostalCode; + Country = data.Country; + Company = data.Company; + Email = data.Email; + Phone = data.Phone; + SSN = data.SSN; + Username = data.Username; + PassportNumber = data.PassportNumber; + LicenseNumber = data.LicenseNumber; } + + [EncryptedString] + [EncryptedStringLength(1000)] + public string Title { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string FirstName { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string MiddleName { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string LastName { get; set; } + [EncryptedString] + [StringLength(1000)] + public string Address1 { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Address2 { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Address3 { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string City { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string State { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string PostalCode { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Country { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Company { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Email { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Phone { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string SSN { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Username { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string PassportNumber { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string LicenseNumber { get; set; } } diff --git a/src/Api/Models/CipherLoginModel.cs b/src/Api/Models/CipherLoginModel.cs index 156da6ba7..134ca09cb 100644 --- a/src/Api/Models/CipherLoginModel.cs +++ b/src/Api/Models/CipherLoginModel.cs @@ -2,84 +2,83 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models -{ - public class CipherLoginModel - { - public CipherLoginModel() { } +namespace Bit.Api.Models; - public CipherLoginModel(CipherLoginData data) +public class CipherLoginModel +{ + public CipherLoginModel() { } + + public CipherLoginModel(CipherLoginData data) + { + Uris = data.Uris?.Select(u => new CipherLoginUriModel(u))?.ToList(); + if (!Uris?.Any() ?? true) { - Uris = data.Uris?.Select(u => new CipherLoginUriModel(u))?.ToList(); - if (!Uris?.Any() ?? true) + Uri = data.Uri; + } + + Username = data.Username; + Password = data.Password; + PasswordRevisionDate = data.PasswordRevisionDate; + Totp = data.Totp; + AutofillOnPageLoad = data.AutofillOnPageLoad; + } + + [EncryptedString] + [EncryptedStringLength(10000)] + public string Uri + { + get => Uris?.FirstOrDefault()?.Uri; + set + { + if (string.IsNullOrWhiteSpace(value)) { - Uri = data.Uri; + return; } - Username = data.Username; - Password = data.Password; - PasswordRevisionDate = data.PasswordRevisionDate; - Totp = data.Totp; - AutofillOnPageLoad = data.AutofillOnPageLoad; + if (Uris == null) + { + Uris = new List(); + } + + Uris.Add(new CipherLoginUriModel(value)); + } + } + public List Uris { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Username { get; set; } + [EncryptedString] + [EncryptedStringLength(5000)] + public string Password { get; set; } + public DateTime? PasswordRevisionDate { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Totp { get; set; } + public bool? AutofillOnPageLoad { get; set; } + + public class CipherLoginUriModel + { + public CipherLoginUriModel() { } + + public CipherLoginUriModel(string uri) + { + Uri = uri; + } + + public CipherLoginUriModel(CipherLoginData.CipherLoginUriData uri) + { + Uri = uri.Uri; + Match = uri.Match; } [EncryptedString] [EncryptedStringLength(10000)] - public string Uri + public string Uri { get; set; } + public UriMatchType? Match { get; set; } = null; + + public CipherLoginData.CipherLoginUriData ToCipherLoginUriData() { - get => Uris?.FirstOrDefault()?.Uri; - set - { - if (string.IsNullOrWhiteSpace(value)) - { - return; - } - - if (Uris == null) - { - Uris = new List(); - } - - Uris.Add(new CipherLoginUriModel(value)); - } - } - public List Uris { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Username { get; set; } - [EncryptedString] - [EncryptedStringLength(5000)] - public string Password { get; set; } - public DateTime? PasswordRevisionDate { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Totp { get; set; } - public bool? AutofillOnPageLoad { get; set; } - - public class CipherLoginUriModel - { - public CipherLoginUriModel() { } - - public CipherLoginUriModel(string uri) - { - Uri = uri; - } - - public CipherLoginUriModel(CipherLoginData.CipherLoginUriData uri) - { - Uri = uri.Uri; - Match = uri.Match; - } - - [EncryptedString] - [EncryptedStringLength(10000)] - public string Uri { get; set; } - public UriMatchType? Match { get; set; } = null; - - public CipherLoginData.CipherLoginUriData ToCipherLoginUriData() - { - return new CipherLoginData.CipherLoginUriData { Uri = Uri, Match = Match, }; - } + return new CipherLoginData.CipherLoginUriData { Uri = Uri, Match = Match, }; } } } diff --git a/src/Api/Models/CipherPasswordHistoryModel.cs b/src/Api/Models/CipherPasswordHistoryModel.cs index bd9eb296f..329c2cf27 100644 --- a/src/Api/Models/CipherPasswordHistoryModel.cs +++ b/src/Api/Models/CipherPasswordHistoryModel.cs @@ -2,28 +2,27 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models +namespace Bit.Api.Models; + +public class CipherPasswordHistoryModel { - public class CipherPasswordHistoryModel + public CipherPasswordHistoryModel() { } + + public CipherPasswordHistoryModel(CipherPasswordHistoryData data) { - public CipherPasswordHistoryModel() { } + Password = data.Password; + LastUsedDate = data.LastUsedDate; + } - public CipherPasswordHistoryModel(CipherPasswordHistoryData data) - { - Password = data.Password; - LastUsedDate = data.LastUsedDate; - } + [EncryptedString] + [EncryptedStringLength(5000)] + [Required] + public string Password { get; set; } + [Required] + public DateTime? LastUsedDate { get; set; } - [EncryptedString] - [EncryptedStringLength(5000)] - [Required] - public string Password { get; set; } - [Required] - public DateTime? LastUsedDate { get; set; } - - public CipherPasswordHistoryData ToCipherPasswordHistoryData() - { - return new CipherPasswordHistoryData { Password = Password, LastUsedDate = LastUsedDate.Value, }; - } + public CipherPasswordHistoryData ToCipherPasswordHistoryData() + { + return new CipherPasswordHistoryData { Password = Password, LastUsedDate = LastUsedDate.Value, }; } } diff --git a/src/Api/Models/CipherSecureNoteModel.cs b/src/Api/Models/CipherSecureNoteModel.cs index 6ea63d299..5ab35d1e8 100644 --- a/src/Api/Models/CipherSecureNoteModel.cs +++ b/src/Api/Models/CipherSecureNoteModel.cs @@ -1,17 +1,16 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; -namespace Bit.Api.Models +namespace Bit.Api.Models; + +public class CipherSecureNoteModel { - public class CipherSecureNoteModel + public CipherSecureNoteModel() { } + + public CipherSecureNoteModel(CipherSecureNoteData data) { - public CipherSecureNoteModel() { } - - public CipherSecureNoteModel(CipherSecureNoteData data) - { - Type = data.Type; - } - - public SecureNoteType Type { get; set; } + Type = data.Type; } + + public SecureNoteType Type { get; set; } } diff --git a/src/Api/Models/Public/AssociationWithPermissionsBaseModel.cs b/src/Api/Models/Public/AssociationWithPermissionsBaseModel.cs index 54a0a204f..014f67a04 100644 --- a/src/Api/Models/Public/AssociationWithPermissionsBaseModel.cs +++ b/src/Api/Models/Public/AssociationWithPermissionsBaseModel.cs @@ -1,19 +1,18 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Public +namespace Bit.Api.Models.Public; + +public abstract class AssociationWithPermissionsBaseModel { - public abstract class AssociationWithPermissionsBaseModel - { - /// - /// The associated object's unique identifier. - /// - /// bfbc8338-e329-4dc0-b0c9-317c2ebf1a09 - [Required] - public Guid? Id { get; set; } - /// - /// When true, the read only permission will not allow the user or group to make changes to items. - /// - [Required] - public bool? ReadOnly { get; set; } - } + /// + /// The associated object's unique identifier. + /// + /// bfbc8338-e329-4dc0-b0c9-317c2ebf1a09 + [Required] + public Guid? Id { get; set; } + /// + /// When true, the read only permission will not allow the user or group to make changes to items. + /// + [Required] + public bool? ReadOnly { get; set; } } diff --git a/src/Api/Models/Public/CollectionBaseModel.cs b/src/Api/Models/Public/CollectionBaseModel.cs index 5c36ef9b4..0dd4b6ce8 100644 --- a/src/Api/Models/Public/CollectionBaseModel.cs +++ b/src/Api/Models/Public/CollectionBaseModel.cs @@ -1,14 +1,13 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Public +namespace Bit.Api.Models.Public; + +public abstract class CollectionBaseModel { - public abstract class CollectionBaseModel - { - /// - /// External identifier for reference or linking this collection to another system. - /// - /// external_id_123456 - [StringLength(300)] - public string ExternalId { get; set; } - } + /// + /// External identifier for reference or linking this collection to another system. + /// + /// external_id_123456 + [StringLength(300)] + public string ExternalId { get; set; } } diff --git a/src/Api/Models/Public/GroupBaseModel.cs b/src/Api/Models/Public/GroupBaseModel.cs index 28b5ebe08..2b09e2952 100644 --- a/src/Api/Models/Public/GroupBaseModel.cs +++ b/src/Api/Models/Public/GroupBaseModel.cs @@ -1,27 +1,26 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Public +namespace Bit.Api.Models.Public; + +public abstract class GroupBaseModel { - public abstract class GroupBaseModel - { - /// - /// The name of the group. - /// - /// Development Team - [Required] - [StringLength(100)] - public string Name { get; set; } - /// - /// Determines if this group can access all collections within the organization, or only the associated - /// collections. If set to true, this option overrides any collection assignments. - /// - [Required] - public bool? AccessAll { get; set; } - /// - /// External identifier for reference or linking this group to another system, such as a user directory. - /// - /// external_id_123456 - [StringLength(300)] - public string ExternalId { get; set; } - } + /// + /// The name of the group. + /// + /// Development Team + [Required] + [StringLength(100)] + public string Name { get; set; } + /// + /// Determines if this group can access all collections within the organization, or only the associated + /// collections. If set to true, this option overrides any collection assignments. + /// + [Required] + public bool? AccessAll { get; set; } + /// + /// External identifier for reference or linking this group to another system, such as a user directory. + /// + /// external_id_123456 + [StringLength(300)] + public string ExternalId { get; set; } } diff --git a/src/Api/Models/Public/MemberBaseModel.cs b/src/Api/Models/Public/MemberBaseModel.cs index 47621cf18..af57d8064 100644 --- a/src/Api/Models/Public/MemberBaseModel.cs +++ b/src/Api/Models/Public/MemberBaseModel.cs @@ -3,59 +3,58 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Api.Models.Public +namespace Bit.Api.Models.Public; + +public abstract class MemberBaseModel { - public abstract class MemberBaseModel + public MemberBaseModel() { } + + public MemberBaseModel(OrganizationUser user) { - public MemberBaseModel() { } - - public MemberBaseModel(OrganizationUser user) + if (user == null) { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - Type = user.Type; - AccessAll = user.AccessAll; - ExternalId = user.ExternalId; - ResetPasswordEnrolled = user.ResetPasswordKey != null; + throw new ArgumentNullException(nameof(user)); } - public MemberBaseModel(OrganizationUserUserDetails user) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - Type = user.Type; - AccessAll = user.AccessAll; - ExternalId = user.ExternalId; - ResetPasswordEnrolled = user.ResetPasswordKey != null; - } - - /// - /// The member's type (or role) within the organization. - /// - [Required] - public OrganizationUserType? Type { get; set; } - /// - /// Determines if this member can access all collections within the organization, or only the associated - /// collections. If set to true, this option overrides any collection assignments. - /// - [Required] - public bool? AccessAll { get; set; } - /// - /// External identifier for reference or linking this member to another system, such as a user directory. - /// - /// external_id_123456 - [StringLength(300)] - public string ExternalId { get; set; } - /// - /// Returns true if the member has enrolled in Password Reset assistance within the organization - /// - [Required] - public bool ResetPasswordEnrolled { get; set; } + Type = user.Type; + AccessAll = user.AccessAll; + ExternalId = user.ExternalId; + ResetPasswordEnrolled = user.ResetPasswordKey != null; } + + public MemberBaseModel(OrganizationUserUserDetails user) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + Type = user.Type; + AccessAll = user.AccessAll; + ExternalId = user.ExternalId; + ResetPasswordEnrolled = user.ResetPasswordKey != null; + } + + /// + /// The member's type (or role) within the organization. + /// + [Required] + public OrganizationUserType? Type { get; set; } + /// + /// Determines if this member can access all collections within the organization, or only the associated + /// collections. If set to true, this option overrides any collection assignments. + /// + [Required] + public bool? AccessAll { get; set; } + /// + /// External identifier for reference or linking this member to another system, such as a user directory. + /// + /// external_id_123456 + [StringLength(300)] + public string ExternalId { get; set; } + /// + /// Returns true if the member has enrolled in Password Reset assistance within the organization + /// + [Required] + public bool ResetPasswordEnrolled { get; set; } } diff --git a/src/Api/Models/Public/PolicyBaseModel.cs b/src/Api/Models/Public/PolicyBaseModel.cs index 1814c9f4a..2ad8e7600 100644 --- a/src/Api/Models/Public/PolicyBaseModel.cs +++ b/src/Api/Models/Public/PolicyBaseModel.cs @@ -1,17 +1,16 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Public +namespace Bit.Api.Models.Public; + +public abstract class PolicyBaseModel { - public abstract class PolicyBaseModel - { - /// - /// Determines if this policy is enabled and enforced. - /// - [Required] - public bool? Enabled { get; set; } - /// - /// Data for the policy. - /// - public Dictionary Data { get; set; } - } + /// + /// Determines if this policy is enabled and enforced. + /// + [Required] + public bool? Enabled { get; set; } + /// + /// Data for the policy. + /// + public Dictionary Data { get; set; } } diff --git a/src/Api/Models/Public/Request/AssociationWithPermissionsRequestModel.cs b/src/Api/Models/Public/Request/AssociationWithPermissionsRequestModel.cs index 9a87760b9..b93b16e59 100644 --- a/src/Api/Models/Public/Request/AssociationWithPermissionsRequestModel.cs +++ b/src/Api/Models/Public/Request/AssociationWithPermissionsRequestModel.cs @@ -1,16 +1,15 @@ using Bit.Core.Models.Data; -namespace Bit.Api.Models.Public.Request +namespace Bit.Api.Models.Public.Request; + +public class AssociationWithPermissionsRequestModel : AssociationWithPermissionsBaseModel { - public class AssociationWithPermissionsRequestModel : AssociationWithPermissionsBaseModel + public SelectionReadOnly ToSelectionReadOnly() { - public SelectionReadOnly ToSelectionReadOnly() + return new SelectionReadOnly { - return new SelectionReadOnly - { - Id = Id.Value, - ReadOnly = ReadOnly.Value - }; - } + Id = Id.Value, + ReadOnly = ReadOnly.Value + }; } } diff --git a/src/Api/Models/Public/Request/CollectionUpdateRequestModel.cs b/src/Api/Models/Public/Request/CollectionUpdateRequestModel.cs index 36b77137d..f38d1fec7 100644 --- a/src/Api/Models/Public/Request/CollectionUpdateRequestModel.cs +++ b/src/Api/Models/Public/Request/CollectionUpdateRequestModel.cs @@ -1,18 +1,17 @@ using Bit.Core.Entities; -namespace Bit.Api.Models.Public.Request -{ - public class CollectionUpdateRequestModel : CollectionBaseModel - { - /// - /// The associated groups that this collection is assigned to. - /// - public IEnumerable Groups { get; set; } +namespace Bit.Api.Models.Public.Request; - public Collection ToCollection(Collection existingCollection) - { - existingCollection.ExternalId = ExternalId; - return existingCollection; - } +public class CollectionUpdateRequestModel : CollectionBaseModel +{ + /// + /// The associated groups that this collection is assigned to. + /// + public IEnumerable Groups { get; set; } + + public Collection ToCollection(Collection existingCollection) + { + existingCollection.ExternalId = ExternalId; + return existingCollection; } } diff --git a/src/Api/Models/Public/Request/EventFilterRequestModel.cs b/src/Api/Models/Public/Request/EventFilterRequestModel.cs index 74a1700a7..852076eeb 100644 --- a/src/Api/Models/Public/Request/EventFilterRequestModel.cs +++ b/src/Api/Models/Public/Request/EventFilterRequestModel.cs @@ -1,50 +1,49 @@ using Bit.Core.Exceptions; -namespace Bit.Api.Models.Public.Request +namespace Bit.Api.Models.Public.Request; + +public class EventFilterRequestModel { - public class EventFilterRequestModel + /// + /// The start date. Must be less than the end date. + /// + public DateTime? Start { get; set; } + /// + /// The end date. Must be greater than the start date. + /// + public DateTime? End { get; set; } + /// + /// The unique identifier of the user that performed the event. + /// + public Guid? ActingUserId { get; set; } + /// + /// The unique identifier of the related item that the event describes. + /// + public Guid? ItemId { get; set; } + /// + /// A cursor for use in pagination. + /// + public string ContinuationToken { get; set; } + + public Tuple ToDateRange() { - /// - /// The start date. Must be less than the end date. - /// - public DateTime? Start { get; set; } - /// - /// The end date. Must be greater than the start date. - /// - public DateTime? End { get; set; } - /// - /// The unique identifier of the user that performed the event. - /// - public Guid? ActingUserId { get; set; } - /// - /// The unique identifier of the related item that the event describes. - /// - public Guid? ItemId { get; set; } - /// - /// A cursor for use in pagination. - /// - public string ContinuationToken { get; set; } - - public Tuple ToDateRange() + if (!End.HasValue || !Start.HasValue) { - if (!End.HasValue || !Start.HasValue) - { - End = DateTime.UtcNow.Date.AddDays(1).AddMilliseconds(-1); - Start = DateTime.UtcNow.Date.AddDays(-30); - } - else if (Start.Value > End.Value) - { - var newEnd = Start; - Start = End; - End = newEnd; - } - - if ((End.Value - Start.Value) > TimeSpan.FromDays(367)) - { - throw new BadRequestException("Date range must be < 367 days."); - } - - return new Tuple(Start.Value, End.Value); + End = DateTime.UtcNow.Date.AddDays(1).AddMilliseconds(-1); + Start = DateTime.UtcNow.Date.AddDays(-30); } + else if (Start.Value > End.Value) + { + var newEnd = Start; + Start = End; + End = newEnd; + } + + if ((End.Value - Start.Value) > TimeSpan.FromDays(367)) + { + throw new BadRequestException("Date range must be < 367 days."); + } + + return new Tuple(Start.Value, End.Value); } } diff --git a/src/Api/Models/Public/Request/GroupCreateUpdateRequestModel.cs b/src/Api/Models/Public/Request/GroupCreateUpdateRequestModel.cs index 12e7d4489..9b8193b07 100644 --- a/src/Api/Models/Public/Request/GroupCreateUpdateRequestModel.cs +++ b/src/Api/Models/Public/Request/GroupCreateUpdateRequestModel.cs @@ -1,28 +1,27 @@ using Bit.Core.Entities; -namespace Bit.Api.Models.Public.Request +namespace Bit.Api.Models.Public.Request; + +public class GroupCreateUpdateRequestModel : GroupBaseModel { - public class GroupCreateUpdateRequestModel : GroupBaseModel + /// + /// The associated collections that this group can access. + /// + public IEnumerable Collections { get; set; } + + public Group ToGroup(Guid orgId) { - /// - /// The associated collections that this group can access. - /// - public IEnumerable Collections { get; set; } - - public Group ToGroup(Guid orgId) + return ToGroup(new Group { - return ToGroup(new Group - { - OrganizationId = orgId - }); - } + OrganizationId = orgId + }); + } - public Group ToGroup(Group existingGroup) - { - existingGroup.Name = Name; - existingGroup.AccessAll = AccessAll.Value; - existingGroup.ExternalId = ExternalId; - return existingGroup; - } + public Group ToGroup(Group existingGroup) + { + existingGroup.Name = Name; + existingGroup.AccessAll = AccessAll.Value; + existingGroup.ExternalId = ExternalId; + return existingGroup; } } diff --git a/src/Api/Models/Public/Request/MemberCreateRequestModel.cs b/src/Api/Models/Public/Request/MemberCreateRequestModel.cs index 1845fee22..447434e47 100644 --- a/src/Api/Models/Public/Request/MemberCreateRequestModel.cs +++ b/src/Api/Models/Public/Request/MemberCreateRequestModel.cs @@ -2,22 +2,21 @@ using Bit.Core.Entities; using Bit.Core.Utilities; -namespace Bit.Api.Models.Public.Request -{ - public class MemberCreateRequestModel : MemberUpdateRequestModel - { - /// - /// The member's email address. - /// - /// jsmith@example.com - [Required] - [StringLength(256)] - [StrictEmailAddress] - public string Email { get; set; } +namespace Bit.Api.Models.Public.Request; - public override OrganizationUser ToOrganizationUser(OrganizationUser existingUser) - { - throw new NotImplementedException(); - } +public class MemberCreateRequestModel : MemberUpdateRequestModel +{ + /// + /// The member's email address. + /// + /// jsmith@example.com + [Required] + [StringLength(256)] + [StrictEmailAddress] + public string Email { get; set; } + + public override OrganizationUser ToOrganizationUser(OrganizationUser existingUser) + { + throw new NotImplementedException(); } } diff --git a/src/Api/Models/Public/Request/MemberUpdateRequestModel.cs b/src/Api/Models/Public/Request/MemberUpdateRequestModel.cs index 44a07f526..6b5881186 100644 --- a/src/Api/Models/Public/Request/MemberUpdateRequestModel.cs +++ b/src/Api/Models/Public/Request/MemberUpdateRequestModel.cs @@ -1,20 +1,19 @@ using Bit.Core.Entities; -namespace Bit.Api.Models.Public.Request -{ - public class MemberUpdateRequestModel : MemberBaseModel - { - /// - /// The associated collections that this member can access. - /// - public IEnumerable Collections { get; set; } +namespace Bit.Api.Models.Public.Request; - public virtual OrganizationUser ToOrganizationUser(OrganizationUser existingUser) - { - existingUser.Type = Type.Value; - existingUser.AccessAll = AccessAll.Value; - existingUser.ExternalId = ExternalId; - return existingUser; - } +public class MemberUpdateRequestModel : MemberBaseModel +{ + /// + /// The associated collections that this member can access. + /// + public IEnumerable Collections { get; set; } + + public virtual OrganizationUser ToOrganizationUser(OrganizationUser existingUser) + { + existingUser.Type = Type.Value; + existingUser.AccessAll = AccessAll.Value; + existingUser.ExternalId = ExternalId; + return existingUser; } } diff --git a/src/Api/Models/Public/Request/OrganizationImportRequestModel.cs b/src/Api/Models/Public/Request/OrganizationImportRequestModel.cs index 2b2177b48..70bf649a2 100644 --- a/src/Api/Models/Public/Request/OrganizationImportRequestModel.cs +++ b/src/Api/Models/Public/Request/OrganizationImportRequestModel.cs @@ -4,108 +4,107 @@ using Bit.Core.Entities; using Bit.Core.Models.Business; using Bit.Core.Utilities; -namespace Bit.Api.Models.Public.Request +namespace Bit.Api.Models.Public.Request; + +public class OrganizationImportRequestModel { - public class OrganizationImportRequestModel + /// + /// Groups to import. + /// + public OrganizationImportGroupRequestModel[] Groups { get; set; } + /// + /// Members to import. + /// + public OrganizationImportMemberRequestModel[] Members { get; set; } + /// + /// Determines if the data in this request should overwrite or append to the existing organization data. + /// + [Required] + public bool? OverwriteExisting { get; set; } + /// + /// Indicates an import of over 2000 users and/or groups is expected + /// + public bool LargeImport { get; set; } = false; + + public class OrganizationImportGroupRequestModel { /// - /// Groups to import. - /// - public OrganizationImportGroupRequestModel[] Groups { get; set; } - /// - /// Members to import. - /// - public OrganizationImportMemberRequestModel[] Members { get; set; } - /// - /// Determines if the data in this request should overwrite or append to the existing organization data. + /// The name of the group. /// + /// Development Team [Required] - public bool? OverwriteExisting { get; set; } + [StringLength(100)] + public string Name { get; set; } /// - /// Indicates an import of over 2000 users and/or groups is expected + /// External identifier for reference or linking this group to another system, such as a user directory. /// - public bool LargeImport { get; set; } = false; + /// external_id_123456 + [Required] + [StringLength(300)] + [JsonConverter(typeof(PermissiveStringConverter))] + public string ExternalId { get; set; } + /// + /// The associated external ids for members in this group. + /// + [JsonConverter(typeof(PermissiveStringEnumerableConverter))] + public IEnumerable MemberExternalIds { get; set; } - public class OrganizationImportGroupRequestModel + public ImportedGroup ToImportedGroup(Guid organizationId) { - /// - /// The name of the group. - /// - /// Development Team - [Required] - [StringLength(100)] - public string Name { get; set; } - /// - /// External identifier for reference or linking this group to another system, such as a user directory. - /// - /// external_id_123456 - [Required] - [StringLength(300)] - [JsonConverter(typeof(PermissiveStringConverter))] - public string ExternalId { get; set; } - /// - /// The associated external ids for members in this group. - /// - [JsonConverter(typeof(PermissiveStringEnumerableConverter))] - public IEnumerable MemberExternalIds { get; set; } - - public ImportedGroup ToImportedGroup(Guid organizationId) + var importedGroup = new ImportedGroup { - var importedGroup = new ImportedGroup + Group = new Group { - Group = new Group - { - OrganizationId = organizationId, - Name = Name, - ExternalId = ExternalId - }, - ExternalUserIds = new HashSet(MemberExternalIds) - }; + OrganizationId = organizationId, + Name = Name, + ExternalId = ExternalId + }, + ExternalUserIds = new HashSet(MemberExternalIds) + }; - return importedGroup; - } + return importedGroup; + } + } + + public class OrganizationImportMemberRequestModel : IValidatableObject + { + /// + /// The member's email address. Required for non-deleted users. + /// + /// jsmith@example.com + [EmailAddress] + [StringLength(256)] + public string Email { get; set; } + /// + /// External identifier for reference or linking this member to another system, such as a user directory. + /// + /// external_id_123456 + [Required] + [StringLength(300)] + [JsonConverter(typeof(PermissiveStringConverter))] + public string ExternalId { get; set; } + /// + /// Determines if this member should be removed from the organization during import. + /// + public bool Deleted { get; set; } + + public ImportedOrganizationUser ToImportedOrganizationUser() + { + var importedUser = new ImportedOrganizationUser + { + Email = Email.ToLowerInvariant(), + ExternalId = ExternalId + }; + + return importedUser; } - public class OrganizationImportMemberRequestModel : IValidatableObject + public IEnumerable Validate(ValidationContext validationContext) { - /// - /// The member's email address. Required for non-deleted users. - /// - /// jsmith@example.com - [EmailAddress] - [StringLength(256)] - public string Email { get; set; } - /// - /// External identifier for reference or linking this member to another system, such as a user directory. - /// - /// external_id_123456 - [Required] - [StringLength(300)] - [JsonConverter(typeof(PermissiveStringConverter))] - public string ExternalId { get; set; } - /// - /// Determines if this member should be removed from the organization during import. - /// - public bool Deleted { get; set; } - - public ImportedOrganizationUser ToImportedOrganizationUser() + if (string.IsNullOrWhiteSpace(Email) && !Deleted) { - var importedUser = new ImportedOrganizationUser - { - Email = Email.ToLowerInvariant(), - ExternalId = ExternalId - }; - - return importedUser; - } - - public IEnumerable Validate(ValidationContext validationContext) - { - if (string.IsNullOrWhiteSpace(Email) && !Deleted) - { - yield return new ValidationResult("Email is required for enabled members.", - new string[] { nameof(Email) }); - } + yield return new ValidationResult("Email is required for enabled members.", + new string[] { nameof(Email) }); } } } diff --git a/src/Api/Models/Public/Request/PolicyUpdateRequestModel.cs b/src/Api/Models/Public/Request/PolicyUpdateRequestModel.cs index c563ca9d6..251b9358d 100644 --- a/src/Api/Models/Public/Request/PolicyUpdateRequestModel.cs +++ b/src/Api/Models/Public/Request/PolicyUpdateRequestModel.cs @@ -1,23 +1,22 @@ using System.Text.Json; using Bit.Core.Entities; -namespace Bit.Api.Models.Public.Request -{ - public class PolicyUpdateRequestModel : PolicyBaseModel - { - public Policy ToPolicy(Guid orgId) - { - return ToPolicy(new Policy - { - OrganizationId = orgId - }); - } +namespace Bit.Api.Models.Public.Request; - public virtual Policy ToPolicy(Policy existingPolicy) +public class PolicyUpdateRequestModel : PolicyBaseModel +{ + public Policy ToPolicy(Guid orgId) + { + return ToPolicy(new Policy { - existingPolicy.Enabled = Enabled.GetValueOrDefault(); - existingPolicy.Data = Data != null ? JsonSerializer.Serialize(Data) : null; - return existingPolicy; - } + OrganizationId = orgId + }); + } + + public virtual Policy ToPolicy(Policy existingPolicy) + { + existingPolicy.Enabled = Enabled.GetValueOrDefault(); + existingPolicy.Data = Data != null ? JsonSerializer.Serialize(Data) : null; + return existingPolicy; } } diff --git a/src/Api/Models/Public/Request/UpdateGroupIdsRequestModel.cs b/src/Api/Models/Public/Request/UpdateGroupIdsRequestModel.cs index a691777aa..7a818e5bb 100644 --- a/src/Api/Models/Public/Request/UpdateGroupIdsRequestModel.cs +++ b/src/Api/Models/Public/Request/UpdateGroupIdsRequestModel.cs @@ -1,10 +1,9 @@ -namespace Bit.Api.Models.Public.Request +namespace Bit.Api.Models.Public.Request; + +public class UpdateGroupIdsRequestModel { - public class UpdateGroupIdsRequestModel - { - /// - /// The associated group ids that this object can access. - /// - public IEnumerable GroupIds { get; set; } - } + /// + /// The associated group ids that this object can access. + /// + public IEnumerable GroupIds { get; set; } } diff --git a/src/Api/Models/Public/Request/UpdateMemberIdsRequestModel.cs b/src/Api/Models/Public/Request/UpdateMemberIdsRequestModel.cs index 03ea89eac..87a241831 100644 --- a/src/Api/Models/Public/Request/UpdateMemberIdsRequestModel.cs +++ b/src/Api/Models/Public/Request/UpdateMemberIdsRequestModel.cs @@ -1,10 +1,9 @@ -namespace Bit.Api.Models.Public.Request +namespace Bit.Api.Models.Public.Request; + +public class UpdateMemberIdsRequestModel { - public class UpdateMemberIdsRequestModel - { - /// - /// The associated member ids that have access to this object. - /// - public IEnumerable MemberIds { get; set; } - } + /// + /// The associated member ids that have access to this object. + /// + public IEnumerable MemberIds { get; set; } } diff --git a/src/Api/Models/Public/Response/AssociationWithPermissionsResponseModel.cs b/src/Api/Models/Public/Response/AssociationWithPermissionsResponseModel.cs index 823b35904..04863d9b4 100644 --- a/src/Api/Models/Public/Response/AssociationWithPermissionsResponseModel.cs +++ b/src/Api/Models/Public/Response/AssociationWithPermissionsResponseModel.cs @@ -1,17 +1,16 @@ using Bit.Core.Models.Data; -namespace Bit.Api.Models.Public.Response +namespace Bit.Api.Models.Public.Response; + +public class AssociationWithPermissionsResponseModel : AssociationWithPermissionsBaseModel { - public class AssociationWithPermissionsResponseModel : AssociationWithPermissionsBaseModel + public AssociationWithPermissionsResponseModel(SelectionReadOnly selection) { - public AssociationWithPermissionsResponseModel(SelectionReadOnly selection) + if (selection == null) { - if (selection == null) - { - throw new ArgumentNullException(nameof(selection)); - } - Id = selection.Id; - ReadOnly = selection.ReadOnly; + throw new ArgumentNullException(nameof(selection)); } + Id = selection.Id; + ReadOnly = selection.ReadOnly; } } diff --git a/src/Api/Models/Public/Response/CollectionResponseModel.cs b/src/Api/Models/Public/Response/CollectionResponseModel.cs index 8e318e585..93e484801 100644 --- a/src/Api/Models/Public/Response/CollectionResponseModel.cs +++ b/src/Api/Models/Public/Response/CollectionResponseModel.cs @@ -2,40 +2,39 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Public.Response -{ - /// - /// A collection. - /// - public class CollectionResponseModel : CollectionBaseModel, IResponseModel - { - public CollectionResponseModel(Collection collection, IEnumerable groups) - { - if (collection == null) - { - throw new ArgumentNullException(nameof(collection)); - } +namespace Bit.Api.Models.Public.Response; - Id = collection.Id; - ExternalId = collection.ExternalId; - Groups = groups?.Select(c => new AssociationWithPermissionsResponseModel(c)); +/// +/// A collection. +/// +public class CollectionResponseModel : CollectionBaseModel, IResponseModel +{ + public CollectionResponseModel(Collection collection, IEnumerable groups) + { + if (collection == null) + { + throw new ArgumentNullException(nameof(collection)); } - /// - /// String representing the object's type. Objects of the same type share the same properties. - /// - /// collection - [Required] - public string Object => "collection"; - /// - /// The collection's unique identifier. - /// - /// 539a36c5-e0d2-4cf9-979e-51ecf5cf6593 - [Required] - public Guid Id { get; set; } - /// - /// The associated groups that this collection is assigned to. - /// - public IEnumerable Groups { get; set; } + Id = collection.Id; + ExternalId = collection.ExternalId; + Groups = groups?.Select(c => new AssociationWithPermissionsResponseModel(c)); } + + /// + /// String representing the object's type. Objects of the same type share the same properties. + /// + /// collection + [Required] + public string Object => "collection"; + /// + /// The collection's unique identifier. + /// + /// 539a36c5-e0d2-4cf9-979e-51ecf5cf6593 + [Required] + public Guid Id { get; set; } + /// + /// The associated groups that this collection is assigned to. + /// + public IEnumerable Groups { get; set; } } diff --git a/src/Api/Models/Public/Response/ErrorResponseModel.cs b/src/Api/Models/Public/Response/ErrorResponseModel.cs index dd2f2ba0e..4a4887a0e 100644 --- a/src/Api/Models/Public/Response/ErrorResponseModel.cs +++ b/src/Api/Models/Public/Response/ErrorResponseModel.cs @@ -1,77 +1,76 @@ using System.ComponentModel.DataAnnotations; using Microsoft.AspNetCore.Mvc.ModelBinding; -namespace Bit.Api.Models.Public.Response +namespace Bit.Api.Models.Public.Response; + +public class ErrorResponseModel : IResponseModel { - public class ErrorResponseModel : IResponseModel + public ErrorResponseModel(string message) { - public ErrorResponseModel(string message) - { - Message = message; - } - - public ErrorResponseModel(ModelStateDictionary modelState) - { - Message = "The request's model state is invalid."; - Errors = new Dictionary>(); - - var keys = modelState.Keys.ToList(); - var values = modelState.Values.ToList(); - - for (var i = 0; i < values.Count; i++) - { - var value = values[i]; - if (keys.Count <= i) - { - // Keys not available for some reason. - break; - } - - var key = keys[i]; - if (value.ValidationState != ModelValidationState.Invalid || value.Errors.Count == 0) - { - continue; - } - - var errors = value.Errors.Select(e => e.ErrorMessage); - Errors.Add(key, errors); - } - } - - public ErrorResponseModel(Dictionary> errors) - : this("Errors have occurred.", errors) - { } - - public ErrorResponseModel(string errorKey, string errorValue) - : this(errorKey, new string[] { errorValue }) - { } - - public ErrorResponseModel(string errorKey, IEnumerable errorValues) - : this(new Dictionary> { { errorKey, errorValues } }) - { } - - public ErrorResponseModel(string message, Dictionary> errors) - { - Message = message; - Errors = errors; - } - - /// - /// String representing the object's type. Objects of the same type share the same properties. - /// - /// error - [Required] - public string Object => "error"; - /// - /// A human-readable message providing details about the error. - /// - /// The request model is invalid. - [Required] - public string Message { get; set; } - /// - /// If multiple errors occurred, they are listed in dictionary. Errors related to a specific - /// request parameter will include a dictionary key describing that parameter. - /// - public Dictionary> Errors { get; set; } + Message = message; } + + public ErrorResponseModel(ModelStateDictionary modelState) + { + Message = "The request's model state is invalid."; + Errors = new Dictionary>(); + + var keys = modelState.Keys.ToList(); + var values = modelState.Values.ToList(); + + for (var i = 0; i < values.Count; i++) + { + var value = values[i]; + if (keys.Count <= i) + { + // Keys not available for some reason. + break; + } + + var key = keys[i]; + if (value.ValidationState != ModelValidationState.Invalid || value.Errors.Count == 0) + { + continue; + } + + var errors = value.Errors.Select(e => e.ErrorMessage); + Errors.Add(key, errors); + } + } + + public ErrorResponseModel(Dictionary> errors) + : this("Errors have occurred.", errors) + { } + + public ErrorResponseModel(string errorKey, string errorValue) + : this(errorKey, new string[] { errorValue }) + { } + + public ErrorResponseModel(string errorKey, IEnumerable errorValues) + : this(new Dictionary> { { errorKey, errorValues } }) + { } + + public ErrorResponseModel(string message, Dictionary> errors) + { + Message = message; + Errors = errors; + } + + /// + /// String representing the object's type. Objects of the same type share the same properties. + /// + /// error + [Required] + public string Object => "error"; + /// + /// A human-readable message providing details about the error. + /// + /// The request model is invalid. + [Required] + public string Message { get; set; } + /// + /// If multiple errors occurred, they are listed in dictionary. Errors related to a specific + /// request parameter will include a dictionary key describing that parameter. + /// + public Dictionary> Errors { get; set; } } diff --git a/src/Api/Models/Public/Response/EventResponseModel.cs b/src/Api/Models/Public/Response/EventResponseModel.cs index 4a5f9f652..bc8b77e49 100644 --- a/src/Api/Models/Public/Response/EventResponseModel.cs +++ b/src/Api/Models/Public/Response/EventResponseModel.cs @@ -2,92 +2,91 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Public.Response -{ - /// - /// An event log. - /// - public class EventResponseModel : IResponseModel - { - public EventResponseModel(IEvent ev) - { - if (ev == null) - { - throw new ArgumentNullException(nameof(ev)); - } +namespace Bit.Api.Models.Public.Response; - Type = ev.Type; - ItemId = ev.CipherId; - CollectionId = ev.CollectionId; - GroupId = ev.GroupId; - PolicyId = ev.PolicyId; - MemberId = ev.OrganizationUserId; - ActingUserId = ev.ActingUserId; - Date = ev.Date; - Device = ev.DeviceType; - IpAddress = ev.IpAddress; - InstallationId = ev.InstallationId; +/// +/// An event log. +/// +public class EventResponseModel : IResponseModel +{ + public EventResponseModel(IEvent ev) + { + if (ev == null) + { + throw new ArgumentNullException(nameof(ev)); } - /// - /// String representing the object's type. Objects of the same type share the same properties. - /// - /// event - [Required] - public string Object => "event"; - /// - /// The type of event. - /// - [Required] - public EventType Type { get; set; } - /// - /// The unique identifier of the related item that the event describes. - /// - /// 3767a302-8208-4dc6-b842-030428a1cfad - public Guid? ItemId { get; set; } - /// - /// The unique identifier of the related collection that the event describes. - /// - /// bce212a4-25f3-4888-8a0a-4c5736d851e0 - public Guid? CollectionId { get; set; } - /// - /// The unique identifier of the related group that the event describes. - /// - /// f29a2515-91d2-4452-b49b-5e8040e6b0f4 - public Guid? GroupId { get; set; } - /// - /// The unique identifier of the related policy that the event describes. - /// - /// f29a2515-91d2-4452-b49b-5e8040e6b0f4 - public Guid? PolicyId { get; set; } - /// - /// The unique identifier of the related member that the event describes. - /// - /// e68b8629-85eb-4929-92c0-b84464976ba4 - public Guid? MemberId { get; set; } - /// - /// The unique identifier of the user that performed the event. - /// - /// a2549f79-a71f-4eb9-9234-eb7247333f94 - public Guid? ActingUserId { get; set; } - /// - /// The Unique identifier of the Installation that performed the event. - /// - /// - public Guid? InstallationId { get; set; } - /// - /// The date/timestamp when the event occurred. - /// - [Required] - public DateTime Date { get; set; } - /// - /// The type of device used by the acting user when the event occurred. - /// - public DeviceType? Device { get; set; } - /// - /// The IP address of the acting user. - /// - /// 172.16.254.1 - public string IpAddress { get; set; } + Type = ev.Type; + ItemId = ev.CipherId; + CollectionId = ev.CollectionId; + GroupId = ev.GroupId; + PolicyId = ev.PolicyId; + MemberId = ev.OrganizationUserId; + ActingUserId = ev.ActingUserId; + Date = ev.Date; + Device = ev.DeviceType; + IpAddress = ev.IpAddress; + InstallationId = ev.InstallationId; } + + /// + /// String representing the object's type. Objects of the same type share the same properties. + /// + /// event + [Required] + public string Object => "event"; + /// + /// The type of event. + /// + [Required] + public EventType Type { get; set; } + /// + /// The unique identifier of the related item that the event describes. + /// + /// 3767a302-8208-4dc6-b842-030428a1cfad + public Guid? ItemId { get; set; } + /// + /// The unique identifier of the related collection that the event describes. + /// + /// bce212a4-25f3-4888-8a0a-4c5736d851e0 + public Guid? CollectionId { get; set; } + /// + /// The unique identifier of the related group that the event describes. + /// + /// f29a2515-91d2-4452-b49b-5e8040e6b0f4 + public Guid? GroupId { get; set; } + /// + /// The unique identifier of the related policy that the event describes. + /// + /// f29a2515-91d2-4452-b49b-5e8040e6b0f4 + public Guid? PolicyId { get; set; } + /// + /// The unique identifier of the related member that the event describes. + /// + /// e68b8629-85eb-4929-92c0-b84464976ba4 + public Guid? MemberId { get; set; } + /// + /// The unique identifier of the user that performed the event. + /// + /// a2549f79-a71f-4eb9-9234-eb7247333f94 + public Guid? ActingUserId { get; set; } + /// + /// The Unique identifier of the Installation that performed the event. + /// + /// + public Guid? InstallationId { get; set; } + /// + /// The date/timestamp when the event occurred. + /// + [Required] + public DateTime Date { get; set; } + /// + /// The type of device used by the acting user when the event occurred. + /// + public DeviceType? Device { get; set; } + /// + /// The IP address of the acting user. + /// + /// 172.16.254.1 + public string IpAddress { get; set; } } diff --git a/src/Api/Models/Public/Response/GroupResponseModel.cs b/src/Api/Models/Public/Response/GroupResponseModel.cs index 4c6a76c8f..c2e8df4be 100644 --- a/src/Api/Models/Public/Response/GroupResponseModel.cs +++ b/src/Api/Models/Public/Response/GroupResponseModel.cs @@ -2,42 +2,41 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Public.Response -{ - /// - /// A user group. - /// - public class GroupResponseModel : GroupBaseModel, IResponseModel - { - public GroupResponseModel(Group group, IEnumerable collections) - { - if (group == null) - { - throw new ArgumentNullException(nameof(group)); - } +namespace Bit.Api.Models.Public.Response; - Id = group.Id; - Name = group.Name; - AccessAll = group.AccessAll; - ExternalId = group.ExternalId; - Collections = collections?.Select(c => new AssociationWithPermissionsResponseModel(c)); +/// +/// A user group. +/// +public class GroupResponseModel : GroupBaseModel, IResponseModel +{ + public GroupResponseModel(Group group, IEnumerable collections) + { + if (group == null) + { + throw new ArgumentNullException(nameof(group)); } - /// - /// String representing the object's type. Objects of the same type share the same properties. - /// - /// group - [Required] - public string Object => "group"; - /// - /// The group's unique identifier. - /// - /// 539a36c5-e0d2-4cf9-979e-51ecf5cf6593 - [Required] - public Guid Id { get; set; } - /// - /// The associated collections that this group can access. - /// - public IEnumerable Collections { get; set; } + Id = group.Id; + Name = group.Name; + AccessAll = group.AccessAll; + ExternalId = group.ExternalId; + Collections = collections?.Select(c => new AssociationWithPermissionsResponseModel(c)); } + + /// + /// String representing the object's type. Objects of the same type share the same properties. + /// + /// group + [Required] + public string Object => "group"; + /// + /// The group's unique identifier. + /// + /// 539a36c5-e0d2-4cf9-979e-51ecf5cf6593 + [Required] + public Guid Id { get; set; } + /// + /// The associated collections that this group can access. + /// + public IEnumerable Collections { get; set; } } diff --git a/src/Api/Models/Public/Response/IResponseModel.cs b/src/Api/Models/Public/Response/IResponseModel.cs index 3e3333073..1032f5276 100644 --- a/src/Api/Models/Public/Response/IResponseModel.cs +++ b/src/Api/Models/Public/Response/IResponseModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Api.Models.Public.Response +namespace Bit.Api.Models.Public.Response; + +public interface IResponseModel { - public interface IResponseModel - { - string Object { get; } - } + string Object { get; } } diff --git a/src/Api/Models/Public/Response/ListResponseModel.cs b/src/Api/Models/Public/Response/ListResponseModel.cs index 78328c3e1..0865be3e8 100644 --- a/src/Api/Models/Public/Response/ListResponseModel.cs +++ b/src/Api/Models/Public/Response/ListResponseModel.cs @@ -1,29 +1,28 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Public.Response -{ - public class ListResponseModel : IResponseModel where T : IResponseModel - { - public ListResponseModel(IEnumerable data, string continuationToken = null) - { - Data = data; - ContinuationToken = continuationToken; - } +namespace Bit.Api.Models.Public.Response; - /// - /// String representing the object's type. Objects of the same type share the same properties. - /// - /// list - [Required] - public string Object => "list"; - /// - /// An array containing the actual response elements, paginated by any request parameters. - /// - [Required] - public IEnumerable Data { get; set; } - /// - /// A cursor for use in pagination. - /// - public string ContinuationToken { get; set; } +public class ListResponseModel : IResponseModel where T : IResponseModel +{ + public ListResponseModel(IEnumerable data, string continuationToken = null) + { + Data = data; + ContinuationToken = continuationToken; } + + /// + /// String representing the object's type. Objects of the same type share the same properties. + /// + /// list + [Required] + public string Object => "list"; + /// + /// An array containing the actual response elements, paginated by any request parameters. + /// + [Required] + public IEnumerable Data { get; set; } + /// + /// A cursor for use in pagination. + /// + public string ContinuationToken { get; set; } } diff --git a/src/Api/Models/Public/Response/MemberResponseModel.cs b/src/Api/Models/Public/Response/MemberResponseModel.cs index ceac9fca2..ccb8a8c95 100644 --- a/src/Api/Models/Public/Response/MemberResponseModel.cs +++ b/src/Api/Models/Public/Response/MemberResponseModel.cs @@ -4,91 +4,90 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Api.Models.Public.Response +namespace Bit.Api.Models.Public.Response; + +/// +/// An organization member. +/// +public class MemberResponseModel : MemberBaseModel, IResponseModel { - /// - /// An organization member. - /// - public class MemberResponseModel : MemberBaseModel, IResponseModel + public MemberResponseModel(OrganizationUser user, IEnumerable collections) + : base(user) { - public MemberResponseModel(OrganizationUser user, IEnumerable collections) - : base(user) + if (user == null) { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - Id = user.Id; - UserId = user.UserId; - Email = user.Email; - Status = user.Status; - Collections = collections?.Select(c => new AssociationWithPermissionsResponseModel(c)); + throw new ArgumentNullException(nameof(user)); } - public MemberResponseModel(OrganizationUserUserDetails user, bool twoFactorEnabled, - IEnumerable collections) - : base(user) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - Id = user.Id; - UserId = user.UserId; - Name = user.Name; - Email = user.Email; - TwoFactorEnabled = twoFactorEnabled; - Status = user.Status; - Collections = collections?.Select(c => new AssociationWithPermissionsResponseModel(c)); - } - - /// - /// String representing the object's type. Objects of the same type share the same properties. - /// - /// member - [Required] - public string Object => "member"; - /// - /// The member's unique identifier within the organization. - /// - /// 539a36c5-e0d2-4cf9-979e-51ecf5cf6593 - [Required] - public Guid Id { get; set; } - /// - /// The member's unique identifier across Bitwarden. - /// - /// 48b47ee1-493e-4c67-aef7-014996c40eca - [Required] - public Guid? UserId { get; set; } - /// - /// The member's name, set from their user account profile. - /// - /// John Smith - public string Name { get; set; } - /// - /// The member's email address. - /// - /// jsmith@example.com - [Required] - public string Email { get; set; } - /// - /// Returns true if the member has a two-step login method enabled on their user account. - /// - [Required] - public bool TwoFactorEnabled { get; set; } - /// - /// The member's status within the organization. All created members start with a status of "Invited". - /// Once a member accept's their invitation to join the organization, their status changes to "Accepted". - /// Accepted members are then "Confirmed" by an organization administrator. Once a member is "Confirmed", - /// their status can no longer change. - /// - [Required] - public OrganizationUserStatusType Status { get; set; } - /// - /// The associated collections that this member can access. - /// - public IEnumerable Collections { get; set; } + Id = user.Id; + UserId = user.UserId; + Email = user.Email; + Status = user.Status; + Collections = collections?.Select(c => new AssociationWithPermissionsResponseModel(c)); } + + public MemberResponseModel(OrganizationUserUserDetails user, bool twoFactorEnabled, + IEnumerable collections) + : base(user) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + Id = user.Id; + UserId = user.UserId; + Name = user.Name; + Email = user.Email; + TwoFactorEnabled = twoFactorEnabled; + Status = user.Status; + Collections = collections?.Select(c => new AssociationWithPermissionsResponseModel(c)); + } + + /// + /// String representing the object's type. Objects of the same type share the same properties. + /// + /// member + [Required] + public string Object => "member"; + /// + /// The member's unique identifier within the organization. + /// + /// 539a36c5-e0d2-4cf9-979e-51ecf5cf6593 + [Required] + public Guid Id { get; set; } + /// + /// The member's unique identifier across Bitwarden. + /// + /// 48b47ee1-493e-4c67-aef7-014996c40eca + [Required] + public Guid? UserId { get; set; } + /// + /// The member's name, set from their user account profile. + /// + /// John Smith + public string Name { get; set; } + /// + /// The member's email address. + /// + /// jsmith@example.com + [Required] + public string Email { get; set; } + /// + /// Returns true if the member has a two-step login method enabled on their user account. + /// + [Required] + public bool TwoFactorEnabled { get; set; } + /// + /// The member's status within the organization. All created members start with a status of "Invited". + /// Once a member accept's their invitation to join the organization, their status changes to "Accepted". + /// Accepted members are then "Confirmed" by an organization administrator. Once a member is "Confirmed", + /// their status can no longer change. + /// + [Required] + public OrganizationUserStatusType Status { get; set; } + /// + /// The associated collections that this member can access. + /// + public IEnumerable Collections { get; set; } } diff --git a/src/Api/Models/Public/Response/PolicyResponseModel.cs b/src/Api/Models/Public/Response/PolicyResponseModel.cs index 9806c96d0..b30c28322 100644 --- a/src/Api/Models/Public/Response/PolicyResponseModel.cs +++ b/src/Api/Models/Public/Response/PolicyResponseModel.cs @@ -3,45 +3,44 @@ using System.Text.Json; using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Api.Models.Public.Response -{ - /// - /// A policy. - /// - public class PolicyResponseModel : PolicyBaseModel, IResponseModel - { - public PolicyResponseModel(Policy policy) - { - if (policy == null) - { - throw new ArgumentNullException(nameof(policy)); - } +namespace Bit.Api.Models.Public.Response; - Id = policy.Id; - Type = policy.Type; - Enabled = policy.Enabled; - if (!string.IsNullOrWhiteSpace(policy.Data)) - { - Data = JsonSerializer.Deserialize>(policy.Data); - } +/// +/// A policy. +/// +public class PolicyResponseModel : PolicyBaseModel, IResponseModel +{ + public PolicyResponseModel(Policy policy) + { + if (policy == null) + { + throw new ArgumentNullException(nameof(policy)); } - /// - /// String representing the object's type. Objects of the same type share the same properties. - /// - /// policy - [Required] - public string Object => "policy"; - /// - /// The policy's unique identifier. - /// - /// 539a36c5-e0d2-4cf9-979e-51ecf5cf6593 - [Required] - public Guid Id { get; set; } - /// - /// The type of policy. - /// - [Required] - public PolicyType? Type { get; set; } + Id = policy.Id; + Type = policy.Type; + Enabled = policy.Enabled; + if (!string.IsNullOrWhiteSpace(policy.Data)) + { + Data = JsonSerializer.Deserialize>(policy.Data); + } } + + /// + /// String representing the object's type. Objects of the same type share the same properties. + /// + /// policy + [Required] + public string Object => "policy"; + /// + /// The policy's unique identifier. + /// + /// 539a36c5-e0d2-4cf9-979e-51ecf5cf6593 + [Required] + public Guid Id { get; set; } + /// + /// The type of policy. + /// + [Required] + public PolicyType? Type { get; set; } } diff --git a/src/Api/Models/Request/Accounts/DeleteRecoverRequestModel.cs b/src/Api/Models/Request/Accounts/DeleteRecoverRequestModel.cs index 635d878a5..541df9a81 100644 --- a/src/Api/Models/Request/Accounts/DeleteRecoverRequestModel.cs +++ b/src/Api/Models/Request/Accounts/DeleteRecoverRequestModel.cs @@ -1,12 +1,11 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts +namespace Bit.Api.Models.Request.Accounts; + +public class DeleteRecoverRequestModel { - public class DeleteRecoverRequestModel - { - [Required] - [EmailAddress] - [StringLength(256)] - public string Email { get; set; } - } + [Required] + [EmailAddress] + [StringLength(256)] + public string Email { get; set; } } diff --git a/src/Api/Models/Request/Accounts/EmailRequestModel.cs b/src/Api/Models/Request/Accounts/EmailRequestModel.cs index 7eabe3e2e..54e8bfbcc 100644 --- a/src/Api/Models/Request/Accounts/EmailRequestModel.cs +++ b/src/Api/Models/Request/Accounts/EmailRequestModel.cs @@ -1,20 +1,19 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request.Accounts +namespace Bit.Api.Models.Request.Accounts; + +public class EmailRequestModel : SecretVerificationRequestModel { - public class EmailRequestModel : SecretVerificationRequestModel - { - [Required] - [StrictEmailAddress] - [StringLength(256)] - public string NewEmail { get; set; } - [Required] - [StringLength(300)] - public string NewMasterPasswordHash { get; set; } - [Required] - public string Token { get; set; } - [Required] - public string Key { get; set; } - } + [Required] + [StrictEmailAddress] + [StringLength(256)] + public string NewEmail { get; set; } + [Required] + [StringLength(300)] + public string NewMasterPasswordHash { get; set; } + [Required] + public string Token { get; set; } + [Required] + public string Key { get; set; } } diff --git a/src/Api/Models/Request/Accounts/EmailTokenRequestModel.cs b/src/Api/Models/Request/Accounts/EmailTokenRequestModel.cs index 298b5918d..c4c4f7814 100644 --- a/src/Api/Models/Request/Accounts/EmailTokenRequestModel.cs +++ b/src/Api/Models/Request/Accounts/EmailTokenRequestModel.cs @@ -1,13 +1,12 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request.Accounts +namespace Bit.Api.Models.Request.Accounts; + +public class EmailTokenRequestModel : SecretVerificationRequestModel { - public class EmailTokenRequestModel : SecretVerificationRequestModel - { - [Required] - [StrictEmailAddress] - [StringLength(256)] - public string NewEmail { get; set; } - } + [Required] + [StrictEmailAddress] + [StringLength(256)] + public string NewEmail { get; set; } } diff --git a/src/Api/Models/Request/Accounts/ImportCiphersRequestModel.cs b/src/Api/Models/Request/Accounts/ImportCiphersRequestModel.cs index 321fef658..2a675fa48 100644 --- a/src/Api/Models/Request/Accounts/ImportCiphersRequestModel.cs +++ b/src/Api/Models/Request/Accounts/ImportCiphersRequestModel.cs @@ -1,9 +1,8 @@ -namespace Bit.Api.Models.Request.Accounts +namespace Bit.Api.Models.Request.Accounts; + +public class ImportCiphersRequestModel { - public class ImportCiphersRequestModel - { - public FolderRequestModel[] Folders { get; set; } - public CipherRequestModel[] Ciphers { get; set; } - public KeyValuePair[] FolderRelationships { get; set; } - } + public FolderRequestModel[] Folders { get; set; } + public CipherRequestModel[] Ciphers { get; set; } + public KeyValuePair[] FolderRelationships { get; set; } } diff --git a/src/Api/Models/Request/Accounts/KdfRequestModel.cs b/src/Api/Models/Request/Accounts/KdfRequestModel.cs index eea6ad201..ac920c7db 100644 --- a/src/Api/Models/Request/Accounts/KdfRequestModel.cs +++ b/src/Api/Models/Request/Accounts/KdfRequestModel.cs @@ -1,30 +1,29 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Enums; -namespace Bit.Api.Models.Request.Accounts -{ - public class KdfRequestModel : PasswordRequestModel, IValidatableObject - { - [Required] - public KdfType? Kdf { get; set; } - [Required] - public int? KdfIterations { get; set; } +namespace Bit.Api.Models.Request.Accounts; - public IEnumerable Validate(ValidationContext validationContext) +public class KdfRequestModel : PasswordRequestModel, IValidatableObject +{ + [Required] + public KdfType? Kdf { get; set; } + [Required] + public int? KdfIterations { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (Kdf.HasValue && KdfIterations.HasValue) { - if (Kdf.HasValue && KdfIterations.HasValue) + switch (Kdf.Value) { - switch (Kdf.Value) - { - case KdfType.PBKDF2_SHA256: - if (KdfIterations.Value < 5000 || KdfIterations.Value > 2_000_000) - { - yield return new ValidationResult("KDF iterations must be between 5000 and 2000000."); - } - break; - default: - break; - } + case KdfType.PBKDF2_SHA256: + if (KdfIterations.Value < 5000 || KdfIterations.Value > 2_000_000) + { + yield return new ValidationResult("KDF iterations must be between 5000 and 2000000."); + } + break; + default: + break; } } } diff --git a/src/Api/Models/Request/Accounts/OrganizationApiKeyRequestModel.cs b/src/Api/Models/Request/Accounts/OrganizationApiKeyRequestModel.cs index 331cd7045..c7e840818 100644 --- a/src/Api/Models/Request/Accounts/OrganizationApiKeyRequestModel.cs +++ b/src/Api/Models/Request/Accounts/OrganizationApiKeyRequestModel.cs @@ -1,9 +1,8 @@ using Bit.Core.Enums; -namespace Bit.Api.Models.Request.Accounts +namespace Bit.Api.Models.Request.Accounts; + +public class OrganizationApiKeyRequestModel : SecretVerificationRequestModel { - public class OrganizationApiKeyRequestModel : SecretVerificationRequestModel - { - public OrganizationApiKeyType Type { get; set; } - } + public OrganizationApiKeyType Type { get; set; } } diff --git a/src/Api/Models/Request/Accounts/PasswordHintRequestModel.cs b/src/Api/Models/Request/Accounts/PasswordHintRequestModel.cs index 148ced2b2..340a89be2 100644 --- a/src/Api/Models/Request/Accounts/PasswordHintRequestModel.cs +++ b/src/Api/Models/Request/Accounts/PasswordHintRequestModel.cs @@ -1,12 +1,11 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts +namespace Bit.Api.Models.Request.Accounts; + +public class PasswordHintRequestModel { - public class PasswordHintRequestModel - { - [Required] - [EmailAddress] - [StringLength(256)] - public string Email { get; set; } - } + [Required] + [EmailAddress] + [StringLength(256)] + public string Email { get; set; } } diff --git a/src/Api/Models/Request/Accounts/PasswordRequestModel.cs b/src/Api/Models/Request/Accounts/PasswordRequestModel.cs index 0df96f527..d7c22da4b 100644 --- a/src/Api/Models/Request/Accounts/PasswordRequestModel.cs +++ b/src/Api/Models/Request/Accounts/PasswordRequestModel.cs @@ -1,15 +1,14 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts +namespace Bit.Api.Models.Request.Accounts; + +public class PasswordRequestModel : SecretVerificationRequestModel { - public class PasswordRequestModel : SecretVerificationRequestModel - { - [Required] - [StringLength(300)] - public string NewMasterPasswordHash { get; set; } - [StringLength(50)] - public string MasterPasswordHint { get; set; } - [Required] - public string Key { get; set; } - } + [Required] + [StringLength(300)] + public string NewMasterPasswordHash { get; set; } + [StringLength(50)] + public string MasterPasswordHint { get; set; } + [Required] + public string Key { get; set; } } diff --git a/src/Api/Models/Request/Accounts/PremiumRequestModel.cs b/src/Api/Models/Request/Accounts/PremiumRequestModel.cs index 3fd95b1ce..26d199381 100644 --- a/src/Api/Models/Request/Accounts/PremiumRequestModel.cs +++ b/src/Api/Models/Request/Accounts/PremiumRequestModel.cs @@ -2,41 +2,40 @@ using Bit.Core.Settings; using Enums = Bit.Core.Enums; -namespace Bit.Api.Models.Request.Accounts +namespace Bit.Api.Models.Request.Accounts; + +public class PremiumRequestModel : IValidatableObject { - public class PremiumRequestModel : IValidatableObject + [Required] + public Enums.PaymentMethodType? PaymentMethodType { get; set; } + public string PaymentToken { get; set; } + [Range(0, 99)] + public short? AdditionalStorageGb { get; set; } + public IFormFile License { get; set; } + public string Country { get; set; } + public string PostalCode { get; set; } + + public bool Validate(GlobalSettings globalSettings) { - [Required] - public Enums.PaymentMethodType? PaymentMethodType { get; set; } - public string PaymentToken { get; set; } - [Range(0, 99)] - public short? AdditionalStorageGb { get; set; } - public IFormFile License { get; set; } - public string Country { get; set; } - public string PostalCode { get; set; } - - public bool Validate(GlobalSettings globalSettings) + if (!(License == null && !globalSettings.SelfHosted) || + (License != null && globalSettings.SelfHosted)) { - if (!(License == null && !globalSettings.SelfHosted) || - (License != null && globalSettings.SelfHosted)) - { - return false; - } - return globalSettings.SelfHosted || !string.IsNullOrWhiteSpace(Country); + return false; } + return globalSettings.SelfHosted || !string.IsNullOrWhiteSpace(Country); + } - public IEnumerable Validate(ValidationContext validationContext) + public IEnumerable Validate(ValidationContext validationContext) + { + var creditType = PaymentMethodType.HasValue && PaymentMethodType.Value == Enums.PaymentMethodType.Credit; + if (string.IsNullOrWhiteSpace(PaymentToken) && !creditType && License == null) { - var creditType = PaymentMethodType.HasValue && PaymentMethodType.Value == Enums.PaymentMethodType.Credit; - if (string.IsNullOrWhiteSpace(PaymentToken) && !creditType && License == null) - { - yield return new ValidationResult("Payment token or license is required."); - } - if (Country == "US" && string.IsNullOrWhiteSpace(PostalCode)) - { - yield return new ValidationResult("Zip / postal code is required.", - new string[] { nameof(PostalCode) }); - } + yield return new ValidationResult("Payment token or license is required."); + } + if (Country == "US" && string.IsNullOrWhiteSpace(PostalCode)) + { + yield return new ValidationResult("Zip / postal code is required.", + new string[] { nameof(PostalCode) }); } } } diff --git a/src/Api/Models/Request/Accounts/RegenerateTwoFactorRequestModel.cs b/src/Api/Models/Request/Accounts/RegenerateTwoFactorRequestModel.cs index 06a6148d4..329b3a0c3 100644 --- a/src/Api/Models/Request/Accounts/RegenerateTwoFactorRequestModel.cs +++ b/src/Api/Models/Request/Accounts/RegenerateTwoFactorRequestModel.cs @@ -1,13 +1,12 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts +namespace Bit.Api.Models.Request.Accounts; + +public class RegenerateTwoFactorRequestModel { - public class RegenerateTwoFactorRequestModel - { - [Required] - public string MasterPasswordHash { get; set; } - [Required] - [StringLength(50)] - public string Token { get; set; } - } + [Required] + public string MasterPasswordHash { get; set; } + [Required] + [StringLength(50)] + public string Token { get; set; } } diff --git a/src/Api/Models/Request/Accounts/SecretVerificationRequestModel.cs b/src/Api/Models/Request/Accounts/SecretVerificationRequestModel.cs index e1042d5a3..f35ea9677 100644 --- a/src/Api/Models/Request/Accounts/SecretVerificationRequestModel.cs +++ b/src/Api/Models/Request/Accounts/SecretVerificationRequestModel.cs @@ -1,20 +1,19 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts -{ - public class SecretVerificationRequestModel : IValidatableObject - { - [StringLength(300)] - public string MasterPasswordHash { get; set; } - public string OTP { get; set; } - public string Secret => !string.IsNullOrEmpty(MasterPasswordHash) ? MasterPasswordHash : OTP; +namespace Bit.Api.Models.Request.Accounts; - public virtual IEnumerable Validate(ValidationContext validationContext) +public class SecretVerificationRequestModel : IValidatableObject +{ + [StringLength(300)] + public string MasterPasswordHash { get; set; } + public string OTP { get; set; } + public string Secret => !string.IsNullOrEmpty(MasterPasswordHash) ? MasterPasswordHash : OTP; + + public virtual IEnumerable Validate(ValidationContext validationContext) + { + if (string.IsNullOrEmpty(Secret)) { - if (string.IsNullOrEmpty(Secret)) - { - yield return new ValidationResult("MasterPasswordHash or OTP must be supplied."); - } + yield return new ValidationResult("MasterPasswordHash or OTP must be supplied."); } } } diff --git a/src/Api/Models/Request/Accounts/SetKeyConnectorKeyRequestModel.cs b/src/Api/Models/Request/Accounts/SetKeyConnectorKeyRequestModel.cs index 39c17bc4b..a4906b1b5 100644 --- a/src/Api/Models/Request/Accounts/SetKeyConnectorKeyRequestModel.cs +++ b/src/Api/Models/Request/Accounts/SetKeyConnectorKeyRequestModel.cs @@ -3,28 +3,27 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Api.Request.Accounts; -namespace Bit.Api.Models.Request.Accounts -{ - public class SetKeyConnectorKeyRequestModel - { - [Required] - public string Key { get; set; } - [Required] - public KeysRequestModel Keys { get; set; } - [Required] - public KdfType Kdf { get; set; } - [Required] - public int KdfIterations { get; set; } - [Required] - public string OrgIdentifier { get; set; } +namespace Bit.Api.Models.Request.Accounts; - public User ToUser(User existingUser) - { - existingUser.Kdf = Kdf; - existingUser.KdfIterations = KdfIterations; - existingUser.Key = Key; - Keys.ToUser(existingUser); - return existingUser; - } +public class SetKeyConnectorKeyRequestModel +{ + [Required] + public string Key { get; set; } + [Required] + public KeysRequestModel Keys { get; set; } + [Required] + public KdfType Kdf { get; set; } + [Required] + public int KdfIterations { get; set; } + [Required] + public string OrgIdentifier { get; set; } + + public User ToUser(User existingUser) + { + existingUser.Kdf = Kdf; + existingUser.KdfIterations = KdfIterations; + existingUser.Key = Key; + Keys.ToUser(existingUser); + return existingUser; } } diff --git a/src/Api/Models/Request/Accounts/SetPasswordRequestModel.cs b/src/Api/Models/Request/Accounts/SetPasswordRequestModel.cs index 287ba8769..8a345001c 100644 --- a/src/Api/Models/Request/Accounts/SetPasswordRequestModel.cs +++ b/src/Api/Models/Request/Accounts/SetPasswordRequestModel.cs @@ -3,33 +3,32 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Api.Request.Accounts; -namespace Bit.Api.Models.Request.Accounts -{ - public class SetPasswordRequestModel - { - [Required] - [StringLength(300)] - public string MasterPasswordHash { get; set; } - [Required] - public string Key { get; set; } - [StringLength(50)] - public string MasterPasswordHint { get; set; } - [Required] - public KeysRequestModel Keys { get; set; } - [Required] - public KdfType Kdf { get; set; } - [Required] - public int KdfIterations { get; set; } - public string OrgIdentifier { get; set; } +namespace Bit.Api.Models.Request.Accounts; - public User ToUser(User existingUser) - { - existingUser.MasterPasswordHint = MasterPasswordHint; - existingUser.Kdf = Kdf; - existingUser.KdfIterations = KdfIterations; - existingUser.Key = Key; - Keys.ToUser(existingUser); - return existingUser; - } +public class SetPasswordRequestModel +{ + [Required] + [StringLength(300)] + public string MasterPasswordHash { get; set; } + [Required] + public string Key { get; set; } + [StringLength(50)] + public string MasterPasswordHint { get; set; } + [Required] + public KeysRequestModel Keys { get; set; } + [Required] + public KdfType Kdf { get; set; } + [Required] + public int KdfIterations { get; set; } + public string OrgIdentifier { get; set; } + + public User ToUser(User existingUser) + { + existingUser.MasterPasswordHint = MasterPasswordHint; + existingUser.Kdf = Kdf; + existingUser.KdfIterations = KdfIterations; + existingUser.Key = Key; + Keys.ToUser(existingUser); + return existingUser; } } diff --git a/src/Api/Models/Request/Accounts/StorageRequestModel.cs b/src/Api/Models/Request/Accounts/StorageRequestModel.cs index beb7c189b..397da7411 100644 --- a/src/Api/Models/Request/Accounts/StorageRequestModel.cs +++ b/src/Api/Models/Request/Accounts/StorageRequestModel.cs @@ -1,19 +1,18 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts -{ - public class StorageRequestModel : IValidatableObject - { - [Required] - public short? StorageGbAdjustment { get; set; } +namespace Bit.Api.Models.Request.Accounts; - public IEnumerable Validate(ValidationContext validationContext) +public class StorageRequestModel : IValidatableObject +{ + [Required] + public short? StorageGbAdjustment { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (StorageGbAdjustment == 0) { - if (StorageGbAdjustment == 0) - { - yield return new ValidationResult("Storage adjustment cannot be 0.", - new string[] { nameof(StorageGbAdjustment) }); - } + yield return new ValidationResult("Storage adjustment cannot be 0.", + new string[] { nameof(StorageGbAdjustment) }); } } } diff --git a/src/Api/Models/Request/Accounts/TaxInfoUpdateRequestModel.cs b/src/Api/Models/Request/Accounts/TaxInfoUpdateRequestModel.cs index 205356e68..f51580408 100644 --- a/src/Api/Models/Request/Accounts/TaxInfoUpdateRequestModel.cs +++ b/src/Api/Models/Request/Accounts/TaxInfoUpdateRequestModel.cs @@ -1,20 +1,19 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts -{ - public class TaxInfoUpdateRequestModel : IValidatableObject - { - [Required] - public string Country { get; set; } - public string PostalCode { get; set; } +namespace Bit.Api.Models.Request.Accounts; - public virtual IEnumerable Validate(ValidationContext validationContext) +public class TaxInfoUpdateRequestModel : IValidatableObject +{ + [Required] + public string Country { get; set; } + public string PostalCode { get; set; } + + public virtual IEnumerable Validate(ValidationContext validationContext) + { + if (Country == "US" && string.IsNullOrWhiteSpace(PostalCode)) { - if (Country == "US" && string.IsNullOrWhiteSpace(PostalCode)) - { - yield return new ValidationResult("Zip / postal code is required.", - new string[] { nameof(PostalCode) }); - } + yield return new ValidationResult("Zip / postal code is required.", + new string[] { nameof(PostalCode) }); } } } diff --git a/src/Api/Models/Request/Accounts/UpdateKeyRequestModel.cs b/src/Api/Models/Request/Accounts/UpdateKeyRequestModel.cs index 31ac2d830..2064c09b9 100644 --- a/src/Api/Models/Request/Accounts/UpdateKeyRequestModel.cs +++ b/src/Api/Models/Request/Accounts/UpdateKeyRequestModel.cs @@ -1,20 +1,19 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts +namespace Bit.Api.Models.Request.Accounts; + +public class UpdateKeyRequestModel { - public class UpdateKeyRequestModel - { - [Required] - [StringLength(300)] - public string MasterPasswordHash { get; set; } - [Required] - public IEnumerable Ciphers { get; set; } - [Required] - public IEnumerable Folders { get; set; } - public IEnumerable Sends { get; set; } - [Required] - public string PrivateKey { get; set; } - [Required] - public string Key { get; set; } - } + [Required] + [StringLength(300)] + public string MasterPasswordHash { get; set; } + [Required] + public IEnumerable Ciphers { get; set; } + [Required] + public IEnumerable Folders { get; set; } + public IEnumerable Sends { get; set; } + [Required] + public string PrivateKey { get; set; } + [Required] + public string Key { get; set; } } diff --git a/src/Api/Models/Request/Accounts/UpdateProfileRequestModel.cs b/src/Api/Models/Request/Accounts/UpdateProfileRequestModel.cs index 9f8506dfc..fd625fe9d 100644 --- a/src/Api/Models/Request/Accounts/UpdateProfileRequestModel.cs +++ b/src/Api/Models/Request/Accounts/UpdateProfileRequestModel.cs @@ -1,21 +1,20 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Entities; -namespace Bit.Api.Models.Request.Accounts -{ - public class UpdateProfileRequestModel - { - [StringLength(50)] - public string Name { get; set; } - [StringLength(50)] - [Obsolete("Changes will be made via the 'password' endpoint going forward.")] - public string MasterPasswordHint { get; set; } +namespace Bit.Api.Models.Request.Accounts; - public User ToUser(User existingUser) - { - existingUser.Name = Name; - existingUser.MasterPasswordHint = string.IsNullOrWhiteSpace(MasterPasswordHint) ? null : MasterPasswordHint; - return existingUser; - } +public class UpdateProfileRequestModel +{ + [StringLength(50)] + public string Name { get; set; } + [StringLength(50)] + [Obsolete("Changes will be made via the 'password' endpoint going forward.")] + public string MasterPasswordHint { get; set; } + + public User ToUser(User existingUser) + { + existingUser.Name = Name; + existingUser.MasterPasswordHint = string.IsNullOrWhiteSpace(MasterPasswordHint) ? null : MasterPasswordHint; + return existingUser; } } diff --git a/src/Api/Models/Request/Accounts/UpdateTempPasswordRequestModel.cs b/src/Api/Models/Request/Accounts/UpdateTempPasswordRequestModel.cs index db1c0dbd7..94bfabeee 100644 --- a/src/Api/Models/Request/Accounts/UpdateTempPasswordRequestModel.cs +++ b/src/Api/Models/Request/Accounts/UpdateTempPasswordRequestModel.cs @@ -1,11 +1,10 @@ using System.ComponentModel.DataAnnotations; using Bit.Api.Models.Request.Organizations; -namespace Bit.Api.Models.Request.Accounts +namespace Bit.Api.Models.Request.Accounts; + +public class UpdateTempPasswordRequestModel : OrganizationUserResetPasswordRequestModel { - public class UpdateTempPasswordRequestModel : OrganizationUserResetPasswordRequestModel - { - [StringLength(50)] - public string MasterPasswordHint { get; set; } - } + [StringLength(50)] + public string MasterPasswordHint { get; set; } } diff --git a/src/Api/Models/Request/Accounts/VerifyDeleteRecoverRequestModel.cs b/src/Api/Models/Request/Accounts/VerifyDeleteRecoverRequestModel.cs index 463750722..1faaade2b 100644 --- a/src/Api/Models/Request/Accounts/VerifyDeleteRecoverRequestModel.cs +++ b/src/Api/Models/Request/Accounts/VerifyDeleteRecoverRequestModel.cs @@ -1,12 +1,11 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts +namespace Bit.Api.Models.Request.Accounts; + +public class VerifyDeleteRecoverRequestModel { - public class VerifyDeleteRecoverRequestModel - { - [Required] - public string UserId { get; set; } - [Required] - public string Token { get; set; } - } + [Required] + public string UserId { get; set; } + [Required] + public string Token { get; set; } } diff --git a/src/Api/Models/Request/Accounts/VerifyEmailRequestModel.cs b/src/Api/Models/Request/Accounts/VerifyEmailRequestModel.cs index d85996681..2e8820e1d 100644 --- a/src/Api/Models/Request/Accounts/VerifyEmailRequestModel.cs +++ b/src/Api/Models/Request/Accounts/VerifyEmailRequestModel.cs @@ -1,12 +1,11 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts +namespace Bit.Api.Models.Request.Accounts; + +public class VerifyEmailRequestModel { - public class VerifyEmailRequestModel - { - [Required] - public string UserId { get; set; } - [Required] - public string Token { get; set; } - } + [Required] + public string UserId { get; set; } + [Required] + public string Token { get; set; } } diff --git a/src/Api/Models/Request/Accounts/VerifyOTPRequestModel.cs b/src/Api/Models/Request/Accounts/VerifyOTPRequestModel.cs index 6466aee7e..63e37cdf1 100644 --- a/src/Api/Models/Request/Accounts/VerifyOTPRequestModel.cs +++ b/src/Api/Models/Request/Accounts/VerifyOTPRequestModel.cs @@ -1,10 +1,9 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Accounts +namespace Bit.Api.Models.Request.Accounts; + +public class VerifyOTPRequestModel { - public class VerifyOTPRequestModel - { - [Required] - public string OTP { get; set; } - } + [Required] + public string OTP { get; set; } } diff --git a/src/Api/Models/Request/AttachmentRequestModel.cs b/src/Api/Models/Request/AttachmentRequestModel.cs index b5ca4fb61..cadeccdc0 100644 --- a/src/Api/Models/Request/AttachmentRequestModel.cs +++ b/src/Api/Models/Request/AttachmentRequestModel.cs @@ -1,10 +1,9 @@ -namespace Bit.Api.Models.Request +namespace Bit.Api.Models.Request; + +public class AttachmentRequestModel { - public class AttachmentRequestModel - { - public string Key { get; set; } - public string FileName { get; set; } - public long FileSize { get; set; } - public bool AdminRequest { get; set; } = false; - } + public string Key { get; set; } + public string FileName { get; set; } + public long FileSize { get; set; } + public bool AdminRequest { get; set; } = false; } diff --git a/src/Api/Models/Request/BitPayInvoiceRequestModel.cs b/src/Api/Models/Request/BitPayInvoiceRequestModel.cs index 9e87cca00..ce1d98638 100644 --- a/src/Api/Models/Request/BitPayInvoiceRequestModel.cs +++ b/src/Api/Models/Request/BitPayInvoiceRequestModel.cs @@ -1,66 +1,65 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Settings; -namespace Bit.Api.Models.Request +namespace Bit.Api.Models.Request; + +public class BitPayInvoiceRequestModel : IValidatableObject { - public class BitPayInvoiceRequestModel : IValidatableObject + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public bool Credit { get; set; } + [Required] + public decimal? Amount { get; set; } + public string ReturnUrl { get; set; } + public string Name { get; set; } + public string Email { get; set; } + + public BitPayLight.Models.Invoice.Invoice ToBitpayInvoice(GlobalSettings globalSettings) { - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public bool Credit { get; set; } - [Required] - public decimal? Amount { get; set; } - public string ReturnUrl { get; set; } - public string Name { get; set; } - public string Email { get; set; } - - public BitPayLight.Models.Invoice.Invoice ToBitpayInvoice(GlobalSettings globalSettings) + var inv = new BitPayLight.Models.Invoice.Invoice { - var inv = new BitPayLight.Models.Invoice.Invoice + Price = Convert.ToDouble(Amount.Value), + Currency = "USD", + RedirectUrl = ReturnUrl, + Buyer = new BitPayLight.Models.Invoice.Buyer { - Price = Convert.ToDouble(Amount.Value), - Currency = "USD", - RedirectUrl = ReturnUrl, - Buyer = new BitPayLight.Models.Invoice.Buyer - { - Email = Email, - Name = Name - }, - NotificationUrl = globalSettings.BitPay.NotificationUrl, - FullNotifications = true, - ExtendedNotifications = true - }; + Email = Email, + Name = Name + }, + NotificationUrl = globalSettings.BitPay.NotificationUrl, + FullNotifications = true, + ExtendedNotifications = true + }; - var posData = string.Empty; - if (UserId.HasValue) - { - posData = "userId:" + UserId.Value; - } - else if (OrganizationId.HasValue) - { - posData = "organizationId:" + OrganizationId.Value; - } - - if (Credit) - { - posData += ",accountCredit:1"; - inv.ItemDesc = "Bitwarden Account Credit"; - } - else - { - inv.ItemDesc = "Bitwarden"; - } - - inv.PosData = posData; - return inv; + var posData = string.Empty; + if (UserId.HasValue) + { + posData = "userId:" + UserId.Value; + } + else if (OrganizationId.HasValue) + { + posData = "organizationId:" + OrganizationId.Value; } - public IEnumerable Validate(ValidationContext validationContext) + if (Credit) { - if (!UserId.HasValue && !OrganizationId.HasValue) - { - yield return new ValidationResult("User or Ooganization is required."); - } + posData += ",accountCredit:1"; + inv.ItemDesc = "Bitwarden Account Credit"; + } + else + { + inv.ItemDesc = "Bitwarden"; + } + + inv.PosData = posData; + return inv; + } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (!UserId.HasValue && !OrganizationId.HasValue) + { + yield return new ValidationResult("User or Ooganization is required."); } } } diff --git a/src/Api/Models/Request/CipherPartialRequestModel.cs b/src/Api/Models/Request/CipherPartialRequestModel.cs index 996aec5fc..bc58eb427 100644 --- a/src/Api/Models/Request/CipherPartialRequestModel.cs +++ b/src/Api/Models/Request/CipherPartialRequestModel.cs @@ -1,11 +1,10 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request +namespace Bit.Api.Models.Request; + +public class CipherPartialRequestModel { - public class CipherPartialRequestModel - { - [StringLength(36)] - public string FolderId { get; set; } - public bool Favorite { get; set; } - } + [StringLength(36)] + public string FolderId { get; set; } + public bool Favorite { get; set; } } diff --git a/src/Api/Models/Request/CipherRequestModel.cs b/src/Api/Models/Request/CipherRequestModel.cs index 90435132a..f5f3eee42 100644 --- a/src/Api/Models/Request/CipherRequestModel.cs +++ b/src/Api/Models/Request/CipherRequestModel.cs @@ -8,342 +8,341 @@ using Core.Models.Data; using NS = Newtonsoft.Json; using NSL = Newtonsoft.Json.Linq; -namespace Bit.Api.Models.Request +namespace Bit.Api.Models.Request; + +public class CipherRequestModel { - public class CipherRequestModel + public CipherType Type { get; set; } + + [StringLength(36)] + public string OrganizationId { get; set; } + public string FolderId { get; set; } + public bool Favorite { get; set; } + public CipherRepromptType Reprompt { get; set; } + [Required] + [EncryptedString] + [EncryptedStringLength(1000)] + public string Name { get; set; } + [EncryptedString] + [EncryptedStringLength(10000)] + public string Notes { get; set; } + public IEnumerable Fields { get; set; } + public IEnumerable PasswordHistory { get; set; } + [Obsolete] + public Dictionary Attachments { get; set; } + // TODO: Rename to Attachments whenever the above is finally removed. + public Dictionary Attachments2 { get; set; } + + public CipherLoginModel Login { get; set; } + public CipherCardModel Card { get; set; } + public CipherIdentityModel Identity { get; set; } + public CipherSecureNoteModel SecureNote { get; set; } + public DateTime? LastKnownRevisionDate { get; set; } = null; + + public CipherDetails ToCipherDetails(Guid userId, bool allowOrgIdSet = true) { - public CipherType Type { get; set; } - - [StringLength(36)] - public string OrganizationId { get; set; } - public string FolderId { get; set; } - public bool Favorite { get; set; } - public CipherRepromptType Reprompt { get; set; } - [Required] - [EncryptedString] - [EncryptedStringLength(1000)] - public string Name { get; set; } - [EncryptedString] - [EncryptedStringLength(10000)] - public string Notes { get; set; } - public IEnumerable Fields { get; set; } - public IEnumerable PasswordHistory { get; set; } - [Obsolete] - public Dictionary Attachments { get; set; } - // TODO: Rename to Attachments whenever the above is finally removed. - public Dictionary Attachments2 { get; set; } - - public CipherLoginModel Login { get; set; } - public CipherCardModel Card { get; set; } - public CipherIdentityModel Identity { get; set; } - public CipherSecureNoteModel SecureNote { get; set; } - public DateTime? LastKnownRevisionDate { get; set; } = null; - - public CipherDetails ToCipherDetails(Guid userId, bool allowOrgIdSet = true) + var hasOrgId = !string.IsNullOrWhiteSpace(OrganizationId); + var cipher = new CipherDetails { - var hasOrgId = !string.IsNullOrWhiteSpace(OrganizationId); - var cipher = new CipherDetails - { - Type = Type, - UserId = !hasOrgId ? (Guid?)userId : null, - OrganizationId = allowOrgIdSet && hasOrgId ? new Guid(OrganizationId) : (Guid?)null, - Edit = true, - ViewPassword = true, - }; - ToCipherDetails(cipher); - return cipher; + Type = Type, + UserId = !hasOrgId ? (Guid?)userId : null, + OrganizationId = allowOrgIdSet && hasOrgId ? new Guid(OrganizationId) : (Guid?)null, + Edit = true, + ViewPassword = true, + }; + ToCipherDetails(cipher); + return cipher; + } + + public CipherDetails ToCipherDetails(CipherDetails existingCipher) + { + existingCipher.FolderId = string.IsNullOrWhiteSpace(FolderId) ? null : (Guid?)new Guid(FolderId); + existingCipher.Favorite = Favorite; + ToCipher(existingCipher); + return existingCipher; + } + + public Cipher ToCipher(Cipher existingCipher) + { + switch (existingCipher.Type) + { + case CipherType.Login: + var loginObj = NSL.JObject.FromObject(ToCipherLoginData(), + new NS.JsonSerializer { NullValueHandling = NS.NullValueHandling.Ignore }); + // TODO: Switch to JsonNode in .NET 6 https://docs.microsoft.com/en-us/dotnet/standard/serialization/system-text-json-use-dom-utf8jsonreader-utf8jsonwriter?pivots=dotnet-6-0 + loginObj[nameof(CipherLoginData.Uri)]?.Parent?.Remove(); + existingCipher.Data = loginObj.ToString(NS.Formatting.None); + break; + case CipherType.Card: + existingCipher.Data = JsonSerializer.Serialize(ToCipherCardData(), JsonHelpers.IgnoreWritingNull); + break; + case CipherType.Identity: + existingCipher.Data = JsonSerializer.Serialize(ToCipherIdentityData(), JsonHelpers.IgnoreWritingNull); + break; + case CipherType.SecureNote: + existingCipher.Data = JsonSerializer.Serialize(ToCipherSecureNoteData(), JsonHelpers.IgnoreWritingNull); + break; + default: + throw new ArgumentException("Unsupported type: " + nameof(Type) + "."); } - public CipherDetails ToCipherDetails(CipherDetails existingCipher) + existingCipher.Reprompt = Reprompt; + + var hasAttachments2 = (Attachments2?.Count ?? 0) > 0; + var hasAttachments = (Attachments?.Count ?? 0) > 0; + + if (!hasAttachments2 && !hasAttachments) { - existingCipher.FolderId = string.IsNullOrWhiteSpace(FolderId) ? null : (Guid?)new Guid(FolderId); - existingCipher.Favorite = Favorite; - ToCipher(existingCipher); return existingCipher; } - public Cipher ToCipher(Cipher existingCipher) + var attachments = existingCipher.GetAttachments(); + if ((attachments?.Count ?? 0) == 0) { - switch (existingCipher.Type) - { - case CipherType.Login: - var loginObj = NSL.JObject.FromObject(ToCipherLoginData(), - new NS.JsonSerializer { NullValueHandling = NS.NullValueHandling.Ignore }); - // TODO: Switch to JsonNode in .NET 6 https://docs.microsoft.com/en-us/dotnet/standard/serialization/system-text-json-use-dom-utf8jsonreader-utf8jsonwriter?pivots=dotnet-6-0 - loginObj[nameof(CipherLoginData.Uri)]?.Parent?.Remove(); - existingCipher.Data = loginObj.ToString(NS.Formatting.None); - break; - case CipherType.Card: - existingCipher.Data = JsonSerializer.Serialize(ToCipherCardData(), JsonHelpers.IgnoreWritingNull); - break; - case CipherType.Identity: - existingCipher.Data = JsonSerializer.Serialize(ToCipherIdentityData(), JsonHelpers.IgnoreWritingNull); - break; - case CipherType.SecureNote: - existingCipher.Data = JsonSerializer.Serialize(ToCipherSecureNoteData(), JsonHelpers.IgnoreWritingNull); - break; - default: - throw new ArgumentException("Unsupported type: " + nameof(Type) + "."); - } - - existingCipher.Reprompt = Reprompt; - - var hasAttachments2 = (Attachments2?.Count ?? 0) > 0; - var hasAttachments = (Attachments?.Count ?? 0) > 0; - - if (!hasAttachments2 && !hasAttachments) - { - return existingCipher; - } - - var attachments = existingCipher.GetAttachments(); - if ((attachments?.Count ?? 0) == 0) - { - return existingCipher; - } - - if (hasAttachments2) - { - foreach (var attachment in attachments.Where(a => Attachments2.ContainsKey(a.Key))) - { - var attachment2 = Attachments2[attachment.Key]; - attachment.Value.FileName = attachment2.FileName; - attachment.Value.Key = attachment2.Key; - } - } - else if (hasAttachments) - { - foreach (var attachment in attachments.Where(a => Attachments.ContainsKey(a.Key))) - { - attachment.Value.FileName = Attachments[attachment.Key]; - attachment.Value.Key = null; - } - } - - existingCipher.SetAttachments(attachments); return existingCipher; } - public Cipher ToOrganizationCipher() + if (hasAttachments2) { - if (string.IsNullOrWhiteSpace(OrganizationId)) + foreach (var attachment in attachments.Where(a => Attachments2.ContainsKey(a.Key))) { - throw new ArgumentNullException(nameof(OrganizationId)); - } - - return ToCipher(new Cipher - { - Type = Type, - OrganizationId = new Guid(OrganizationId) - }); - } - - public CipherDetails ToOrganizationCipherDetails(Guid orgId) - { - return ToCipherDetails(new CipherDetails - { - Type = Type, - OrganizationId = orgId, - Edit = true - }); - } - - private CipherLoginData ToCipherLoginData() - { - return new CipherLoginData - { - Name = Name, - Notes = Notes, - Fields = Fields?.Select(f => f.ToCipherFieldData()), - PasswordHistory = PasswordHistory?.Select(ph => ph.ToCipherPasswordHistoryData()), - - Uris = - Login.Uris?.Where(u => u != null) - .Select(u => u.ToCipherLoginUriData()), - Username = Login.Username, - Password = Login.Password, - PasswordRevisionDate = Login.PasswordRevisionDate, - Totp = Login.Totp, - AutofillOnPageLoad = Login.AutofillOnPageLoad, - }; - } - - private CipherIdentityData ToCipherIdentityData() - { - return new CipherIdentityData - { - Name = Name, - Notes = Notes, - Fields = Fields?.Select(f => f.ToCipherFieldData()), - PasswordHistory = PasswordHistory?.Select(ph => ph.ToCipherPasswordHistoryData()), - - Title = Identity.Title, - FirstName = Identity.FirstName, - MiddleName = Identity.MiddleName, - LastName = Identity.LastName, - Address1 = Identity.Address1, - Address2 = Identity.Address2, - Address3 = Identity.Address3, - City = Identity.City, - State = Identity.State, - PostalCode = Identity.PostalCode, - Country = Identity.Country, - Company = Identity.Company, - Email = Identity.Email, - Phone = Identity.Phone, - SSN = Identity.SSN, - Username = Identity.Username, - PassportNumber = Identity.PassportNumber, - LicenseNumber = Identity.LicenseNumber, - }; - } - - private CipherCardData ToCipherCardData() - { - return new CipherCardData - { - Name = Name, - Notes = Notes, - Fields = Fields?.Select(f => f.ToCipherFieldData()), - PasswordHistory = PasswordHistory?.Select(ph => ph.ToCipherPasswordHistoryData()), - - CardholderName = Card.CardholderName, - Brand = Card.Brand, - Number = Card.Number, - ExpMonth = Card.ExpMonth, - ExpYear = Card.ExpYear, - Code = Card.Code, - }; - } - - private CipherSecureNoteData ToCipherSecureNoteData() - { - return new CipherSecureNoteData - { - Name = Name, - Notes = Notes, - Fields = Fields?.Select(f => f.ToCipherFieldData()), - PasswordHistory = PasswordHistory?.Select(ph => ph.ToCipherPasswordHistoryData()), - - Type = SecureNote.Type, - }; - } - } - - public class CipherWithIdRequestModel : CipherRequestModel - { - [Required] - public Guid? Id { get; set; } - } - - public class CipherCreateRequestModel : IValidatableObject - { - public IEnumerable CollectionIds { get; set; } - [Required] - public CipherRequestModel Cipher { get; set; } - - public IEnumerable Validate(ValidationContext validationContext) - { - if (!string.IsNullOrWhiteSpace(Cipher.OrganizationId) && (!CollectionIds?.Any() ?? true)) - { - yield return new ValidationResult("You must select at least one collection.", - new string[] { nameof(CollectionIds) }); + var attachment2 = Attachments2[attachment.Key]; + attachment.Value.FileName = attachment2.FileName; + attachment.Value.Key = attachment2.Key; } } - } - - public class CipherShareRequestModel : IValidatableObject - { - [Required] - public IEnumerable CollectionIds { get; set; } - [Required] - public CipherRequestModel Cipher { get; set; } - - public IEnumerable Validate(ValidationContext validationContext) + else if (hasAttachments) { - if (string.IsNullOrWhiteSpace(Cipher.OrganizationId)) + foreach (var attachment in attachments.Where(a => Attachments.ContainsKey(a.Key))) { - yield return new ValidationResult("Cipher OrganizationId is required.", - new string[] { nameof(Cipher.OrganizationId) }); - } - - if (!CollectionIds?.Any() ?? true) - { - yield return new ValidationResult("You must select at least one collection.", - new string[] { nameof(CollectionIds) }); + attachment.Value.FileName = Attachments[attachment.Key]; + attachment.Value.Key = null; } } + + existingCipher.SetAttachments(attachments); + return existingCipher; } - public class CipherCollectionsRequestModel + public Cipher ToOrganizationCipher() { - [Required] - public IEnumerable CollectionIds { get; set; } - } - - public class CipherBulkDeleteRequestModel - { - [Required] - public IEnumerable Ids { get; set; } - public string OrganizationId { get; set; } - } - - public class CipherBulkRestoreRequestModel - { - [Required] - public IEnumerable Ids { get; set; } - } - - public class CipherBulkMoveRequestModel - { - [Required] - public IEnumerable Ids { get; set; } - public string FolderId { get; set; } - } - - public class CipherBulkShareRequestModel : IValidatableObject - { - [Required] - public IEnumerable CollectionIds { get; set; } - [Required] - public IEnumerable Ciphers { get; set; } - - public IEnumerable Validate(ValidationContext validationContext) + if (string.IsNullOrWhiteSpace(OrganizationId)) { - if (!Ciphers?.Any() ?? true) - { - yield return new ValidationResult("You must select at least one cipher.", - new string[] { nameof(Ciphers) }); - } - else - { - var allHaveIds = true; - var organizationIds = new HashSet(); - foreach (var c in Ciphers) - { - organizationIds.Add(c.OrganizationId); - if (allHaveIds) - { - allHaveIds = !(!c.Id.HasValue || string.IsNullOrWhiteSpace(c.OrganizationId)); - } - } + throw new ArgumentNullException(nameof(OrganizationId)); + } - if (!allHaveIds) - { - yield return new ValidationResult("All Ciphers must have an Id and OrganizationId.", - new string[] { nameof(Ciphers) }); - } - else if (organizationIds.Count != 1) - { - yield return new ValidationResult("All ciphers must be for the same organization."); - } - } + return ToCipher(new Cipher + { + Type = Type, + OrganizationId = new Guid(OrganizationId) + }); + } - if (!CollectionIds?.Any() ?? true) - { - yield return new ValidationResult("You must select at least one collection.", - new string[] { nameof(CollectionIds) }); - } + public CipherDetails ToOrganizationCipherDetails(Guid orgId) + { + return ToCipherDetails(new CipherDetails + { + Type = Type, + OrganizationId = orgId, + Edit = true + }); + } + + private CipherLoginData ToCipherLoginData() + { + return new CipherLoginData + { + Name = Name, + Notes = Notes, + Fields = Fields?.Select(f => f.ToCipherFieldData()), + PasswordHistory = PasswordHistory?.Select(ph => ph.ToCipherPasswordHistoryData()), + + Uris = + Login.Uris?.Where(u => u != null) + .Select(u => u.ToCipherLoginUriData()), + Username = Login.Username, + Password = Login.Password, + PasswordRevisionDate = Login.PasswordRevisionDate, + Totp = Login.Totp, + AutofillOnPageLoad = Login.AutofillOnPageLoad, + }; + } + + private CipherIdentityData ToCipherIdentityData() + { + return new CipherIdentityData + { + Name = Name, + Notes = Notes, + Fields = Fields?.Select(f => f.ToCipherFieldData()), + PasswordHistory = PasswordHistory?.Select(ph => ph.ToCipherPasswordHistoryData()), + + Title = Identity.Title, + FirstName = Identity.FirstName, + MiddleName = Identity.MiddleName, + LastName = Identity.LastName, + Address1 = Identity.Address1, + Address2 = Identity.Address2, + Address3 = Identity.Address3, + City = Identity.City, + State = Identity.State, + PostalCode = Identity.PostalCode, + Country = Identity.Country, + Company = Identity.Company, + Email = Identity.Email, + Phone = Identity.Phone, + SSN = Identity.SSN, + Username = Identity.Username, + PassportNumber = Identity.PassportNumber, + LicenseNumber = Identity.LicenseNumber, + }; + } + + private CipherCardData ToCipherCardData() + { + return new CipherCardData + { + Name = Name, + Notes = Notes, + Fields = Fields?.Select(f => f.ToCipherFieldData()), + PasswordHistory = PasswordHistory?.Select(ph => ph.ToCipherPasswordHistoryData()), + + CardholderName = Card.CardholderName, + Brand = Card.Brand, + Number = Card.Number, + ExpMonth = Card.ExpMonth, + ExpYear = Card.ExpYear, + Code = Card.Code, + }; + } + + private CipherSecureNoteData ToCipherSecureNoteData() + { + return new CipherSecureNoteData + { + Name = Name, + Notes = Notes, + Fields = Fields?.Select(f => f.ToCipherFieldData()), + PasswordHistory = PasswordHistory?.Select(ph => ph.ToCipherPasswordHistoryData()), + + Type = SecureNote.Type, + }; + } +} + +public class CipherWithIdRequestModel : CipherRequestModel +{ + [Required] + public Guid? Id { get; set; } +} + +public class CipherCreateRequestModel : IValidatableObject +{ + public IEnumerable CollectionIds { get; set; } + [Required] + public CipherRequestModel Cipher { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (!string.IsNullOrWhiteSpace(Cipher.OrganizationId) && (!CollectionIds?.Any() ?? true)) + { + yield return new ValidationResult("You must select at least one collection.", + new string[] { nameof(CollectionIds) }); + } + } +} + +public class CipherShareRequestModel : IValidatableObject +{ + [Required] + public IEnumerable CollectionIds { get; set; } + [Required] + public CipherRequestModel Cipher { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (string.IsNullOrWhiteSpace(Cipher.OrganizationId)) + { + yield return new ValidationResult("Cipher OrganizationId is required.", + new string[] { nameof(Cipher.OrganizationId) }); + } + + if (!CollectionIds?.Any() ?? true) + { + yield return new ValidationResult("You must select at least one collection.", + new string[] { nameof(CollectionIds) }); + } + } +} + +public class CipherCollectionsRequestModel +{ + [Required] + public IEnumerable CollectionIds { get; set; } +} + +public class CipherBulkDeleteRequestModel +{ + [Required] + public IEnumerable Ids { get; set; } + public string OrganizationId { get; set; } +} + +public class CipherBulkRestoreRequestModel +{ + [Required] + public IEnumerable Ids { get; set; } +} + +public class CipherBulkMoveRequestModel +{ + [Required] + public IEnumerable Ids { get; set; } + public string FolderId { get; set; } +} + +public class CipherBulkShareRequestModel : IValidatableObject +{ + [Required] + public IEnumerable CollectionIds { get; set; } + [Required] + public IEnumerable Ciphers { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (!Ciphers?.Any() ?? true) + { + yield return new ValidationResult("You must select at least one cipher.", + new string[] { nameof(Ciphers) }); + } + else + { + var allHaveIds = true; + var organizationIds = new HashSet(); + foreach (var c in Ciphers) + { + organizationIds.Add(c.OrganizationId); + if (allHaveIds) + { + allHaveIds = !(!c.Id.HasValue || string.IsNullOrWhiteSpace(c.OrganizationId)); + } + } + + if (!allHaveIds) + { + yield return new ValidationResult("All Ciphers must have an Id and OrganizationId.", + new string[] { nameof(Ciphers) }); + } + else if (organizationIds.Count != 1) + { + yield return new ValidationResult("All ciphers must be for the same organization."); + } + } + + if (!CollectionIds?.Any() ?? true) + { + yield return new ValidationResult("You must select at least one collection.", + new string[] { nameof(CollectionIds) }); } } } diff --git a/src/Api/Models/Request/CollectionRequestModel.cs b/src/Api/Models/Request/CollectionRequestModel.cs index e09510347..fb0be314d 100644 --- a/src/Api/Models/Request/CollectionRequestModel.cs +++ b/src/Api/Models/Request/CollectionRequestModel.cs @@ -2,31 +2,30 @@ using Bit.Core.Entities; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request +namespace Bit.Api.Models.Request; + +public class CollectionRequestModel { - public class CollectionRequestModel + [Required] + [EncryptedString] + [EncryptedStringLength(1000)] + public string Name { get; set; } + [StringLength(300)] + public string ExternalId { get; set; } + public IEnumerable Groups { get; set; } + + public Collection ToCollection(Guid orgId) { - [Required] - [EncryptedString] - [EncryptedStringLength(1000)] - public string Name { get; set; } - [StringLength(300)] - public string ExternalId { get; set; } - public IEnumerable Groups { get; set; } - - public Collection ToCollection(Guid orgId) + return ToCollection(new Collection { - return ToCollection(new Collection - { - OrganizationId = orgId - }); - } + OrganizationId = orgId + }); + } - public Collection ToCollection(Collection existingCollection) - { - existingCollection.Name = Name; - existingCollection.ExternalId = ExternalId; - return existingCollection; - } + public Collection ToCollection(Collection existingCollection) + { + existingCollection.Name = Name; + existingCollection.ExternalId = ExternalId; + return existingCollection; } } diff --git a/src/Api/Models/Request/DeviceRequestModels.cs b/src/Api/Models/Request/DeviceRequestModels.cs index b47693e3b..8d88c7f9c 100644 --- a/src/Api/Models/Request/DeviceRequestModels.cs +++ b/src/Api/Models/Request/DeviceRequestModels.cs @@ -2,49 +2,48 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Api.Models.Request +namespace Bit.Api.Models.Request; + +public class DeviceRequestModel { - public class DeviceRequestModel + [Required] + public DeviceType? Type { get; set; } + [Required] + [StringLength(50)] + public string Name { get; set; } + [Required] + [StringLength(50)] + public string Identifier { get; set; } + [StringLength(255)] + public string PushToken { get; set; } + + public Device ToDevice(Guid? userId = null) { - [Required] - public DeviceType? Type { get; set; } - [Required] - [StringLength(50)] - public string Name { get; set; } - [Required] - [StringLength(50)] - public string Identifier { get; set; } - [StringLength(255)] - public string PushToken { get; set; } - - public Device ToDevice(Guid? userId = null) + return ToDevice(new Device { - return ToDevice(new Device - { - UserId = userId == null ? default(Guid) : userId.Value - }); - } - - public Device ToDevice(Device existingDevice) - { - existingDevice.Name = Name; - existingDevice.Identifier = Identifier; - existingDevice.PushToken = PushToken; - existingDevice.Type = Type.Value; - - return existingDevice; - } + UserId = userId == null ? default(Guid) : userId.Value + }); } - public class DeviceTokenRequestModel + public Device ToDevice(Device existingDevice) { - [StringLength(255)] - public string PushToken { get; set; } + existingDevice.Name = Name; + existingDevice.Identifier = Identifier; + existingDevice.PushToken = PushToken; + existingDevice.Type = Type.Value; - public Device ToDevice(Device existingDevice) - { - existingDevice.PushToken = PushToken; - return existingDevice; - } + return existingDevice; + } +} + +public class DeviceTokenRequestModel +{ + [StringLength(255)] + public string PushToken { get; set; } + + public Device ToDevice(Device existingDevice) + { + existingDevice.PushToken = PushToken; + return existingDevice; } } diff --git a/src/Api/Models/Request/DeviceVerificationRequestModel.cs b/src/Api/Models/Request/DeviceVerificationRequestModel.cs index e8c22d9fe..d81471916 100644 --- a/src/Api/Models/Request/DeviceVerificationRequestModel.cs +++ b/src/Api/Models/Request/DeviceVerificationRequestModel.cs @@ -1,17 +1,16 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Entities; -namespace Bit.Api.Models.Request -{ - public class DeviceVerificationRequestModel - { - [Required] - public bool UnknownDeviceVerificationEnabled { get; set; } +namespace Bit.Api.Models.Request; - public User ToUser(User user) - { - user.UnknownDeviceVerificationEnabled = UnknownDeviceVerificationEnabled; - return user; - } +public class DeviceVerificationRequestModel +{ + [Required] + public bool UnknownDeviceVerificationEnabled { get; set; } + + public User ToUser(User user) + { + user.UnknownDeviceVerificationEnabled = UnknownDeviceVerificationEnabled; + return user; } } diff --git a/src/Api/Models/Request/EmergencyAccessRequstModels.cs b/src/Api/Models/Request/EmergencyAccessRequstModels.cs index a8e9f07a0..040316c50 100644 --- a/src/Api/Models/Request/EmergencyAccessRequstModels.cs +++ b/src/Api/Models/Request/EmergencyAccessRequstModels.cs @@ -3,47 +3,46 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request +namespace Bit.Api.Models.Request; + +public class EmergencyAccessInviteRequestModel { - public class EmergencyAccessInviteRequestModel - { - [Required] - [StrictEmailAddress] - [StringLength(256)] - public string Email { get; set; } - [Required] - public EmergencyAccessType? Type { get; set; } - [Required] - public int WaitTimeDays { get; set; } - } + [Required] + [StrictEmailAddress] + [StringLength(256)] + public string Email { get; set; } + [Required] + public EmergencyAccessType? Type { get; set; } + [Required] + public int WaitTimeDays { get; set; } +} - public class EmergencyAccessUpdateRequestModel - { - [Required] - public EmergencyAccessType Type { get; set; } - [Required] - public int WaitTimeDays { get; set; } - public string KeyEncrypted { get; set; } +public class EmergencyAccessUpdateRequestModel +{ + [Required] + public EmergencyAccessType Type { get; set; } + [Required] + public int WaitTimeDays { get; set; } + public string KeyEncrypted { get; set; } - public EmergencyAccess ToEmergencyAccess(EmergencyAccess existingEmergencyAccess) + public EmergencyAccess ToEmergencyAccess(EmergencyAccess existingEmergencyAccess) + { + // Ensure we only set keys for a confirmed emergency access. + if (!string.IsNullOrWhiteSpace(existingEmergencyAccess.KeyEncrypted) && !string.IsNullOrWhiteSpace(KeyEncrypted)) { - // Ensure we only set keys for a confirmed emergency access. - if (!string.IsNullOrWhiteSpace(existingEmergencyAccess.KeyEncrypted) && !string.IsNullOrWhiteSpace(KeyEncrypted)) - { - existingEmergencyAccess.KeyEncrypted = KeyEncrypted; - } - existingEmergencyAccess.Type = Type; - existingEmergencyAccess.WaitTimeDays = WaitTimeDays; - return existingEmergencyAccess; + existingEmergencyAccess.KeyEncrypted = KeyEncrypted; } - } - - public class EmergencyAccessPasswordRequestModel - { - [Required] - [StringLength(300)] - public string NewMasterPasswordHash { get; set; } - [Required] - public string Key { get; set; } + existingEmergencyAccess.Type = Type; + existingEmergencyAccess.WaitTimeDays = WaitTimeDays; + return existingEmergencyAccess; } } + +public class EmergencyAccessPasswordRequestModel +{ + [Required] + [StringLength(300)] + public string NewMasterPasswordHash { get; set; } + [Required] + public string Key { get; set; } +} diff --git a/src/Api/Models/Request/FolderRequestModel.cs b/src/Api/Models/Request/FolderRequestModel.cs index 52b0fcdb3..092b993bb 100644 --- a/src/Api/Models/Request/FolderRequestModel.cs +++ b/src/Api/Models/Request/FolderRequestModel.cs @@ -2,32 +2,31 @@ using Bit.Core.Entities; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request +namespace Bit.Api.Models.Request; + +public class FolderRequestModel { - public class FolderRequestModel + [Required] + [EncryptedString] + [EncryptedStringLength(1000)] + public string Name { get; set; } + + public Folder ToFolder(Guid userId) { - [Required] - [EncryptedString] - [EncryptedStringLength(1000)] - public string Name { get; set; } - - public Folder ToFolder(Guid userId) + return ToFolder(new Folder { - return ToFolder(new Folder - { - UserId = userId - }); - } - - public Folder ToFolder(Folder existingFolder) - { - existingFolder.Name = Name; - return existingFolder; - } + UserId = userId + }); } - public class FolderWithIdRequestModel : FolderRequestModel + public Folder ToFolder(Folder existingFolder) { - public Guid Id { get; set; } + existingFolder.Name = Name; + return existingFolder; } } + +public class FolderWithIdRequestModel : FolderRequestModel +{ + public Guid Id { get; set; } +} diff --git a/src/Api/Models/Request/GroupRequestModel.cs b/src/Api/Models/Request/GroupRequestModel.cs index 23b817624..71e76590b 100644 --- a/src/Api/Models/Request/GroupRequestModel.cs +++ b/src/Api/Models/Request/GroupRequestModel.cs @@ -1,33 +1,32 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Entities; -namespace Bit.Api.Models.Request +namespace Bit.Api.Models.Request; + +public class GroupRequestModel { - public class GroupRequestModel + [Required] + [StringLength(100)] + public string Name { get; set; } + [Required] + public bool? AccessAll { get; set; } + [StringLength(300)] + public string ExternalId { get; set; } + public IEnumerable Collections { get; set; } + + public Group ToGroup(Guid orgId) { - [Required] - [StringLength(100)] - public string Name { get; set; } - [Required] - public bool? AccessAll { get; set; } - [StringLength(300)] - public string ExternalId { get; set; } - public IEnumerable Collections { get; set; } - - public Group ToGroup(Guid orgId) + return ToGroup(new Group { - return ToGroup(new Group - { - OrganizationId = orgId - }); - } + OrganizationId = orgId + }); + } - public Group ToGroup(Group existingGroup) - { - existingGroup.Name = Name; - existingGroup.AccessAll = AccessAll.Value; - existingGroup.ExternalId = ExternalId; - return existingGroup; - } + public Group ToGroup(Group existingGroup) + { + existingGroup.Name = Name; + existingGroup.AccessAll = AccessAll.Value; + existingGroup.ExternalId = ExternalId; + return existingGroup; } } diff --git a/src/Api/Models/Request/IapCheckRequestModel.cs b/src/Api/Models/Request/IapCheckRequestModel.cs index d7ca6ba3b..ededb37ee 100644 --- a/src/Api/Models/Request/IapCheckRequestModel.cs +++ b/src/Api/Models/Request/IapCheckRequestModel.cs @@ -1,20 +1,19 @@ using System.ComponentModel.DataAnnotations; using Enums = Bit.Core.Enums; -namespace Bit.Api.Models.Request -{ - public class IapCheckRequestModel : IValidatableObject - { - [Required] - public Enums.PaymentMethodType? PaymentMethodType { get; set; } +namespace Bit.Api.Models.Request; - public IEnumerable Validate(ValidationContext validationContext) +public class IapCheckRequestModel : IValidatableObject +{ + [Required] + public Enums.PaymentMethodType? PaymentMethodType { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (PaymentMethodType != Enums.PaymentMethodType.AppleInApp) { - if (PaymentMethodType != Enums.PaymentMethodType.AppleInApp) - { - yield return new ValidationResult("Not a supported in-app purchase payment method.", - new string[] { nameof(PaymentMethodType) }); - } + yield return new ValidationResult("Not a supported in-app purchase payment method.", + new string[] { nameof(PaymentMethodType) }); } } } diff --git a/src/Api/Models/Request/InstallationRequestModel.cs b/src/Api/Models/Request/InstallationRequestModel.cs index 9f594f7bb..65b542e62 100644 --- a/src/Api/Models/Request/InstallationRequestModel.cs +++ b/src/Api/Models/Request/InstallationRequestModel.cs @@ -2,23 +2,22 @@ using Bit.Core.Entities; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request -{ - public class InstallationRequestModel - { - [Required] - [EmailAddress] - [StringLength(256)] - public string Email { get; set; } +namespace Bit.Api.Models.Request; - public Installation ToInstallation() +public class InstallationRequestModel +{ + [Required] + [EmailAddress] + [StringLength(256)] + public string Email { get; set; } + + public Installation ToInstallation() + { + return new Installation { - return new Installation - { - Key = CoreHelpers.SecureRandomString(20), - Email = Email, - Enabled = true - }; - } + Key = CoreHelpers.SecureRandomString(20), + Email = Email, + Enabled = true + }; } } diff --git a/src/Api/Models/Request/LicenseRequestModel.cs b/src/Api/Models/Request/LicenseRequestModel.cs index 382f68615..7b66d95f0 100644 --- a/src/Api/Models/Request/LicenseRequestModel.cs +++ b/src/Api/Models/Request/LicenseRequestModel.cs @@ -1,10 +1,9 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request +namespace Bit.Api.Models.Request; + +public class LicenseRequestModel { - public class LicenseRequestModel - { - [Required] - public IFormFile License { get; set; } - } + [Required] + public IFormFile License { get; set; } } diff --git a/src/Api/Models/Request/Organizations/ImportOrganizationCiphersRequestModel.cs b/src/Api/Models/Request/Organizations/ImportOrganizationCiphersRequestModel.cs index b70f39587..3aa6ef68c 100644 --- a/src/Api/Models/Request/Organizations/ImportOrganizationCiphersRequestModel.cs +++ b/src/Api/Models/Request/Organizations/ImportOrganizationCiphersRequestModel.cs @@ -1,9 +1,8 @@ -namespace Bit.Api.Models.Request.Organizations +namespace Bit.Api.Models.Request.Organizations; + +public class ImportOrganizationCiphersRequestModel { - public class ImportOrganizationCiphersRequestModel - { - public CollectionRequestModel[] Collections { get; set; } - public CipherRequestModel[] Ciphers { get; set; } - public KeyValuePair[] CollectionRelationships { get; set; } - } + public CollectionRequestModel[] Collections { get; set; } + public CipherRequestModel[] Ciphers { get; set; } + public KeyValuePair[] CollectionRelationships { get; set; } } diff --git a/src/Api/Models/Request/Organizations/ImportOrganizationUsersRequestModel.cs b/src/Api/Models/Request/Organizations/ImportOrganizationUsersRequestModel.cs index d35e051e9..3f1e2b244 100644 --- a/src/Api/Models/Request/Organizations/ImportOrganizationUsersRequestModel.cs +++ b/src/Api/Models/Request/Organizations/ImportOrganizationUsersRequestModel.cs @@ -1,69 +1,68 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Models.Business; -namespace Bit.Api.Models.Request.Organizations +namespace Bit.Api.Models.Request.Organizations; + +public class ImportOrganizationUsersRequestModel { - public class ImportOrganizationUsersRequestModel + public Group[] Groups { get; set; } + public User[] Users { get; set; } + public bool OverwriteExisting { get; set; } + public bool LargeImport { get; set; } + + public class Group { - public Group[] Groups { get; set; } - public User[] Users { get; set; } - public bool OverwriteExisting { get; set; } - public bool LargeImport { get; set; } + [Required] + [StringLength(100)] + public string Name { get; set; } + [Required] + [StringLength(300)] + public string ExternalId { get; set; } + public IEnumerable Users { get; set; } - public class Group + public ImportedGroup ToImportedGroup(Guid organizationId) { - [Required] - [StringLength(100)] - public string Name { get; set; } - [Required] - [StringLength(300)] - public string ExternalId { get; set; } - public IEnumerable Users { get; set; } - - public ImportedGroup ToImportedGroup(Guid organizationId) + var importedGroup = new ImportedGroup { - var importedGroup = new ImportedGroup + Group = new Core.Entities.Group { - Group = new Core.Entities.Group - { - OrganizationId = organizationId, - Name = Name, - ExternalId = ExternalId - }, - ExternalUserIds = new HashSet(Users) - }; + OrganizationId = organizationId, + Name = Name, + ExternalId = ExternalId + }, + ExternalUserIds = new HashSet(Users) + }; - return importedGroup; - } + return importedGroup; + } + } + + public class User : IValidatableObject + { + [EmailAddress] + [StringLength(256)] + public string Email { get; set; } + public bool Deleted { get; set; } + [Required] + [StringLength(300)] + public string ExternalId { get; set; } + + public ImportedOrganizationUser ToImportedOrganizationUser() + { + var importedUser = new ImportedOrganizationUser + { + Email = Email.ToLowerInvariant(), + ExternalId = ExternalId + }; + + return importedUser; } - public class User : IValidatableObject + public IEnumerable Validate(ValidationContext validationContext) { - [EmailAddress] - [StringLength(256)] - public string Email { get; set; } - public bool Deleted { get; set; } - [Required] - [StringLength(300)] - public string ExternalId { get; set; } - - public ImportedOrganizationUser ToImportedOrganizationUser() + if (string.IsNullOrWhiteSpace(Email) && !Deleted) { - var importedUser = new ImportedOrganizationUser - { - Email = Email.ToLowerInvariant(), - ExternalId = ExternalId - }; - - return importedUser; - } - - public IEnumerable Validate(ValidationContext validationContext) - { - if (string.IsNullOrWhiteSpace(Email) && !Deleted) - { - yield return new ValidationResult("Email is required for enabled users.", new string[] { nameof(Email) }); - } + yield return new ValidationResult("Email is required for enabled users.", new string[] { nameof(Email) }); } } } diff --git a/src/Api/Models/Request/Organizations/OrganizationConnectionRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationConnectionRequestModel.cs index 91132ec5e..9dbc9ca0a 100644 --- a/src/Api/Models/Request/Organizations/OrganizationConnectionRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationConnectionRequestModel.cs @@ -4,48 +4,47 @@ using Bit.Core.Exceptions; using Bit.Core.Models.Data.Organizations.OrganizationConnections; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request.Organizations +namespace Bit.Api.Models.Request.Organizations; + +public class OrganizationConnectionRequestModel { - public class OrganizationConnectionRequestModel - { - public OrganizationConnectionType Type { get; set; } - public Guid OrganizationId { get; set; } - public bool Enabled { get; set; } - public JsonDocument Config { get; set; } + public OrganizationConnectionType Type { get; set; } + public Guid OrganizationId { get; set; } + public bool Enabled { get; set; } + public JsonDocument Config { get; set; } - public OrganizationConnectionRequestModel() { } - } - - - public class OrganizationConnectionRequestModel : OrganizationConnectionRequestModel where T : new() - { - public T ParsedConfig { get; private set; } - - public OrganizationConnectionRequestModel(OrganizationConnectionRequestModel model) - { - Type = model.Type; - OrganizationId = model.OrganizationId; - Enabled = model.Enabled; - Config = model.Config; - - try - { - ParsedConfig = model.Config.ToObject(JsonHelpers.IgnoreCase); - } - catch (JsonException) - { - throw new BadRequestException("Organization Connection configuration malformed"); - } - } - - public OrganizationConnectionData ToData(Guid? id = null) => - new() - { - Id = id, - Type = Type, - OrganizationId = OrganizationId, - Enabled = Enabled, - Config = ParsedConfig, - }; - } + public OrganizationConnectionRequestModel() { } +} + + +public class OrganizationConnectionRequestModel : OrganizationConnectionRequestModel where T : new() +{ + public T ParsedConfig { get; private set; } + + public OrganizationConnectionRequestModel(OrganizationConnectionRequestModel model) + { + Type = model.Type; + OrganizationId = model.OrganizationId; + Enabled = model.Enabled; + Config = model.Config; + + try + { + ParsedConfig = model.Config.ToObject(JsonHelpers.IgnoreCase); + } + catch (JsonException) + { + throw new BadRequestException("Organization Connection configuration malformed"); + } + } + + public OrganizationConnectionData ToData(Guid? id = null) => + new() + { + Id = id, + Type = Type, + OrganizationId = OrganizationId, + Enabled = Enabled, + Config = ParsedConfig, + }; } diff --git a/src/Api/Models/Request/Organizations/OrganizationCreateLicenseRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationCreateLicenseRequestModel.cs index 722d338b9..2d9175158 100644 --- a/src/Api/Models/Request/Organizations/OrganizationCreateLicenseRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationCreateLicenseRequestModel.cs @@ -1,15 +1,14 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request.Organizations +namespace Bit.Api.Models.Request.Organizations; + +public class OrganizationCreateLicenseRequestModel : LicenseRequestModel { - public class OrganizationCreateLicenseRequestModel : LicenseRequestModel - { - [Required] - public string Key { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string CollectionName { get; set; } - public OrganizationKeysRequestModel Keys { get; set; } - } + [Required] + public string Key { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string CollectionName { get; set; } + public OrganizationKeysRequestModel Keys { get; set; } } diff --git a/src/Api/Models/Request/Organizations/OrganizationCreateRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationCreateRequestModel.cs index 4f84ea5c6..3e4602179 100644 --- a/src/Api/Models/Request/Organizations/OrganizationCreateRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationCreateRequestModel.cs @@ -4,99 +4,98 @@ using Bit.Core.Enums; using Bit.Core.Models.Business; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request.Organizations +namespace Bit.Api.Models.Request.Organizations; + +public class OrganizationCreateRequestModel : IValidatableObject { - public class OrganizationCreateRequestModel : IValidatableObject + [Required] + [StringLength(50)] + public string Name { get; set; } + [StringLength(50)] + public string BusinessName { get; set; } + [Required] + [StringLength(256)] + [EmailAddress] + public string BillingEmail { get; set; } + public PlanType PlanType { get; set; } + [Required] + public string Key { get; set; } + public OrganizationKeysRequestModel Keys { get; set; } + public PaymentMethodType? PaymentMethodType { get; set; } + public string PaymentToken { get; set; } + [Range(0, int.MaxValue)] + public int AdditionalSeats { get; set; } + [Range(0, 99)] + public short? AdditionalStorageGb { get; set; } + public bool PremiumAccessAddon { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string CollectionName { get; set; } + public string TaxIdNumber { get; set; } + public string BillingAddressLine1 { get; set; } + public string BillingAddressLine2 { get; set; } + public string BillingAddressCity { get; set; } + public string BillingAddressState { get; set; } + public string BillingAddressPostalCode { get; set; } + [StringLength(2)] + public string BillingAddressCountry { get; set; } + public int? MaxAutoscaleSeats { get; set; } + + public virtual OrganizationSignup ToOrganizationSignup(User user) { - [Required] - [StringLength(50)] - public string Name { get; set; } - [StringLength(50)] - public string BusinessName { get; set; } - [Required] - [StringLength(256)] - [EmailAddress] - public string BillingEmail { get; set; } - public PlanType PlanType { get; set; } - [Required] - public string Key { get; set; } - public OrganizationKeysRequestModel Keys { get; set; } - public PaymentMethodType? PaymentMethodType { get; set; } - public string PaymentToken { get; set; } - [Range(0, int.MaxValue)] - public int AdditionalSeats { get; set; } - [Range(0, 99)] - public short? AdditionalStorageGb { get; set; } - public bool PremiumAccessAddon { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string CollectionName { get; set; } - public string TaxIdNumber { get; set; } - public string BillingAddressLine1 { get; set; } - public string BillingAddressLine2 { get; set; } - public string BillingAddressCity { get; set; } - public string BillingAddressState { get; set; } - public string BillingAddressPostalCode { get; set; } - [StringLength(2)] - public string BillingAddressCountry { get; set; } - public int? MaxAutoscaleSeats { get; set; } - - public virtual OrganizationSignup ToOrganizationSignup(User user) + var orgSignup = new OrganizationSignup { - var orgSignup = new OrganizationSignup + Owner = user, + OwnerKey = Key, + Name = Name, + Plan = PlanType, + PaymentMethodType = PaymentMethodType, + PaymentToken = PaymentToken, + AdditionalSeats = AdditionalSeats, + MaxAutoscaleSeats = MaxAutoscaleSeats, + AdditionalStorageGb = AdditionalStorageGb.GetValueOrDefault(0), + PremiumAccessAddon = PremiumAccessAddon, + BillingEmail = BillingEmail, + BusinessName = BusinessName, + CollectionName = CollectionName, + TaxInfo = new TaxInfo { - Owner = user, - OwnerKey = Key, - Name = Name, - Plan = PlanType, - PaymentMethodType = PaymentMethodType, - PaymentToken = PaymentToken, - AdditionalSeats = AdditionalSeats, - MaxAutoscaleSeats = MaxAutoscaleSeats, - AdditionalStorageGb = AdditionalStorageGb.GetValueOrDefault(0), - PremiumAccessAddon = PremiumAccessAddon, - BillingEmail = BillingEmail, - BusinessName = BusinessName, - CollectionName = CollectionName, - TaxInfo = new TaxInfo - { - TaxIdNumber = TaxIdNumber, - BillingAddressLine1 = BillingAddressLine1, - BillingAddressLine2 = BillingAddressLine2, - BillingAddressCity = BillingAddressCity, - BillingAddressState = BillingAddressState, - BillingAddressPostalCode = BillingAddressPostalCode, - BillingAddressCountry = BillingAddressCountry, - }, - }; + TaxIdNumber = TaxIdNumber, + BillingAddressLine1 = BillingAddressLine1, + BillingAddressLine2 = BillingAddressLine2, + BillingAddressCity = BillingAddressCity, + BillingAddressState = BillingAddressState, + BillingAddressPostalCode = BillingAddressPostalCode, + BillingAddressCountry = BillingAddressCountry, + }, + }; - Keys?.ToOrganizationSignup(orgSignup); + Keys?.ToOrganizationSignup(orgSignup); - return orgSignup; + return orgSignup; + } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (PlanType != PlanType.Free && string.IsNullOrWhiteSpace(PaymentToken)) + { + yield return new ValidationResult("Payment required.", new string[] { nameof(PaymentToken) }); } - - public IEnumerable Validate(ValidationContext validationContext) + if (PlanType != PlanType.Free && !PaymentMethodType.HasValue) { - if (PlanType != PlanType.Free && string.IsNullOrWhiteSpace(PaymentToken)) - { - yield return new ValidationResult("Payment required.", new string[] { nameof(PaymentToken) }); - } - if (PlanType != PlanType.Free && !PaymentMethodType.HasValue) - { - yield return new ValidationResult("Payment method type required.", - new string[] { nameof(PaymentMethodType) }); - } - if (PlanType != PlanType.Free && string.IsNullOrWhiteSpace(BillingAddressCountry)) - { - yield return new ValidationResult("Country required.", - new string[] { nameof(BillingAddressCountry) }); - } - if (PlanType != PlanType.Free && BillingAddressCountry == "US" && - string.IsNullOrWhiteSpace(BillingAddressPostalCode)) - { - yield return new ValidationResult("Zip / postal code is required.", - new string[] { nameof(BillingAddressPostalCode) }); - } + yield return new ValidationResult("Payment method type required.", + new string[] { nameof(PaymentMethodType) }); + } + if (PlanType != PlanType.Free && string.IsNullOrWhiteSpace(BillingAddressCountry)) + { + yield return new ValidationResult("Country required.", + new string[] { nameof(BillingAddressCountry) }); + } + if (PlanType != PlanType.Free && BillingAddressCountry == "US" && + string.IsNullOrWhiteSpace(BillingAddressPostalCode)) + { + yield return new ValidationResult("Zip / postal code is required.", + new string[] { nameof(BillingAddressPostalCode) }); } } } diff --git a/src/Api/Models/Request/Organizations/OrganizationKeysRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationKeysRequestModel.cs index 070b03d19..a22b4eaa6 100644 --- a/src/Api/Models/Request/Organizations/OrganizationKeysRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationKeysRequestModel.cs @@ -2,58 +2,57 @@ using Bit.Core.Entities; using Bit.Core.Models.Business; -namespace Bit.Api.Models.Request.Organizations +namespace Bit.Api.Models.Request.Organizations; + +public class OrganizationKeysRequestModel { - public class OrganizationKeysRequestModel + [Required] + public string PublicKey { get; set; } + [Required] + public string EncryptedPrivateKey { get; set; } + + public OrganizationSignup ToOrganizationSignup(OrganizationSignup existingSignup) { - [Required] - public string PublicKey { get; set; } - [Required] - public string EncryptedPrivateKey { get; set; } - - public OrganizationSignup ToOrganizationSignup(OrganizationSignup existingSignup) + if (string.IsNullOrWhiteSpace(existingSignup.PublicKey)) { - if (string.IsNullOrWhiteSpace(existingSignup.PublicKey)) - { - existingSignup.PublicKey = PublicKey; - } - - if (string.IsNullOrWhiteSpace(existingSignup.PrivateKey)) - { - existingSignup.PrivateKey = EncryptedPrivateKey; - } - - return existingSignup; + existingSignup.PublicKey = PublicKey; } - public OrganizationUpgrade ToOrganizationUpgrade(OrganizationUpgrade existingUpgrade) + if (string.IsNullOrWhiteSpace(existingSignup.PrivateKey)) { - if (string.IsNullOrWhiteSpace(existingUpgrade.PublicKey)) - { - existingUpgrade.PublicKey = PublicKey; - } - - if (string.IsNullOrWhiteSpace(existingUpgrade.PrivateKey)) - { - existingUpgrade.PrivateKey = EncryptedPrivateKey; - } - - return existingUpgrade; + existingSignup.PrivateKey = EncryptedPrivateKey; } - public Organization ToOrganization(Organization existingOrg) + return existingSignup; + } + + public OrganizationUpgrade ToOrganizationUpgrade(OrganizationUpgrade existingUpgrade) + { + if (string.IsNullOrWhiteSpace(existingUpgrade.PublicKey)) { - if (string.IsNullOrWhiteSpace(existingOrg.PublicKey)) - { - existingOrg.PublicKey = PublicKey; - } - - if (string.IsNullOrWhiteSpace(existingOrg.PrivateKey)) - { - existingOrg.PrivateKey = EncryptedPrivateKey; - } - - return existingOrg; + existingUpgrade.PublicKey = PublicKey; } + + if (string.IsNullOrWhiteSpace(existingUpgrade.PrivateKey)) + { + existingUpgrade.PrivateKey = EncryptedPrivateKey; + } + + return existingUpgrade; + } + + public Organization ToOrganization(Organization existingOrg) + { + if (string.IsNullOrWhiteSpace(existingOrg.PublicKey)) + { + existingOrg.PublicKey = PublicKey; + } + + if (string.IsNullOrWhiteSpace(existingOrg.PrivateKey)) + { + existingOrg.PrivateKey = EncryptedPrivateKey; + } + + return existingOrg; } } diff --git a/src/Api/Models/Request/Organizations/OrganizationSeatRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationSeatRequestModel.cs index 068d09624..b3849f0a4 100644 --- a/src/Api/Models/Request/Organizations/OrganizationSeatRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationSeatRequestModel.cs @@ -1,18 +1,17 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Organizations -{ - public class OrganizationSeatRequestModel : IValidatableObject - { - [Required] - public int? SeatAdjustment { get; set; } +namespace Bit.Api.Models.Request.Organizations; - public IEnumerable Validate(ValidationContext validationContext) +public class OrganizationSeatRequestModel : IValidatableObject +{ + [Required] + public int? SeatAdjustment { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (SeatAdjustment == 0) { - if (SeatAdjustment == 0) - { - yield return new ValidationResult("Seat adjustment cannot be 0.", new string[] { nameof(SeatAdjustment) }); - } + yield return new ValidationResult("Seat adjustment cannot be 0.", new string[] { nameof(SeatAdjustment) }); } } } diff --git a/src/Api/Models/Request/Organizations/OrganizationSponsorshipCreateRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationSponsorshipCreateRequestModel.cs index e3848a4cb..ba88f1b90 100644 --- a/src/Api/Models/Request/Organizations/OrganizationSponsorshipCreateRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationSponsorshipCreateRequestModel.cs @@ -2,19 +2,18 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request.Organizations +namespace Bit.Api.Models.Request.Organizations; + +public class OrganizationSponsorshipCreateRequestModel { - public class OrganizationSponsorshipCreateRequestModel - { - [Required] - public PlanSponsorshipType PlanSponsorshipType { get; set; } + [Required] + public PlanSponsorshipType PlanSponsorshipType { get; set; } - [Required] - [StringLength(256)] - [StrictEmailAddress] - public string SponsoredEmail { get; set; } + [Required] + [StringLength(256)] + [StrictEmailAddress] + public string SponsoredEmail { get; set; } - [StringLength(256)] - public string FriendlyName { get; set; } - } + [StringLength(256)] + public string FriendlyName { get; set; } } diff --git a/src/Api/Models/Request/Organizations/OrganizationSponsorshipRedeemRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationSponsorshipRedeemRequestModel.cs index 4a4cc2602..19b11cd77 100644 --- a/src/Api/Models/Request/Organizations/OrganizationSponsorshipRedeemRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationSponsorshipRedeemRequestModel.cs @@ -1,13 +1,12 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Enums; -namespace Bit.Api.Models.Request.Organizations +namespace Bit.Api.Models.Request.Organizations; + +public class OrganizationSponsorshipRedeemRequestModel { - public class OrganizationSponsorshipRedeemRequestModel - { - [Required] - public PlanSponsorshipType PlanSponsorshipType { get; set; } - [Required] - public Guid SponsoredOrganizationId { get; set; } - } + [Required] + public PlanSponsorshipType PlanSponsorshipType { get; set; } + [Required] + public Guid SponsoredOrganizationId { get; set; } } diff --git a/src/Api/Models/Request/Organizations/OrganizationSsoRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationSsoRequestModel.cs index 5291ce175..47594703d 100644 --- a/src/Api/Models/Request/Organizations/OrganizationSsoRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationSsoRequestModel.cs @@ -10,216 +10,215 @@ using Bit.Core.Sso; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authentication.OpenIdConnect; -namespace Bit.Api.Models.Request.Organizations +namespace Bit.Api.Models.Request.Organizations; + +public class OrganizationSsoRequestModel { - public class OrganizationSsoRequestModel + [Required] + public bool Enabled { get; set; } + [Required] + public SsoConfigurationDataRequest Data { get; set; } + + public SsoConfig ToSsoConfig(Guid organizationId) { - [Required] - public bool Enabled { get; set; } - [Required] - public SsoConfigurationDataRequest Data { get; set; } - - public SsoConfig ToSsoConfig(Guid organizationId) - { - return ToSsoConfig(new SsoConfig { OrganizationId = organizationId }); - } - - public SsoConfig ToSsoConfig(SsoConfig existingConfig) - { - existingConfig.Enabled = Enabled; - var configurationData = Data.ToConfigurationData(); - existingConfig.SetData(configurationData); - return existingConfig; - } + return ToSsoConfig(new SsoConfig { OrganizationId = organizationId }); } - public class SsoConfigurationDataRequest : IValidatableObject + public SsoConfig ToSsoConfig(SsoConfig existingConfig) { - public SsoConfigurationDataRequest() { } - - [Required] - public SsoType ConfigType { get; set; } - - public bool KeyConnectorEnabled { get; set; } - public string KeyConnectorUrl { get; set; } - - // OIDC - public string Authority { get; set; } - public string ClientId { get; set; } - public string ClientSecret { get; set; } - public string MetadataAddress { get; set; } - public OpenIdConnectRedirectBehavior RedirectBehavior { get; set; } - public bool? GetClaimsFromUserInfoEndpoint { get; set; } - public string AdditionalScopes { get; set; } - public string AdditionalUserIdClaimTypes { get; set; } - public string AdditionalEmailClaimTypes { get; set; } - public string AdditionalNameClaimTypes { get; set; } - public string AcrValues { get; set; } - public string ExpectedReturnAcrValue { get; set; } - - // SAML2 SP - public Saml2NameIdFormat SpNameIdFormat { get; set; } - public string SpOutboundSigningAlgorithm { get; set; } - public Saml2SigningBehavior SpSigningBehavior { get; set; } - public bool? SpWantAssertionsSigned { get; set; } - public bool? SpValidateCertificates { get; set; } - public string SpMinIncomingSigningAlgorithm { get; set; } - - // SAML2 IDP - public string IdpEntityId { get; set; } - public Saml2BindingType IdpBindingType { get; set; } - public string IdpSingleSignOnServiceUrl { get; set; } - public string IdpSingleLogoutServiceUrl { get; set; } - public string IdpArtifactResolutionServiceUrl { get => null; set { /*IGNORE*/ } } - public string IdpX509PublicCert { get; set; } - public string IdpOutboundSigningAlgorithm { get; set; } - public bool? IdpAllowUnsolicitedAuthnResponse { get; set; } - public bool? IdpDisableOutboundLogoutRequests { get; set; } - public bool? IdpWantAuthnRequestsSigned { get; set; } - - public IEnumerable Validate(ValidationContext context) - { - var i18nService = context.GetService(typeof(II18nService)) as I18nService; - - if (ConfigType == SsoType.OpenIdConnect) - { - if (string.IsNullOrWhiteSpace(Authority)) - { - yield return new ValidationResult(i18nService.GetLocalizedHtmlString("AuthorityValidationError"), - new[] { nameof(Authority) }); - } - - if (string.IsNullOrWhiteSpace(ClientId)) - { - yield return new ValidationResult(i18nService.GetLocalizedHtmlString("ClientIdValidationError"), - new[] { nameof(ClientId) }); - } - - if (string.IsNullOrWhiteSpace(ClientSecret)) - { - yield return new ValidationResult(i18nService.GetLocalizedHtmlString("ClientSecretValidationError"), - new[] { nameof(ClientSecret) }); - } - } - else if (ConfigType == SsoType.Saml2) - { - if (string.IsNullOrWhiteSpace(IdpEntityId)) - { - yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpEntityIdValidationError"), - new[] { nameof(IdpEntityId) }); - } - - if (!Uri.IsWellFormedUriString(IdpEntityId, UriKind.Absolute) && string.IsNullOrWhiteSpace(IdpSingleSignOnServiceUrl)) - { - yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpSingleSignOnServiceUrlValidationError"), - new[] { nameof(IdpSingleSignOnServiceUrl) }); - } - - if (InvalidServiceUrl(IdpSingleSignOnServiceUrl)) - { - yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpSingleSignOnServiceUrlInvalid"), - new[] { nameof(IdpSingleSignOnServiceUrl) }); - } - - if (InvalidServiceUrl(IdpSingleLogoutServiceUrl)) - { - yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpSingleLogoutServiceUrlInvalid"), - new[] { nameof(IdpSingleLogoutServiceUrl) }); - } - - if (!string.IsNullOrWhiteSpace(IdpX509PublicCert)) - { - // Validate the certificate is in a valid format - ValidationResult failedResult = null; - try - { - var certData = CoreHelpers.Base64UrlDecode(StripPemCertificateElements(IdpX509PublicCert)); - new X509Certificate2(certData); - } - catch (FormatException) - { - failedResult = new ValidationResult(i18nService.GetLocalizedHtmlString("IdpX509PublicCertInvalidFormatValidationError"), - new[] { nameof(IdpX509PublicCert) }); - } - catch (CryptographicException cryptoEx) - { - failedResult = new ValidationResult(i18nService.GetLocalizedHtmlString("IdpX509PublicCertCryptographicExceptionValidationError", cryptoEx.Message), - new[] { nameof(IdpX509PublicCert) }); - } - catch (Exception ex) - { - failedResult = new ValidationResult(i18nService.GetLocalizedHtmlString("IdpX509PublicCertValidationError", ex.Message), - new[] { nameof(IdpX509PublicCert) }); - } - if (failedResult != null) - { - yield return failedResult; - } - } - } - } - - public SsoConfigurationData ToConfigurationData() - { - return new SsoConfigurationData - { - ConfigType = ConfigType, - KeyConnectorEnabled = KeyConnectorEnabled, - KeyConnectorUrl = KeyConnectorUrl, - Authority = Authority, - ClientId = ClientId, - ClientSecret = ClientSecret, - MetadataAddress = MetadataAddress, - GetClaimsFromUserInfoEndpoint = GetClaimsFromUserInfoEndpoint.GetValueOrDefault(), - RedirectBehavior = RedirectBehavior, - IdpEntityId = IdpEntityId, - IdpBindingType = IdpBindingType, - IdpSingleSignOnServiceUrl = IdpSingleSignOnServiceUrl, - IdpSingleLogoutServiceUrl = IdpSingleLogoutServiceUrl, - IdpArtifactResolutionServiceUrl = null, - IdpX509PublicCert = StripPemCertificateElements(IdpX509PublicCert), - IdpOutboundSigningAlgorithm = IdpOutboundSigningAlgorithm, - IdpAllowUnsolicitedAuthnResponse = IdpAllowUnsolicitedAuthnResponse.GetValueOrDefault(), - IdpDisableOutboundLogoutRequests = IdpDisableOutboundLogoutRequests.GetValueOrDefault(), - IdpWantAuthnRequestsSigned = IdpWantAuthnRequestsSigned.GetValueOrDefault(), - SpNameIdFormat = SpNameIdFormat, - SpOutboundSigningAlgorithm = SpOutboundSigningAlgorithm ?? SamlSigningAlgorithms.Sha256, - SpSigningBehavior = SpSigningBehavior, - SpWantAssertionsSigned = SpWantAssertionsSigned.GetValueOrDefault(), - SpValidateCertificates = SpValidateCertificates.GetValueOrDefault(), - SpMinIncomingSigningAlgorithm = SpMinIncomingSigningAlgorithm, - AdditionalScopes = AdditionalScopes, - AdditionalUserIdClaimTypes = AdditionalUserIdClaimTypes, - AdditionalEmailClaimTypes = AdditionalEmailClaimTypes, - AdditionalNameClaimTypes = AdditionalNameClaimTypes, - AcrValues = AcrValues, - ExpectedReturnAcrValue = ExpectedReturnAcrValue, - }; - } - - private string StripPemCertificateElements(string certificateText) - { - if (string.IsNullOrWhiteSpace(certificateText)) - { - return null; - } - return Regex.Replace(certificateText, - @"(((BEGIN|END) CERTIFICATE)|([\-\n\r\t\s\f]))", - string.Empty, - RegexOptions.Multiline | RegexOptions.IgnoreCase | RegexOptions.CultureInvariant); - } - - private bool InvalidServiceUrl(string url) - { - if (string.IsNullOrWhiteSpace(url)) - { - return false; - } - if (!url.StartsWith("http://") && !url.StartsWith("https://")) - { - return true; - } - return Regex.IsMatch(url, "[<>\"]"); - } + existingConfig.Enabled = Enabled; + var configurationData = Data.ToConfigurationData(); + existingConfig.SetData(configurationData); + return existingConfig; + } +} + +public class SsoConfigurationDataRequest : IValidatableObject +{ + public SsoConfigurationDataRequest() { } + + [Required] + public SsoType ConfigType { get; set; } + + public bool KeyConnectorEnabled { get; set; } + public string KeyConnectorUrl { get; set; } + + // OIDC + public string Authority { get; set; } + public string ClientId { get; set; } + public string ClientSecret { get; set; } + public string MetadataAddress { get; set; } + public OpenIdConnectRedirectBehavior RedirectBehavior { get; set; } + public bool? GetClaimsFromUserInfoEndpoint { get; set; } + public string AdditionalScopes { get; set; } + public string AdditionalUserIdClaimTypes { get; set; } + public string AdditionalEmailClaimTypes { get; set; } + public string AdditionalNameClaimTypes { get; set; } + public string AcrValues { get; set; } + public string ExpectedReturnAcrValue { get; set; } + + // SAML2 SP + public Saml2NameIdFormat SpNameIdFormat { get; set; } + public string SpOutboundSigningAlgorithm { get; set; } + public Saml2SigningBehavior SpSigningBehavior { get; set; } + public bool? SpWantAssertionsSigned { get; set; } + public bool? SpValidateCertificates { get; set; } + public string SpMinIncomingSigningAlgorithm { get; set; } + + // SAML2 IDP + public string IdpEntityId { get; set; } + public Saml2BindingType IdpBindingType { get; set; } + public string IdpSingleSignOnServiceUrl { get; set; } + public string IdpSingleLogoutServiceUrl { get; set; } + public string IdpArtifactResolutionServiceUrl { get => null; set { /*IGNORE*/ } } + public string IdpX509PublicCert { get; set; } + public string IdpOutboundSigningAlgorithm { get; set; } + public bool? IdpAllowUnsolicitedAuthnResponse { get; set; } + public bool? IdpDisableOutboundLogoutRequests { get; set; } + public bool? IdpWantAuthnRequestsSigned { get; set; } + + public IEnumerable Validate(ValidationContext context) + { + var i18nService = context.GetService(typeof(II18nService)) as I18nService; + + if (ConfigType == SsoType.OpenIdConnect) + { + if (string.IsNullOrWhiteSpace(Authority)) + { + yield return new ValidationResult(i18nService.GetLocalizedHtmlString("AuthorityValidationError"), + new[] { nameof(Authority) }); + } + + if (string.IsNullOrWhiteSpace(ClientId)) + { + yield return new ValidationResult(i18nService.GetLocalizedHtmlString("ClientIdValidationError"), + new[] { nameof(ClientId) }); + } + + if (string.IsNullOrWhiteSpace(ClientSecret)) + { + yield return new ValidationResult(i18nService.GetLocalizedHtmlString("ClientSecretValidationError"), + new[] { nameof(ClientSecret) }); + } + } + else if (ConfigType == SsoType.Saml2) + { + if (string.IsNullOrWhiteSpace(IdpEntityId)) + { + yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpEntityIdValidationError"), + new[] { nameof(IdpEntityId) }); + } + + if (!Uri.IsWellFormedUriString(IdpEntityId, UriKind.Absolute) && string.IsNullOrWhiteSpace(IdpSingleSignOnServiceUrl)) + { + yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpSingleSignOnServiceUrlValidationError"), + new[] { nameof(IdpSingleSignOnServiceUrl) }); + } + + if (InvalidServiceUrl(IdpSingleSignOnServiceUrl)) + { + yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpSingleSignOnServiceUrlInvalid"), + new[] { nameof(IdpSingleSignOnServiceUrl) }); + } + + if (InvalidServiceUrl(IdpSingleLogoutServiceUrl)) + { + yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpSingleLogoutServiceUrlInvalid"), + new[] { nameof(IdpSingleLogoutServiceUrl) }); + } + + if (!string.IsNullOrWhiteSpace(IdpX509PublicCert)) + { + // Validate the certificate is in a valid format + ValidationResult failedResult = null; + try + { + var certData = CoreHelpers.Base64UrlDecode(StripPemCertificateElements(IdpX509PublicCert)); + new X509Certificate2(certData); + } + catch (FormatException) + { + failedResult = new ValidationResult(i18nService.GetLocalizedHtmlString("IdpX509PublicCertInvalidFormatValidationError"), + new[] { nameof(IdpX509PublicCert) }); + } + catch (CryptographicException cryptoEx) + { + failedResult = new ValidationResult(i18nService.GetLocalizedHtmlString("IdpX509PublicCertCryptographicExceptionValidationError", cryptoEx.Message), + new[] { nameof(IdpX509PublicCert) }); + } + catch (Exception ex) + { + failedResult = new ValidationResult(i18nService.GetLocalizedHtmlString("IdpX509PublicCertValidationError", ex.Message), + new[] { nameof(IdpX509PublicCert) }); + } + if (failedResult != null) + { + yield return failedResult; + } + } + } + } + + public SsoConfigurationData ToConfigurationData() + { + return new SsoConfigurationData + { + ConfigType = ConfigType, + KeyConnectorEnabled = KeyConnectorEnabled, + KeyConnectorUrl = KeyConnectorUrl, + Authority = Authority, + ClientId = ClientId, + ClientSecret = ClientSecret, + MetadataAddress = MetadataAddress, + GetClaimsFromUserInfoEndpoint = GetClaimsFromUserInfoEndpoint.GetValueOrDefault(), + RedirectBehavior = RedirectBehavior, + IdpEntityId = IdpEntityId, + IdpBindingType = IdpBindingType, + IdpSingleSignOnServiceUrl = IdpSingleSignOnServiceUrl, + IdpSingleLogoutServiceUrl = IdpSingleLogoutServiceUrl, + IdpArtifactResolutionServiceUrl = null, + IdpX509PublicCert = StripPemCertificateElements(IdpX509PublicCert), + IdpOutboundSigningAlgorithm = IdpOutboundSigningAlgorithm, + IdpAllowUnsolicitedAuthnResponse = IdpAllowUnsolicitedAuthnResponse.GetValueOrDefault(), + IdpDisableOutboundLogoutRequests = IdpDisableOutboundLogoutRequests.GetValueOrDefault(), + IdpWantAuthnRequestsSigned = IdpWantAuthnRequestsSigned.GetValueOrDefault(), + SpNameIdFormat = SpNameIdFormat, + SpOutboundSigningAlgorithm = SpOutboundSigningAlgorithm ?? SamlSigningAlgorithms.Sha256, + SpSigningBehavior = SpSigningBehavior, + SpWantAssertionsSigned = SpWantAssertionsSigned.GetValueOrDefault(), + SpValidateCertificates = SpValidateCertificates.GetValueOrDefault(), + SpMinIncomingSigningAlgorithm = SpMinIncomingSigningAlgorithm, + AdditionalScopes = AdditionalScopes, + AdditionalUserIdClaimTypes = AdditionalUserIdClaimTypes, + AdditionalEmailClaimTypes = AdditionalEmailClaimTypes, + AdditionalNameClaimTypes = AdditionalNameClaimTypes, + AcrValues = AcrValues, + ExpectedReturnAcrValue = ExpectedReturnAcrValue, + }; + } + + private string StripPemCertificateElements(string certificateText) + { + if (string.IsNullOrWhiteSpace(certificateText)) + { + return null; + } + return Regex.Replace(certificateText, + @"(((BEGIN|END) CERTIFICATE)|([\-\n\r\t\s\f]))", + string.Empty, + RegexOptions.Multiline | RegexOptions.IgnoreCase | RegexOptions.CultureInvariant); + } + + private bool InvalidServiceUrl(string url) + { + if (string.IsNullOrWhiteSpace(url)) + { + return false; + } + if (!url.StartsWith("http://") && !url.StartsWith("https://")) + { + return true; + } + return Regex.IsMatch(url, "[<>\"]"); } } diff --git a/src/Api/Models/Request/Organizations/OrganizationSubscriptionUpdateRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationSubscriptionUpdateRequestModel.cs index 9adb0b7ec..6db32589a 100644 --- a/src/Api/Models/Request/Organizations/OrganizationSubscriptionUpdateRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationSubscriptionUpdateRequestModel.cs @@ -1,11 +1,10 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Organizations +namespace Bit.Api.Models.Request.Organizations; + +public class OrganizationSubscriptionUpdateRequestModel { - public class OrganizationSubscriptionUpdateRequestModel - { - [Required] - public int SeatAdjustment { get; set; } - public int? MaxAutoscaleSeats { get; set; } - } + [Required] + public int SeatAdjustment { get; set; } + public int? MaxAutoscaleSeats { get; set; } } diff --git a/src/Api/Models/Request/Organizations/OrganizationTaxInfoUpdateRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationTaxInfoUpdateRequestModel.cs index a67cbbb7e..c20fa07af 100644 --- a/src/Api/Models/Request/Organizations/OrganizationTaxInfoUpdateRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationTaxInfoUpdateRequestModel.cs @@ -1,13 +1,12 @@ using Bit.Api.Models.Request.Accounts; -namespace Bit.Api.Models.Request.Organizations +namespace Bit.Api.Models.Request.Organizations; + +public class OrganizationTaxInfoUpdateRequestModel : TaxInfoUpdateRequestModel { - public class OrganizationTaxInfoUpdateRequestModel : TaxInfoUpdateRequestModel - { - public string TaxId { get; set; } - public string Line1 { get; set; } - public string Line2 { get; set; } - public string City { get; set; } - public string State { get; set; } - } + public string TaxId { get; set; } + public string Line1 { get; set; } + public string Line2 { get; set; } + public string City { get; set; } + public string State { get; set; } } diff --git a/src/Api/Models/Request/Organizations/OrganizationUpdateRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationUpdateRequestModel.cs index 24cce9710..f67016bce 100644 --- a/src/Api/Models/Request/Organizations/OrganizationUpdateRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationUpdateRequestModel.cs @@ -3,36 +3,35 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; using Bit.Core.Settings; -namespace Bit.Api.Models.Request.Organizations -{ - public class OrganizationUpdateRequestModel - { - [Required] - [StringLength(50)] - public string Name { get; set; } - [StringLength(50)] - public string BusinessName { get; set; } - [StringLength(50)] - public string Identifier { get; set; } - [EmailAddress] - [Required] - [StringLength(256)] - public string BillingEmail { get; set; } - public Permissions Permissions { get; set; } - public OrganizationKeysRequestModel Keys { get; set; } +namespace Bit.Api.Models.Request.Organizations; - public virtual Organization ToOrganization(Organization existingOrganization, GlobalSettings globalSettings) +public class OrganizationUpdateRequestModel +{ + [Required] + [StringLength(50)] + public string Name { get; set; } + [StringLength(50)] + public string BusinessName { get; set; } + [StringLength(50)] + public string Identifier { get; set; } + [EmailAddress] + [Required] + [StringLength(256)] + public string BillingEmail { get; set; } + public Permissions Permissions { get; set; } + public OrganizationKeysRequestModel Keys { get; set; } + + public virtual Organization ToOrganization(Organization existingOrganization, GlobalSettings globalSettings) + { + if (!globalSettings.SelfHosted) { - if (!globalSettings.SelfHosted) - { - // These items come from the license file - existingOrganization.Name = Name; - existingOrganization.BusinessName = BusinessName; - existingOrganization.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); - } - existingOrganization.Identifier = Identifier; - Keys?.ToOrganization(existingOrganization); - return existingOrganization; + // These items come from the license file + existingOrganization.Name = Name; + existingOrganization.BusinessName = BusinessName; + existingOrganization.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); } + existingOrganization.Identifier = Identifier; + Keys?.ToOrganization(existingOrganization); + return existingOrganization; } } diff --git a/src/Api/Models/Request/Organizations/OrganizationUpgradeRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationUpgradeRequestModel.cs index 7ceedef08..fb2666cc1 100644 --- a/src/Api/Models/Request/Organizations/OrganizationUpgradeRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationUpgradeRequestModel.cs @@ -2,41 +2,40 @@ using Bit.Core.Enums; using Bit.Core.Models.Business; -namespace Bit.Api.Models.Request.Organizations +namespace Bit.Api.Models.Request.Organizations; + +public class OrganizationUpgradeRequestModel { - public class OrganizationUpgradeRequestModel + [StringLength(50)] + public string BusinessName { get; set; } + public PlanType PlanType { get; set; } + [Range(0, int.MaxValue)] + public int AdditionalSeats { get; set; } + [Range(0, 99)] + public short? AdditionalStorageGb { get; set; } + public bool PremiumAccessAddon { get; set; } + public string BillingAddressCountry { get; set; } + public string BillingAddressPostalCode { get; set; } + public OrganizationKeysRequestModel Keys { get; set; } + + public OrganizationUpgrade ToOrganizationUpgrade() { - [StringLength(50)] - public string BusinessName { get; set; } - public PlanType PlanType { get; set; } - [Range(0, int.MaxValue)] - public int AdditionalSeats { get; set; } - [Range(0, 99)] - public short? AdditionalStorageGb { get; set; } - public bool PremiumAccessAddon { get; set; } - public string BillingAddressCountry { get; set; } - public string BillingAddressPostalCode { get; set; } - public OrganizationKeysRequestModel Keys { get; set; } - - public OrganizationUpgrade ToOrganizationUpgrade() + var orgUpgrade = new OrganizationUpgrade { - var orgUpgrade = new OrganizationUpgrade + AdditionalSeats = AdditionalSeats, + AdditionalStorageGb = AdditionalStorageGb.GetValueOrDefault(), + BusinessName = BusinessName, + Plan = PlanType, + PremiumAccessAddon = PremiumAccessAddon, + TaxInfo = new TaxInfo() { - AdditionalSeats = AdditionalSeats, - AdditionalStorageGb = AdditionalStorageGb.GetValueOrDefault(), - BusinessName = BusinessName, - Plan = PlanType, - PremiumAccessAddon = PremiumAccessAddon, - TaxInfo = new TaxInfo() - { - BillingAddressCountry = BillingAddressCountry, - BillingAddressPostalCode = BillingAddressPostalCode - } - }; + BillingAddressCountry = BillingAddressCountry, + BillingAddressPostalCode = BillingAddressPostalCode + } + }; - Keys?.ToOrganizationUpgrade(orgUpgrade); + Keys?.ToOrganizationUpgrade(orgUpgrade); - return orgUpgrade; - } + return orgUpgrade; } } diff --git a/src/Api/Models/Request/Organizations/OrganizationUserRequestModels.cs b/src/Api/Models/Request/Organizations/OrganizationUserRequestModels.cs index 09cb7efb1..4d6fcfedb 100644 --- a/src/Api/Models/Request/Organizations/OrganizationUserRequestModels.cs +++ b/src/Api/Models/Request/Organizations/OrganizationUserRequestModels.cs @@ -7,99 +7,98 @@ using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request.Organizations +namespace Bit.Api.Models.Request.Organizations; + +public class OrganizationUserInviteRequestModel { - public class OrganizationUserInviteRequestModel - { - [Required] - [StrictEmailAddressList] - public IEnumerable Emails { get; set; } - [Required] - public OrganizationUserType? Type { get; set; } - public bool AccessAll { get; set; } - public Permissions Permissions { get; set; } - public IEnumerable Collections { get; set; } + [Required] + [StrictEmailAddressList] + public IEnumerable Emails { get; set; } + [Required] + public OrganizationUserType? Type { get; set; } + public bool AccessAll { get; set; } + public Permissions Permissions { get; set; } + public IEnumerable Collections { get; set; } - public OrganizationUserInviteData ToData() + public OrganizationUserInviteData ToData() + { + return new OrganizationUserInviteData { - return new OrganizationUserInviteData - { - Emails = Emails, - Type = Type, - AccessAll = AccessAll, - Collections = Collections?.Select(c => c.ToSelectionReadOnly()), - Permissions = Permissions, - }; - } - } - - public class OrganizationUserAcceptRequestModel - { - [Required] - public string Token { get; set; } - // Used to auto-enroll in master password reset - public string ResetPasswordKey { get; set; } - } - - public class OrganizationUserConfirmRequestModel - { - [Required] - public string Key { get; set; } - } - - public class OrganizationUserBulkConfirmRequestModelEntry - { - [Required] - public Guid Id { get; set; } - [Required] - public string Key { get; set; } - } - - public class OrganizationUserBulkConfirmRequestModel - { - [Required] - public IEnumerable Keys { get; set; } - - public Dictionary ToDictionary() - { - return Keys.ToDictionary(e => e.Id, e => e.Key); - } - } - - public class OrganizationUserUpdateRequestModel - { - [Required] - public OrganizationUserType? Type { get; set; } - public bool AccessAll { get; set; } - public Permissions Permissions { get; set; } - public IEnumerable Collections { get; set; } - - public OrganizationUser ToOrganizationUser(OrganizationUser existingUser) - { - existingUser.Type = Type.Value; - existingUser.Permissions = JsonSerializer.Serialize(Permissions, new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }); - existingUser.AccessAll = AccessAll; - return existingUser; - } - } - - public class OrganizationUserUpdateGroupsRequestModel - { - [Required] - public IEnumerable GroupIds { get; set; } - } - - public class OrganizationUserResetPasswordEnrollmentRequestModel : SecretVerificationRequestModel - { - public string ResetPasswordKey { get; set; } - } - - public class OrganizationUserBulkRequestModel - { - [Required] - public IEnumerable Ids { get; set; } + Emails = Emails, + Type = Type, + AccessAll = AccessAll, + Collections = Collections?.Select(c => c.ToSelectionReadOnly()), + Permissions = Permissions, + }; } } + +public class OrganizationUserAcceptRequestModel +{ + [Required] + public string Token { get; set; } + // Used to auto-enroll in master password reset + public string ResetPasswordKey { get; set; } +} + +public class OrganizationUserConfirmRequestModel +{ + [Required] + public string Key { get; set; } +} + +public class OrganizationUserBulkConfirmRequestModelEntry +{ + [Required] + public Guid Id { get; set; } + [Required] + public string Key { get; set; } +} + +public class OrganizationUserBulkConfirmRequestModel +{ + [Required] + public IEnumerable Keys { get; set; } + + public Dictionary ToDictionary() + { + return Keys.ToDictionary(e => e.Id, e => e.Key); + } +} + +public class OrganizationUserUpdateRequestModel +{ + [Required] + public OrganizationUserType? Type { get; set; } + public bool AccessAll { get; set; } + public Permissions Permissions { get; set; } + public IEnumerable Collections { get; set; } + + public OrganizationUser ToOrganizationUser(OrganizationUser existingUser) + { + existingUser.Type = Type.Value; + existingUser.Permissions = JsonSerializer.Serialize(Permissions, new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }); + existingUser.AccessAll = AccessAll; + return existingUser; + } +} + +public class OrganizationUserUpdateGroupsRequestModel +{ + [Required] + public IEnumerable GroupIds { get; set; } +} + +public class OrganizationUserResetPasswordEnrollmentRequestModel : SecretVerificationRequestModel +{ + public string ResetPasswordKey { get; set; } +} + +public class OrganizationUserBulkRequestModel +{ + [Required] + public IEnumerable Ids { get; set; } +} diff --git a/src/Api/Models/Request/Organizations/OrganizationUserResetPasswordRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationUserResetPasswordRequestModel.cs index 4434a64c9..571f69c1e 100644 --- a/src/Api/Models/Request/Organizations/OrganizationUserResetPasswordRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationUserResetPasswordRequestModel.cs @@ -1,13 +1,12 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Organizations +namespace Bit.Api.Models.Request.Organizations; + +public class OrganizationUserResetPasswordRequestModel { - public class OrganizationUserResetPasswordRequestModel - { - [Required] - [StringLength(300)] - public string NewMasterPasswordHash { get; set; } - [Required] - public string Key { get; set; } - } + [Required] + [StringLength(300)] + public string NewMasterPasswordHash { get; set; } + [Required] + public string Key { get; set; } } diff --git a/src/Api/Models/Request/Organizations/OrganizationVerifyBankRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationVerifyBankRequestModel.cs index 9023cd665..71f687380 100644 --- a/src/Api/Models/Request/Organizations/OrganizationVerifyBankRequestModel.cs +++ b/src/Api/Models/Request/Organizations/OrganizationVerifyBankRequestModel.cs @@ -1,14 +1,13 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Organizations +namespace Bit.Api.Models.Request.Organizations; + +public class OrganizationVerifyBankRequestModel { - public class OrganizationVerifyBankRequestModel - { - [Required] - [Range(1, 99)] - public int? Amount1 { get; set; } - [Required] - [Range(1, 99)] - public int? Amount2 { get; set; } - } + [Required] + [Range(1, 99)] + public int? Amount1 { get; set; } + [Required] + [Range(1, 99)] + public int? Amount2 { get; set; } } diff --git a/src/Api/Models/Request/PaymentRequestModel.cs b/src/Api/Models/Request/PaymentRequestModel.cs index b10b7df0c..47e39b010 100644 --- a/src/Api/Models/Request/PaymentRequestModel.cs +++ b/src/Api/Models/Request/PaymentRequestModel.cs @@ -2,13 +2,12 @@ using Bit.Api.Models.Request.Organizations; using Bit.Core.Enums; -namespace Bit.Api.Models.Request +namespace Bit.Api.Models.Request; + +public class PaymentRequestModel : OrganizationTaxInfoUpdateRequestModel { - public class PaymentRequestModel : OrganizationTaxInfoUpdateRequestModel - { - [Required] - public PaymentMethodType? PaymentMethodType { get; set; } - [Required] - public string PaymentToken { get; set; } - } + [Required] + public PaymentMethodType? PaymentMethodType { get; set; } + [Required] + public string PaymentToken { get; set; } } diff --git a/src/Api/Models/Request/PolicyRequestModel.cs b/src/Api/Models/Request/PolicyRequestModel.cs index 927b7fcc3..bc303cd40 100644 --- a/src/Api/Models/Request/PolicyRequestModel.cs +++ b/src/Api/Models/Request/PolicyRequestModel.cs @@ -3,30 +3,29 @@ using System.Text.Json; using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Api.Models.Request +namespace Bit.Api.Models.Request; + +public class PolicyRequestModel { - public class PolicyRequestModel + [Required] + public PolicyType? Type { get; set; } + [Required] + public bool? Enabled { get; set; } + public Dictionary Data { get; set; } + + public Policy ToPolicy(Guid orgId) { - [Required] - public PolicyType? Type { get; set; } - [Required] - public bool? Enabled { get; set; } - public Dictionary Data { get; set; } - - public Policy ToPolicy(Guid orgId) + return ToPolicy(new Policy { - return ToPolicy(new Policy - { - Type = Type.Value, - OrganizationId = orgId - }); - } + Type = Type.Value, + OrganizationId = orgId + }); + } - public Policy ToPolicy(Policy existingPolicy) - { - existingPolicy.Enabled = Enabled.GetValueOrDefault(); - existingPolicy.Data = Data != null ? JsonSerializer.Serialize(Data) : null; - return existingPolicy; - } + public Policy ToPolicy(Policy existingPolicy) + { + existingPolicy.Enabled = Enabled.GetValueOrDefault(); + existingPolicy.Data = Data != null ? JsonSerializer.Serialize(Data) : null; + return existingPolicy; } } diff --git a/src/Api/Models/Request/Providers/ProviderOrganizationAddRequestModel.cs b/src/Api/Models/Request/Providers/ProviderOrganizationAddRequestModel.cs index d3d203f4d..b6ea8759e 100644 --- a/src/Api/Models/Request/Providers/ProviderOrganizationAddRequestModel.cs +++ b/src/Api/Models/Request/Providers/ProviderOrganizationAddRequestModel.cs @@ -1,13 +1,12 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request.Providers -{ - public class ProviderOrganizationAddRequestModel - { - [Required] - public Guid OrganizationId { get; set; } +namespace Bit.Api.Models.Request.Providers; - [Required] - public string Key { get; set; } - } +public class ProviderOrganizationAddRequestModel +{ + [Required] + public Guid OrganizationId { get; set; } + + [Required] + public string Key { get; set; } } diff --git a/src/Api/Models/Request/Providers/ProviderOrganizationCreateRequestModel.cs b/src/Api/Models/Request/Providers/ProviderOrganizationCreateRequestModel.cs index cd796c11f..7fead717b 100644 --- a/src/Api/Models/Request/Providers/ProviderOrganizationCreateRequestModel.cs +++ b/src/Api/Models/Request/Providers/ProviderOrganizationCreateRequestModel.cs @@ -2,15 +2,14 @@ using Bit.Api.Models.Request.Organizations; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request.Providers +namespace Bit.Api.Models.Request.Providers; + +public class ProviderOrganizationCreateRequestModel { - public class ProviderOrganizationCreateRequestModel - { - [Required] - [StringLength(256)] - [StrictEmailAddress] - public string ClientOwnerEmail { get; set; } - [Required] - public OrganizationCreateRequestModel OrganizationCreateRequest { get; set; } - } + [Required] + [StringLength(256)] + [StrictEmailAddress] + public string ClientOwnerEmail { get; set; } + [Required] + public OrganizationCreateRequestModel OrganizationCreateRequest { get; set; } } diff --git a/src/Api/Models/Request/Providers/ProviderSetupRequestModel.cs b/src/Api/Models/Request/Providers/ProviderSetupRequestModel.cs index f0b5fffe9..51191f947 100644 --- a/src/Api/Models/Request/Providers/ProviderSetupRequestModel.cs +++ b/src/Api/Models/Request/Providers/ProviderSetupRequestModel.cs @@ -1,31 +1,30 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Entities.Provider; -namespace Bit.Api.Models.Request.Providers +namespace Bit.Api.Models.Request.Providers; + +public class ProviderSetupRequestModel { - public class ProviderSetupRequestModel + [Required] + [StringLength(50)] + public string Name { get; set; } + [StringLength(50)] + public string BusinessName { get; set; } + [Required] + [StringLength(256)] + [EmailAddress] + public string BillingEmail { get; set; } + [Required] + public string Token { get; set; } + [Required] + public string Key { get; set; } + + public virtual Provider ToProvider(Provider provider) { - [Required] - [StringLength(50)] - public string Name { get; set; } - [StringLength(50)] - public string BusinessName { get; set; } - [Required] - [StringLength(256)] - [EmailAddress] - public string BillingEmail { get; set; } - [Required] - public string Token { get; set; } - [Required] - public string Key { get; set; } + provider.Name = Name; + provider.BusinessName = BusinessName; + provider.BillingEmail = BillingEmail; - public virtual Provider ToProvider(Provider provider) - { - provider.Name = Name; - provider.BusinessName = BusinessName; - provider.BillingEmail = BillingEmail; - - return provider; - } + return provider; } } diff --git a/src/Api/Models/Request/Providers/ProviderUpdateRequestModel.cs b/src/Api/Models/Request/Providers/ProviderUpdateRequestModel.cs index 339ac3180..ceec796dc 100644 --- a/src/Api/Models/Request/Providers/ProviderUpdateRequestModel.cs +++ b/src/Api/Models/Request/Providers/ProviderUpdateRequestModel.cs @@ -2,30 +2,29 @@ using Bit.Core.Entities.Provider; using Bit.Core.Settings; -namespace Bit.Api.Models.Request.Providers -{ - public class ProviderUpdateRequestModel - { - [Required] - [StringLength(50)] - public string Name { get; set; } - [StringLength(50)] - public string BusinessName { get; set; } - [EmailAddress] - [Required] - [StringLength(256)] - public string BillingEmail { get; set; } +namespace Bit.Api.Models.Request.Providers; - public virtual Provider ToProvider(Provider existingProvider, GlobalSettings globalSettings) +public class ProviderUpdateRequestModel +{ + [Required] + [StringLength(50)] + public string Name { get; set; } + [StringLength(50)] + public string BusinessName { get; set; } + [EmailAddress] + [Required] + [StringLength(256)] + public string BillingEmail { get; set; } + + public virtual Provider ToProvider(Provider existingProvider, GlobalSettings globalSettings) + { + if (!globalSettings.SelfHosted) { - if (!globalSettings.SelfHosted) - { - // These items come from the license file - existingProvider.Name = Name; - existingProvider.BusinessName = BusinessName; - existingProvider.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); - } - return existingProvider; + // These items come from the license file + existingProvider.Name = Name; + existingProvider.BusinessName = BusinessName; + existingProvider.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); } + return existingProvider; } } diff --git a/src/Api/Models/Request/Providers/ProviderUserRequestModels.cs b/src/Api/Models/Request/Providers/ProviderUserRequestModels.cs index cbce7bc90..9c451d8ad 100644 --- a/src/Api/Models/Request/Providers/ProviderUserRequestModels.cs +++ b/src/Api/Models/Request/Providers/ProviderUserRequestModels.cs @@ -3,63 +3,62 @@ using Bit.Core.Entities.Provider; using Bit.Core.Enums.Provider; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request.Providers +namespace Bit.Api.Models.Request.Providers; + +public class ProviderUserInviteRequestModel { - public class ProviderUserInviteRequestModel + [Required] + [StrictEmailAddressList] + public IEnumerable Emails { get; set; } + [Required] + public ProviderUserType? Type { get; set; } +} + +public class ProviderUserAcceptRequestModel +{ + [Required] + public string Token { get; set; } +} + +public class ProviderUserConfirmRequestModel +{ + [Required] + public string Key { get; set; } +} + +public class ProviderUserBulkConfirmRequestModelEntry +{ + [Required] + public Guid Id { get; set; } + [Required] + public string Key { get; set; } +} + +public class ProviderUserBulkConfirmRequestModel +{ + [Required] + public IEnumerable Keys { get; set; } + + public Dictionary ToDictionary() { - [Required] - [StrictEmailAddressList] - public IEnumerable Emails { get; set; } - [Required] - public ProviderUserType? Type { get; set; } - } - - public class ProviderUserAcceptRequestModel - { - [Required] - public string Token { get; set; } - } - - public class ProviderUserConfirmRequestModel - { - [Required] - public string Key { get; set; } - } - - public class ProviderUserBulkConfirmRequestModelEntry - { - [Required] - public Guid Id { get; set; } - [Required] - public string Key { get; set; } - } - - public class ProviderUserBulkConfirmRequestModel - { - [Required] - public IEnumerable Keys { get; set; } - - public Dictionary ToDictionary() - { - return Keys.ToDictionary(e => e.Id, e => e.Key); - } - } - - public class ProviderUserUpdateRequestModel - { - [Required] - public ProviderUserType? Type { get; set; } - - public ProviderUser ToProviderUser(ProviderUser existingUser) - { - existingUser.Type = Type.Value; - return existingUser; - } - } - - public class ProviderUserBulkRequestModel - { - [Required] - public IEnumerable Ids { get; set; } + return Keys.ToDictionary(e => e.Id, e => e.Key); } } + +public class ProviderUserUpdateRequestModel +{ + [Required] + public ProviderUserType? Type { get; set; } + + public ProviderUser ToProviderUser(ProviderUser existingUser) + { + existingUser.Type = Type.Value; + return existingUser; + } +} + +public class ProviderUserBulkRequestModel +{ + [Required] + public IEnumerable Ids { get; set; } +} diff --git a/src/Api/Models/Request/SelectionReadOnlyRequestModel.cs b/src/Api/Models/Request/SelectionReadOnlyRequestModel.cs index f5d2043d5..5b82dc3e3 100644 --- a/src/Api/Models/Request/SelectionReadOnlyRequestModel.cs +++ b/src/Api/Models/Request/SelectionReadOnlyRequestModel.cs @@ -1,23 +1,22 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Request -{ - public class SelectionReadOnlyRequestModel - { - [Required] - public string Id { get; set; } - public bool ReadOnly { get; set; } - public bool HidePasswords { get; set; } +namespace Bit.Api.Models.Request; - public SelectionReadOnly ToSelectionReadOnly() +public class SelectionReadOnlyRequestModel +{ + [Required] + public string Id { get; set; } + public bool ReadOnly { get; set; } + public bool HidePasswords { get; set; } + + public SelectionReadOnly ToSelectionReadOnly() + { + return new SelectionReadOnly { - return new SelectionReadOnly - { - Id = new Guid(Id), - ReadOnly = ReadOnly, - HidePasswords = HidePasswords, - }; - } + Id = new Guid(Id), + ReadOnly = ReadOnly, + HidePasswords = HidePasswords, + }; } } diff --git a/src/Api/Models/Request/SendAccessRequestModel.cs b/src/Api/Models/Request/SendAccessRequestModel.cs index 3ee43985f..2a8b3f40a 100644 --- a/src/Api/Models/Request/SendAccessRequestModel.cs +++ b/src/Api/Models/Request/SendAccessRequestModel.cs @@ -1,10 +1,9 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Api.Models.Request +namespace Bit.Api.Models.Request; + +public class SendAccessRequestModel { - public class SendAccessRequestModel - { - [StringLength(300)] - public string Password { get; set; } - } + [StringLength(300)] + public string Password { get; set; } } diff --git a/src/Api/Models/Request/SendRequestModel.cs b/src/Api/Models/Request/SendRequestModel.cs index 1e3182359..51b15cba3 100644 --- a/src/Api/Models/Request/SendRequestModel.cs +++ b/src/Api/Models/Request/SendRequestModel.cs @@ -7,135 +7,134 @@ using Bit.Core.Models.Data; using Bit.Core.Services; using Bit.Core.Utilities; -namespace Bit.Api.Models.Request +namespace Bit.Api.Models.Request; + +public class SendRequestModel { - public class SendRequestModel + public SendType Type { get; set; } + public long? FileLength { get; set; } = null; + [EncryptedString] + [EncryptedStringLength(1000)] + public string Name { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string Notes { get; set; } + [Required] + [EncryptedString] + [EncryptedStringLength(1000)] + public string Key { get; set; } + [Range(1, int.MaxValue)] + public int? MaxAccessCount { get; set; } + public DateTime? ExpirationDate { get; set; } + [Required] + public DateTime? DeletionDate { get; set; } + public SendFileModel File { get; set; } + public SendTextModel Text { get; set; } + [StringLength(1000)] + public string Password { get; set; } + [Required] + public bool? Disabled { get; set; } + public bool? HideEmail { get; set; } + + public Send ToSend(Guid userId, ISendService sendService) { - public SendType Type { get; set; } - public long? FileLength { get; set; } = null; - [EncryptedString] - [EncryptedStringLength(1000)] - public string Name { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string Notes { get; set; } - [Required] - [EncryptedString] - [EncryptedStringLength(1000)] - public string Key { get; set; } - [Range(1, int.MaxValue)] - public int? MaxAccessCount { get; set; } - public DateTime? ExpirationDate { get; set; } - [Required] - public DateTime? DeletionDate { get; set; } - public SendFileModel File { get; set; } - public SendTextModel Text { get; set; } - [StringLength(1000)] - public string Password { get; set; } - [Required] - public bool? Disabled { get; set; } - public bool? HideEmail { get; set; } - - public Send ToSend(Guid userId, ISendService sendService) + var send = new Send { - var send = new Send - { - Type = Type, - UserId = (Guid?)userId - }; - ToSend(send, sendService); - return send; + Type = Type, + UserId = (Guid?)userId + }; + ToSend(send, sendService); + return send; + } + + public (Send, SendFileData) ToSend(Guid userId, string fileName, ISendService sendService) + { + var send = ToSendBase(new Send + { + Type = Type, + UserId = (Guid?)userId + }, sendService); + var data = new SendFileData(Name, Notes, fileName); + return (send, data); + } + + public Send ToSend(Send existingSend, ISendService sendService) + { + existingSend = ToSendBase(existingSend, sendService); + switch (existingSend.Type) + { + case SendType.File: + var fileData = JsonSerializer.Deserialize(existingSend.Data); + fileData.Name = Name; + fileData.Notes = Notes; + existingSend.Data = JsonSerializer.Serialize(fileData, JsonHelpers.IgnoreWritingNull); + break; + case SendType.Text: + existingSend.Data = JsonSerializer.Serialize(ToSendTextData(), JsonHelpers.IgnoreWritingNull); + break; + default: + throw new ArgumentException("Unsupported type: " + nameof(Type) + "."); } + return existingSend; + } - public (Send, SendFileData) ToSend(Guid userId, string fileName, ISendService sendService) + public void ValidateCreation() + { + var now = DateTime.UtcNow; + // Add 1 minute for a sane buffer and client clock float + var nowPlus1Minute = now.AddMinutes(1); + if (ExpirationDate.HasValue && ExpirationDate.Value <= nowPlus1Minute) { - var send = ToSendBase(new Send - { - Type = Type, - UserId = (Guid?)userId - }, sendService); - var data = new SendFileData(Name, Notes, fileName); - return (send, data); + throw new BadRequestException("You cannot create a Send that is already expired. " + + "Adjust the expiration date and try again."); } + ValidateEdit(); + } - public Send ToSend(Send existingSend, ISendService sendService) + public void ValidateEdit() + { + var now = DateTime.UtcNow; + // Add 1 minute for a sane buffer and client clock float + var nowPlus1Minute = now.AddMinutes(1); + if (DeletionDate.HasValue) { - existingSend = ToSendBase(existingSend, sendService); - switch (existingSend.Type) + if (DeletionDate.Value <= nowPlus1Minute) { - case SendType.File: - var fileData = JsonSerializer.Deserialize(existingSend.Data); - fileData.Name = Name; - fileData.Notes = Notes; - existingSend.Data = JsonSerializer.Serialize(fileData, JsonHelpers.IgnoreWritingNull); - break; - case SendType.Text: - existingSend.Data = JsonSerializer.Serialize(ToSendTextData(), JsonHelpers.IgnoreWritingNull); - break; - default: - throw new ArgumentException("Unsupported type: " + nameof(Type) + "."); + throw new BadRequestException("You cannot have a Send with a deletion date in the past. " + + "Adjust the deletion date and try again."); } - return existingSend; - } - - public void ValidateCreation() - { - var now = DateTime.UtcNow; - // Add 1 minute for a sane buffer and client clock float - var nowPlus1Minute = now.AddMinutes(1); - if (ExpirationDate.HasValue && ExpirationDate.Value <= nowPlus1Minute) + if (DeletionDate.Value > now.AddDays(31)) { - throw new BadRequestException("You cannot create a Send that is already expired. " + - "Adjust the expiration date and try again."); + throw new BadRequestException("You cannot have a Send with a deletion date that far " + + "into the future. Adjust the Deletion Date to a value less than 31 days from now " + + "and try again."); } - ValidateEdit(); - } - - public void ValidateEdit() - { - var now = DateTime.UtcNow; - // Add 1 minute for a sane buffer and client clock float - var nowPlus1Minute = now.AddMinutes(1); - if (DeletionDate.HasValue) - { - if (DeletionDate.Value <= nowPlus1Minute) - { - throw new BadRequestException("You cannot have a Send with a deletion date in the past. " + - "Adjust the deletion date and try again."); - } - if (DeletionDate.Value > now.AddDays(31)) - { - throw new BadRequestException("You cannot have a Send with a deletion date that far " + - "into the future. Adjust the Deletion Date to a value less than 31 days from now " + - "and try again."); - } - } - } - - private Send ToSendBase(Send existingSend, ISendService sendService) - { - existingSend.Key = Key; - existingSend.ExpirationDate = ExpirationDate; - existingSend.DeletionDate = DeletionDate.Value; - existingSend.MaxAccessCount = MaxAccessCount; - if (!string.IsNullOrWhiteSpace(Password)) - { - existingSend.Password = sendService.HashPassword(Password); - } - existingSend.Disabled = Disabled.GetValueOrDefault(); - existingSend.HideEmail = HideEmail.GetValueOrDefault(); - return existingSend; - } - - private SendTextData ToSendTextData() - { - return new SendTextData(Name, Notes, Text.Text, Text.Hidden); } } - public class SendWithIdRequestModel : SendRequestModel + private Send ToSendBase(Send existingSend, ISendService sendService) { - [Required] - public Guid? Id { get; set; } + existingSend.Key = Key; + existingSend.ExpirationDate = ExpirationDate; + existingSend.DeletionDate = DeletionDate.Value; + existingSend.MaxAccessCount = MaxAccessCount; + if (!string.IsNullOrWhiteSpace(Password)) + { + existingSend.Password = sendService.HashPassword(Password); + } + existingSend.Disabled = Disabled.GetValueOrDefault(); + existingSend.HideEmail = HideEmail.GetValueOrDefault(); + return existingSend; + } + + private SendTextData ToSendTextData() + { + return new SendTextData(Name, Notes, Text.Text, Text.Hidden); } } + +public class SendWithIdRequestModel : SendRequestModel +{ + [Required] + public Guid? Id { get; set; } +} diff --git a/src/Api/Models/Request/TwoFactorRequestModels.cs b/src/Api/Models/Request/TwoFactorRequestModels.cs index 4ccc209a6..3ce42cdb9 100644 --- a/src/Api/Models/Request/TwoFactorRequestModels.cs +++ b/src/Api/Models/Request/TwoFactorRequestModels.cs @@ -5,270 +5,269 @@ using Bit.Core.Enums; using Bit.Core.Models; using Fido2NetLib; -namespace Bit.Api.Models.Request +namespace Bit.Api.Models.Request; + +public class UpdateTwoFactorAuthenticatorRequestModel : SecretVerificationRequestModel { - public class UpdateTwoFactorAuthenticatorRequestModel : SecretVerificationRequestModel + [Required] + [StringLength(50)] + public string Token { get; set; } + [Required] + [StringLength(50)] + public string Key { get; set; } + + public User ToUser(User extistingUser) { - [Required] - [StringLength(50)] - public string Token { get; set; } - [Required] - [StringLength(50)] - public string Key { get; set; } - - public User ToUser(User extistingUser) + var providers = extistingUser.GetTwoFactorProviders(); + if (providers == null) { - var providers = extistingUser.GetTwoFactorProviders(); - if (providers == null) - { - providers = new Dictionary(); - } - else if (providers.ContainsKey(TwoFactorProviderType.Authenticator)) - { - providers.Remove(TwoFactorProviderType.Authenticator); - } - - providers.Add(TwoFactorProviderType.Authenticator, new TwoFactorProvider - { - MetaData = new Dictionary { ["Key"] = Key }, - Enabled = true - }); - extistingUser.SetTwoFactorProviders(providers); - return extistingUser; + providers = new Dictionary(); } - } - - public class UpdateTwoFactorDuoRequestModel : SecretVerificationRequestModel, IValidatableObject - { - [Required] - [StringLength(50)] - public string IntegrationKey { get; set; } - [Required] - [StringLength(50)] - public string SecretKey { get; set; } - [Required] - [StringLength(50)] - public string Host { get; set; } - - public User ToUser(User extistingUser) + else if (providers.ContainsKey(TwoFactorProviderType.Authenticator)) { - var providers = extistingUser.GetTwoFactorProviders(); - if (providers == null) - { - providers = new Dictionary(); - } - else if (providers.ContainsKey(TwoFactorProviderType.Duo)) - { - providers.Remove(TwoFactorProviderType.Duo); - } - - providers.Add(TwoFactorProviderType.Duo, new TwoFactorProvider - { - MetaData = new Dictionary - { - ["SKey"] = SecretKey, - ["IKey"] = IntegrationKey, - ["Host"] = Host - }, - Enabled = true - }); - extistingUser.SetTwoFactorProviders(providers); - return extistingUser; + providers.Remove(TwoFactorProviderType.Authenticator); } - public Organization ToOrganization(Organization extistingOrg) + providers.Add(TwoFactorProviderType.Authenticator, new TwoFactorProvider { - var providers = extistingOrg.GetTwoFactorProviders(); - if (providers == null) - { - providers = new Dictionary(); - } - else if (providers.ContainsKey(TwoFactorProviderType.OrganizationDuo)) - { - providers.Remove(TwoFactorProviderType.OrganizationDuo); - } - - providers.Add(TwoFactorProviderType.OrganizationDuo, new TwoFactorProvider - { - MetaData = new Dictionary - { - ["SKey"] = SecretKey, - ["IKey"] = IntegrationKey, - ["Host"] = Host - }, - Enabled = true - }); - extistingOrg.SetTwoFactorProviders(providers); - return extistingOrg; - } - - public IEnumerable Validate(ValidationContext validationContext) - { - if (!Core.Utilities.Duo.DuoApi.ValidHost(Host)) - { - yield return new ValidationResult("Host is invalid.", new string[] { nameof(Host) }); - } - } - } - - public class UpdateTwoFactorYubicoOtpRequestModel : SecretVerificationRequestModel, IValidatableObject - { - public string Key1 { get; set; } - public string Key2 { get; set; } - public string Key3 { get; set; } - public string Key4 { get; set; } - public string Key5 { get; set; } - [Required] - public bool? Nfc { get; set; } - - public User ToUser(User extistingUser) - { - var providers = extistingUser.GetTwoFactorProviders(); - if (providers == null) - { - providers = new Dictionary(); - } - else if (providers.ContainsKey(TwoFactorProviderType.YubiKey)) - { - providers.Remove(TwoFactorProviderType.YubiKey); - } - - providers.Add(TwoFactorProviderType.YubiKey, new TwoFactorProvider - { - MetaData = new Dictionary - { - ["Key1"] = FormatKey(Key1), - ["Key2"] = FormatKey(Key2), - ["Key3"] = FormatKey(Key3), - ["Key4"] = FormatKey(Key4), - ["Key5"] = FormatKey(Key5), - ["Nfc"] = Nfc.Value - }, - Enabled = true - }); - extistingUser.SetTwoFactorProviders(providers); - return extistingUser; - } - - private string FormatKey(string keyValue) - { - if (string.IsNullOrWhiteSpace(keyValue)) - { - return null; - } - - return keyValue.Substring(0, 12); - } - - public IEnumerable Validate(ValidationContext validationContext) - { - if (string.IsNullOrWhiteSpace(Key1) && string.IsNullOrWhiteSpace(Key2) && string.IsNullOrWhiteSpace(Key3) && - string.IsNullOrWhiteSpace(Key4) && string.IsNullOrWhiteSpace(Key5)) - { - yield return new ValidationResult("A key is required.", new string[] { nameof(Key1) }); - } - - if (!string.IsNullOrWhiteSpace(Key1) && Key1.Length < 12) - { - yield return new ValidationResult("Key 1 in invalid.", new string[] { nameof(Key1) }); - } - - if (!string.IsNullOrWhiteSpace(Key2) && Key2.Length < 12) - { - yield return new ValidationResult("Key 2 in invalid.", new string[] { nameof(Key2) }); - } - - if (!string.IsNullOrWhiteSpace(Key3) && Key3.Length < 12) - { - yield return new ValidationResult("Key 3 in invalid.", new string[] { nameof(Key3) }); - } - - if (!string.IsNullOrWhiteSpace(Key4) && Key4.Length < 12) - { - yield return new ValidationResult("Key 4 in invalid.", new string[] { nameof(Key4) }); - } - - if (!string.IsNullOrWhiteSpace(Key5) && Key5.Length < 12) - { - yield return new ValidationResult("Key 5 in invalid.", new string[] { nameof(Key5) }); - } - } - } - - public class TwoFactorEmailRequestModel : SecretVerificationRequestModel - { - [Required] - [EmailAddress] - [StringLength(256)] - public string Email { get; set; } - - public string DeviceIdentifier { get; set; } - - public User ToUser(User extistingUser) - { - var providers = extistingUser.GetTwoFactorProviders(); - if (providers == null) - { - providers = new Dictionary(); - } - else if (providers.ContainsKey(TwoFactorProviderType.Email)) - { - providers.Remove(TwoFactorProviderType.Email); - } - - providers.Add(TwoFactorProviderType.Email, new TwoFactorProvider - { - MetaData = new Dictionary { ["Email"] = Email.ToLowerInvariant() }, - Enabled = true - }); - extistingUser.SetTwoFactorProviders(providers); - return extistingUser; - } - } - - public class TwoFactorWebAuthnRequestModel : TwoFactorWebAuthnDeleteRequestModel - { - [Required] - public AuthenticatorAttestationRawResponse DeviceResponse { get; set; } - public string Name { get; set; } - } - - public class TwoFactorWebAuthnDeleteRequestModel : SecretVerificationRequestModel, IValidatableObject - { - [Required] - public int? Id { get; set; } - - public override IEnumerable Validate(ValidationContext validationContext) - { - foreach (var validationResult in base.Validate(validationContext)) - { - yield return validationResult; - } - - if (!Id.HasValue || Id < 0 || Id > 5) - { - yield return new ValidationResult("Invalid Key Id", new string[] { nameof(Id) }); - } - } - } - - public class UpdateTwoFactorEmailRequestModel : TwoFactorEmailRequestModel - { - [Required] - [StringLength(50)] - public string Token { get; set; } - } - - public class TwoFactorProviderRequestModel : SecretVerificationRequestModel - { - [Required] - public TwoFactorProviderType? Type { get; set; } - } - - public class TwoFactorRecoveryRequestModel : TwoFactorEmailRequestModel - { - [Required] - [StringLength(32)] - public string RecoveryCode { get; set; } + MetaData = new Dictionary { ["Key"] = Key }, + Enabled = true + }); + extistingUser.SetTwoFactorProviders(providers); + return extistingUser; } } + +public class UpdateTwoFactorDuoRequestModel : SecretVerificationRequestModel, IValidatableObject +{ + [Required] + [StringLength(50)] + public string IntegrationKey { get; set; } + [Required] + [StringLength(50)] + public string SecretKey { get; set; } + [Required] + [StringLength(50)] + public string Host { get; set; } + + public User ToUser(User extistingUser) + { + var providers = extistingUser.GetTwoFactorProviders(); + if (providers == null) + { + providers = new Dictionary(); + } + else if (providers.ContainsKey(TwoFactorProviderType.Duo)) + { + providers.Remove(TwoFactorProviderType.Duo); + } + + providers.Add(TwoFactorProviderType.Duo, new TwoFactorProvider + { + MetaData = new Dictionary + { + ["SKey"] = SecretKey, + ["IKey"] = IntegrationKey, + ["Host"] = Host + }, + Enabled = true + }); + extistingUser.SetTwoFactorProviders(providers); + return extistingUser; + } + + public Organization ToOrganization(Organization extistingOrg) + { + var providers = extistingOrg.GetTwoFactorProviders(); + if (providers == null) + { + providers = new Dictionary(); + } + else if (providers.ContainsKey(TwoFactorProviderType.OrganizationDuo)) + { + providers.Remove(TwoFactorProviderType.OrganizationDuo); + } + + providers.Add(TwoFactorProviderType.OrganizationDuo, new TwoFactorProvider + { + MetaData = new Dictionary + { + ["SKey"] = SecretKey, + ["IKey"] = IntegrationKey, + ["Host"] = Host + }, + Enabled = true + }); + extistingOrg.SetTwoFactorProviders(providers); + return extistingOrg; + } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (!Core.Utilities.Duo.DuoApi.ValidHost(Host)) + { + yield return new ValidationResult("Host is invalid.", new string[] { nameof(Host) }); + } + } +} + +public class UpdateTwoFactorYubicoOtpRequestModel : SecretVerificationRequestModel, IValidatableObject +{ + public string Key1 { get; set; } + public string Key2 { get; set; } + public string Key3 { get; set; } + public string Key4 { get; set; } + public string Key5 { get; set; } + [Required] + public bool? Nfc { get; set; } + + public User ToUser(User extistingUser) + { + var providers = extistingUser.GetTwoFactorProviders(); + if (providers == null) + { + providers = new Dictionary(); + } + else if (providers.ContainsKey(TwoFactorProviderType.YubiKey)) + { + providers.Remove(TwoFactorProviderType.YubiKey); + } + + providers.Add(TwoFactorProviderType.YubiKey, new TwoFactorProvider + { + MetaData = new Dictionary + { + ["Key1"] = FormatKey(Key1), + ["Key2"] = FormatKey(Key2), + ["Key3"] = FormatKey(Key3), + ["Key4"] = FormatKey(Key4), + ["Key5"] = FormatKey(Key5), + ["Nfc"] = Nfc.Value + }, + Enabled = true + }); + extistingUser.SetTwoFactorProviders(providers); + return extistingUser; + } + + private string FormatKey(string keyValue) + { + if (string.IsNullOrWhiteSpace(keyValue)) + { + return null; + } + + return keyValue.Substring(0, 12); + } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (string.IsNullOrWhiteSpace(Key1) && string.IsNullOrWhiteSpace(Key2) && string.IsNullOrWhiteSpace(Key3) && + string.IsNullOrWhiteSpace(Key4) && string.IsNullOrWhiteSpace(Key5)) + { + yield return new ValidationResult("A key is required.", new string[] { nameof(Key1) }); + } + + if (!string.IsNullOrWhiteSpace(Key1) && Key1.Length < 12) + { + yield return new ValidationResult("Key 1 in invalid.", new string[] { nameof(Key1) }); + } + + if (!string.IsNullOrWhiteSpace(Key2) && Key2.Length < 12) + { + yield return new ValidationResult("Key 2 in invalid.", new string[] { nameof(Key2) }); + } + + if (!string.IsNullOrWhiteSpace(Key3) && Key3.Length < 12) + { + yield return new ValidationResult("Key 3 in invalid.", new string[] { nameof(Key3) }); + } + + if (!string.IsNullOrWhiteSpace(Key4) && Key4.Length < 12) + { + yield return new ValidationResult("Key 4 in invalid.", new string[] { nameof(Key4) }); + } + + if (!string.IsNullOrWhiteSpace(Key5) && Key5.Length < 12) + { + yield return new ValidationResult("Key 5 in invalid.", new string[] { nameof(Key5) }); + } + } +} + +public class TwoFactorEmailRequestModel : SecretVerificationRequestModel +{ + [Required] + [EmailAddress] + [StringLength(256)] + public string Email { get; set; } + + public string DeviceIdentifier { get; set; } + + public User ToUser(User extistingUser) + { + var providers = extistingUser.GetTwoFactorProviders(); + if (providers == null) + { + providers = new Dictionary(); + } + else if (providers.ContainsKey(TwoFactorProviderType.Email)) + { + providers.Remove(TwoFactorProviderType.Email); + } + + providers.Add(TwoFactorProviderType.Email, new TwoFactorProvider + { + MetaData = new Dictionary { ["Email"] = Email.ToLowerInvariant() }, + Enabled = true + }); + extistingUser.SetTwoFactorProviders(providers); + return extistingUser; + } +} + +public class TwoFactorWebAuthnRequestModel : TwoFactorWebAuthnDeleteRequestModel +{ + [Required] + public AuthenticatorAttestationRawResponse DeviceResponse { get; set; } + public string Name { get; set; } +} + +public class TwoFactorWebAuthnDeleteRequestModel : SecretVerificationRequestModel, IValidatableObject +{ + [Required] + public int? Id { get; set; } + + public override IEnumerable Validate(ValidationContext validationContext) + { + foreach (var validationResult in base.Validate(validationContext)) + { + yield return validationResult; + } + + if (!Id.HasValue || Id < 0 || Id > 5) + { + yield return new ValidationResult("Invalid Key Id", new string[] { nameof(Id) }); + } + } +} + +public class UpdateTwoFactorEmailRequestModel : TwoFactorEmailRequestModel +{ + [Required] + [StringLength(50)] + public string Token { get; set; } +} + +public class TwoFactorProviderRequestModel : SecretVerificationRequestModel +{ + [Required] + public TwoFactorProviderType? Type { get; set; } +} + +public class TwoFactorRecoveryRequestModel : TwoFactorEmailRequestModel +{ + [Required] + [StringLength(32)] + public string RecoveryCode { get; set; } +} diff --git a/src/Api/Models/Request/UpdateDomainsRequestModel.cs b/src/Api/Models/Request/UpdateDomainsRequestModel.cs index 0bc3f0385..47c5d05de 100644 --- a/src/Api/Models/Request/UpdateDomainsRequestModel.cs +++ b/src/Api/Models/Request/UpdateDomainsRequestModel.cs @@ -2,19 +2,18 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Api.Models.Request -{ - public class UpdateDomainsRequestModel - { - public IEnumerable> EquivalentDomains { get; set; } - public IEnumerable ExcludedGlobalEquivalentDomains { get; set; } +namespace Bit.Api.Models.Request; - public User ToUser(User existingUser) - { - existingUser.EquivalentDomains = EquivalentDomains != null ? JsonSerializer.Serialize(EquivalentDomains) : null; - existingUser.ExcludedGlobalEquivalentDomains = ExcludedGlobalEquivalentDomains != null ? - JsonSerializer.Serialize(ExcludedGlobalEquivalentDomains) : null; - return existingUser; - } +public class UpdateDomainsRequestModel +{ + public IEnumerable> EquivalentDomains { get; set; } + public IEnumerable ExcludedGlobalEquivalentDomains { get; set; } + + public User ToUser(User existingUser) + { + existingUser.EquivalentDomains = EquivalentDomains != null ? JsonSerializer.Serialize(EquivalentDomains) : null; + existingUser.ExcludedGlobalEquivalentDomains = ExcludedGlobalEquivalentDomains != null ? + JsonSerializer.Serialize(ExcludedGlobalEquivalentDomains) : null; + return existingUser; } } diff --git a/src/Api/Models/Response/ApiKeyResponseModel.cs b/src/Api/Models/Response/ApiKeyResponseModel.cs index 3987e22d5..0661b17bc 100644 --- a/src/Api/Models/Response/ApiKeyResponseModel.cs +++ b/src/Api/Models/Response/ApiKeyResponseModel.cs @@ -1,33 +1,32 @@ using Bit.Core.Entities; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response +namespace Bit.Api.Models.Response; + +public class ApiKeyResponseModel : ResponseModel { - public class ApiKeyResponseModel : ResponseModel + public ApiKeyResponseModel(OrganizationApiKey organizationApiKey, string obj = "apiKey") + : base(obj) { - public ApiKeyResponseModel(OrganizationApiKey organizationApiKey, string obj = "apiKey") - : base(obj) + if (organizationApiKey == null) { - if (organizationApiKey == null) - { - throw new ArgumentNullException(nameof(organizationApiKey)); - } - ApiKey = organizationApiKey.ApiKey; - RevisionDate = organizationApiKey.RevisionDate; + throw new ArgumentNullException(nameof(organizationApiKey)); } - - public ApiKeyResponseModel(User user, string obj = "apiKey") - : base(obj) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - ApiKey = user.ApiKey; - RevisionDate = user.RevisionDate; - } - - public string ApiKey { get; set; } - public DateTime RevisionDate { get; set; } + ApiKey = organizationApiKey.ApiKey; + RevisionDate = organizationApiKey.RevisionDate; } + + public ApiKeyResponseModel(User user, string obj = "apiKey") + : base(obj) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + ApiKey = user.ApiKey; + RevisionDate = user.RevisionDate; + } + + public string ApiKey { get; set; } + public DateTime RevisionDate { get; set; } } diff --git a/src/Api/Models/Response/AttachmentResponseModel.cs b/src/Api/Models/Response/AttachmentResponseModel.cs index 5659cb535..018cdd650 100644 --- a/src/Api/Models/Response/AttachmentResponseModel.cs +++ b/src/Api/Models/Response/AttachmentResponseModel.cs @@ -5,49 +5,48 @@ using Bit.Core.Models.Data; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Api.Models.Response +namespace Bit.Api.Models.Response; + +public class AttachmentResponseModel : ResponseModel { - public class AttachmentResponseModel : ResponseModel + public AttachmentResponseModel(AttachmentResponseData data) : base("attachment") { - public AttachmentResponseModel(AttachmentResponseData data) : base("attachment") + Id = data.Id; + Url = data.Url; + FileName = data.Data.FileName; + Key = data.Data.Key; + Size = data.Data.Size; + SizeName = CoreHelpers.ReadableBytesSize(data.Data.Size); + } + + public AttachmentResponseModel(string id, CipherAttachment.MetaData data, Cipher cipher, + IGlobalSettings globalSettings) + : base("attachment") + { + Id = id; + Url = $"{globalSettings.Attachment.BaseUrl}/{cipher.Id}/{id}"; + FileName = data.FileName; + Key = data.Key; + Size = data.Size; + SizeName = CoreHelpers.ReadableBytesSize(data.Size); + } + + public string Id { get; set; } + public string Url { get; set; } + public string FileName { get; set; } + public string Key { get; set; } + [JsonNumberHandling(JsonNumberHandling.AllowReadingFromString | JsonNumberHandling.WriteAsString)] + public long Size { get; set; } + public string SizeName { get; set; } + + public static IEnumerable FromCipher(Cipher cipher, IGlobalSettings globalSettings) + { + var attachments = cipher.GetAttachments(); + if (attachments == null) { - Id = data.Id; - Url = data.Url; - FileName = data.Data.FileName; - Key = data.Data.Key; - Size = data.Data.Size; - SizeName = CoreHelpers.ReadableBytesSize(data.Data.Size); + return null; } - public AttachmentResponseModel(string id, CipherAttachment.MetaData data, Cipher cipher, - IGlobalSettings globalSettings) - : base("attachment") - { - Id = id; - Url = $"{globalSettings.Attachment.BaseUrl}/{cipher.Id}/{id}"; - FileName = data.FileName; - Key = data.Key; - Size = data.Size; - SizeName = CoreHelpers.ReadableBytesSize(data.Size); - } - - public string Id { get; set; } - public string Url { get; set; } - public string FileName { get; set; } - public string Key { get; set; } - [JsonNumberHandling(JsonNumberHandling.AllowReadingFromString | JsonNumberHandling.WriteAsString)] - public long Size { get; set; } - public string SizeName { get; set; } - - public static IEnumerable FromCipher(Cipher cipher, IGlobalSettings globalSettings) - { - var attachments = cipher.GetAttachments(); - if (attachments == null) - { - return null; - } - - return attachments.Select(a => new AttachmentResponseModel(a.Key, a.Value, cipher, globalSettings)); - } + return attachments.Select(a => new AttachmentResponseModel(a.Key, a.Value, cipher, globalSettings)); } } diff --git a/src/Api/Models/Response/AttachmentUploadDataResponseModel.cs b/src/Api/Models/Response/AttachmentUploadDataResponseModel.cs index 7acc5715c..1c9a5d2a7 100644 --- a/src/Api/Models/Response/AttachmentUploadDataResponseModel.cs +++ b/src/Api/Models/Response/AttachmentUploadDataResponseModel.cs @@ -1,16 +1,15 @@ using Bit.Core.Enums; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response -{ - public class AttachmentUploadDataResponseModel : ResponseModel - { - public string AttachmentId { get; set; } - public string Url { get; set; } - public FileUploadType FileUploadType { get; set; } - public CipherResponseModel CipherResponse { get; set; } - public CipherMiniResponseModel CipherMiniResponse { get; set; } +namespace Bit.Api.Models.Response; - public AttachmentUploadDataResponseModel() : base("attachment-fileUpload") { } - } +public class AttachmentUploadDataResponseModel : ResponseModel +{ + public string AttachmentId { get; set; } + public string Url { get; set; } + public FileUploadType FileUploadType { get; set; } + public CipherResponseModel CipherResponse { get; set; } + public CipherMiniResponseModel CipherMiniResponse { get; set; } + + public AttachmentUploadDataResponseModel() : base("attachment-fileUpload") { } } diff --git a/src/Api/Models/Response/BillingHistoryResponseModel.cs b/src/Api/Models/Response/BillingHistoryResponseModel.cs index 892a6530d..e0e85f069 100644 --- a/src/Api/Models/Response/BillingHistoryResponseModel.cs +++ b/src/Api/Models/Response/BillingHistoryResponseModel.cs @@ -1,17 +1,16 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Business; -namespace Bit.Api.Models.Response +namespace Bit.Api.Models.Response; + +public class BillingHistoryResponseModel : ResponseModel { - public class BillingHistoryResponseModel : ResponseModel + public BillingHistoryResponseModel(BillingInfo billing) + : base("billingHistory") { - public BillingHistoryResponseModel(BillingInfo billing) - : base("billingHistory") - { - Transactions = billing.Transactions?.Select(t => new BillingTransaction(t)); - Invoices = billing.Invoices?.Select(i => new BillingInvoice(i)); - } - public IEnumerable Invoices { get; set; } - public IEnumerable Transactions { get; set; } + Transactions = billing.Transactions?.Select(t => new BillingTransaction(t)); + Invoices = billing.Invoices?.Select(i => new BillingInvoice(i)); } + public IEnumerable Invoices { get; set; } + public IEnumerable Transactions { get; set; } } diff --git a/src/Api/Models/Response/BillingPaymentResponseModel.cs b/src/Api/Models/Response/BillingPaymentResponseModel.cs index 12c14c4d6..dcc004613 100644 --- a/src/Api/Models/Response/BillingPaymentResponseModel.cs +++ b/src/Api/Models/Response/BillingPaymentResponseModel.cs @@ -1,18 +1,17 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Business; -namespace Bit.Api.Models.Response -{ - public class BillingPaymentResponseModel : ResponseModel - { - public BillingPaymentResponseModel(BillingInfo billing) - : base("billingPayment") - { - Balance = billing.Balance; - PaymentSource = billing.PaymentSource != null ? new BillingSource(billing.PaymentSource) : null; - } +namespace Bit.Api.Models.Response; - public decimal Balance { get; set; } - public BillingSource PaymentSource { get; set; } +public class BillingPaymentResponseModel : ResponseModel +{ + public BillingPaymentResponseModel(BillingInfo billing) + : base("billingPayment") + { + Balance = billing.Balance; + PaymentSource = billing.PaymentSource != null ? new BillingSource(billing.PaymentSource) : null; } + + public decimal Balance { get; set; } + public BillingSource PaymentSource { get; set; } } diff --git a/src/Api/Models/Response/BillingResponseModel.cs b/src/Api/Models/Response/BillingResponseModel.cs index 6e2930e1b..c5232242f 100644 --- a/src/Api/Models/Response/BillingResponseModel.cs +++ b/src/Api/Models/Response/BillingResponseModel.cs @@ -2,82 +2,81 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Business; -namespace Bit.Api.Models.Response +namespace Bit.Api.Models.Response; + +public class BillingResponseModel : ResponseModel { - public class BillingResponseModel : ResponseModel + public BillingResponseModel(BillingInfo billing) + : base("billing") { - public BillingResponseModel(BillingInfo billing) - : base("billing") - { - Balance = billing.Balance; - PaymentSource = billing.PaymentSource != null ? new BillingSource(billing.PaymentSource) : null; - Transactions = billing.Transactions?.Select(t => new BillingTransaction(t)); - Invoices = billing.Invoices?.Select(i => new BillingInvoice(i)); - } - - public decimal Balance { get; set; } - public BillingSource PaymentSource { get; set; } - public IEnumerable Invoices { get; set; } - public IEnumerable Transactions { get; set; } + Balance = billing.Balance; + PaymentSource = billing.PaymentSource != null ? new BillingSource(billing.PaymentSource) : null; + Transactions = billing.Transactions?.Select(t => new BillingTransaction(t)); + Invoices = billing.Invoices?.Select(i => new BillingInvoice(i)); } - public class BillingSource - { - public BillingSource(BillingInfo.BillingSource source) - { - Type = source.Type; - CardBrand = source.CardBrand; - Description = source.Description; - NeedsVerification = source.NeedsVerification; - } - - public PaymentMethodType Type { get; set; } - public string CardBrand { get; set; } - public string Description { get; set; } - public bool NeedsVerification { get; set; } - } - - public class BillingInvoice - { - public BillingInvoice(BillingInfo.BillingInvoice inv) - { - Amount = inv.Amount; - Date = inv.Date; - Url = inv.Url; - PdfUrl = inv.PdfUrl; - Number = inv.Number; - Paid = inv.Paid; - } - - public decimal Amount { get; set; } - public DateTime? Date { get; set; } - public string Url { get; set; } - public string PdfUrl { get; set; } - public string Number { get; set; } - public bool Paid { get; set; } - } - - public class BillingTransaction - { - public BillingTransaction(BillingInfo.BillingTransaction transaction) - { - CreatedDate = transaction.CreatedDate; - Amount = transaction.Amount; - Refunded = transaction.Refunded; - RefundedAmount = transaction.RefundedAmount; - PartiallyRefunded = transaction.PartiallyRefunded; - Type = transaction.Type; - PaymentMethodType = transaction.PaymentMethodType; - Details = transaction.Details; - } - - public DateTime CreatedDate { get; set; } - public decimal Amount { get; set; } - public bool? Refunded { get; set; } - public bool? PartiallyRefunded { get; set; } - public decimal? RefundedAmount { get; set; } - public TransactionType Type { get; set; } - public PaymentMethodType? PaymentMethodType { get; set; } - public string Details { get; set; } - } + public decimal Balance { get; set; } + public BillingSource PaymentSource { get; set; } + public IEnumerable Invoices { get; set; } + public IEnumerable Transactions { get; set; } +} + +public class BillingSource +{ + public BillingSource(BillingInfo.BillingSource source) + { + Type = source.Type; + CardBrand = source.CardBrand; + Description = source.Description; + NeedsVerification = source.NeedsVerification; + } + + public PaymentMethodType Type { get; set; } + public string CardBrand { get; set; } + public string Description { get; set; } + public bool NeedsVerification { get; set; } +} + +public class BillingInvoice +{ + public BillingInvoice(BillingInfo.BillingInvoice inv) + { + Amount = inv.Amount; + Date = inv.Date; + Url = inv.Url; + PdfUrl = inv.PdfUrl; + Number = inv.Number; + Paid = inv.Paid; + } + + public decimal Amount { get; set; } + public DateTime? Date { get; set; } + public string Url { get; set; } + public string PdfUrl { get; set; } + public string Number { get; set; } + public bool Paid { get; set; } +} + +public class BillingTransaction +{ + public BillingTransaction(BillingInfo.BillingTransaction transaction) + { + CreatedDate = transaction.CreatedDate; + Amount = transaction.Amount; + Refunded = transaction.Refunded; + RefundedAmount = transaction.RefundedAmount; + PartiallyRefunded = transaction.PartiallyRefunded; + Type = transaction.Type; + PaymentMethodType = transaction.PaymentMethodType; + Details = transaction.Details; + } + + public DateTime CreatedDate { get; set; } + public decimal Amount { get; set; } + public bool? Refunded { get; set; } + public bool? PartiallyRefunded { get; set; } + public decimal? RefundedAmount { get; set; } + public TransactionType Type { get; set; } + public PaymentMethodType? PaymentMethodType { get; set; } + public string Details { get; set; } } diff --git a/src/Api/Models/Response/CipherResponseModel.cs b/src/Api/Models/Response/CipherResponseModel.cs index 5edc27145..9b0d95894 100644 --- a/src/Api/Models/Response/CipherResponseModel.cs +++ b/src/Api/Models/Response/CipherResponseModel.cs @@ -6,142 +6,141 @@ using Bit.Core.Models.Data; using Bit.Core.Settings; using Core.Models.Data; -namespace Bit.Api.Models.Response +namespace Bit.Api.Models.Response; + +public class CipherMiniResponseModel : ResponseModel { - public class CipherMiniResponseModel : ResponseModel + public CipherMiniResponseModel(Cipher cipher, IGlobalSettings globalSettings, bool orgUseTotp, string obj = "cipherMini") + : base(obj) { - public CipherMiniResponseModel(Cipher cipher, IGlobalSettings globalSettings, bool orgUseTotp, string obj = "cipherMini") - : base(obj) + if (cipher == null) { - if (cipher == null) - { - throw new ArgumentNullException(nameof(cipher)); - } - - Id = cipher.Id.ToString(); - Type = cipher.Type; - - CipherData cipherData; - switch (cipher.Type) - { - case CipherType.Login: - var loginData = JsonSerializer.Deserialize(cipher.Data); - cipherData = loginData; - Data = loginData; - Login = new CipherLoginModel(loginData); - break; - case CipherType.SecureNote: - var secureNoteData = JsonSerializer.Deserialize(cipher.Data); - Data = secureNoteData; - cipherData = secureNoteData; - SecureNote = new CipherSecureNoteModel(secureNoteData); - break; - case CipherType.Card: - var cardData = JsonSerializer.Deserialize(cipher.Data); - Data = cardData; - cipherData = cardData; - Card = new CipherCardModel(cardData); - break; - case CipherType.Identity: - var identityData = JsonSerializer.Deserialize(cipher.Data); - Data = identityData; - cipherData = identityData; - Identity = new CipherIdentityModel(identityData); - break; - default: - throw new ArgumentException("Unsupported " + nameof(Type) + "."); - } - - Name = cipherData.Name; - Notes = cipherData.Notes; - Fields = cipherData.Fields?.Select(f => new CipherFieldModel(f)); - PasswordHistory = cipherData.PasswordHistory?.Select(ph => new CipherPasswordHistoryModel(ph)); - RevisionDate = cipher.RevisionDate; - OrganizationId = cipher.OrganizationId?.ToString(); - Attachments = AttachmentResponseModel.FromCipher(cipher, globalSettings); - OrganizationUseTotp = orgUseTotp; - DeletedDate = cipher.DeletedDate; - Reprompt = cipher.Reprompt.GetValueOrDefault(CipherRepromptType.None); + throw new ArgumentNullException(nameof(cipher)); } - public string Id { get; set; } - public string OrganizationId { get; set; } - public CipherType Type { get; set; } - public dynamic Data { get; set; } - public string Name { get; set; } - public string Notes { get; set; } - public CipherLoginModel Login { get; set; } - public CipherCardModel Card { get; set; } - public CipherIdentityModel Identity { get; set; } - public CipherSecureNoteModel SecureNote { get; set; } - public IEnumerable Fields { get; set; } - public IEnumerable PasswordHistory { get; set; } - public IEnumerable Attachments { get; set; } - public bool OrganizationUseTotp { get; set; } - public DateTime RevisionDate { get; set; } - public DateTime? DeletedDate { get; set; } - public CipherRepromptType Reprompt { get; set; } + Id = cipher.Id.ToString(); + Type = cipher.Type; + + CipherData cipherData; + switch (cipher.Type) + { + case CipherType.Login: + var loginData = JsonSerializer.Deserialize(cipher.Data); + cipherData = loginData; + Data = loginData; + Login = new CipherLoginModel(loginData); + break; + case CipherType.SecureNote: + var secureNoteData = JsonSerializer.Deserialize(cipher.Data); + Data = secureNoteData; + cipherData = secureNoteData; + SecureNote = new CipherSecureNoteModel(secureNoteData); + break; + case CipherType.Card: + var cardData = JsonSerializer.Deserialize(cipher.Data); + Data = cardData; + cipherData = cardData; + Card = new CipherCardModel(cardData); + break; + case CipherType.Identity: + var identityData = JsonSerializer.Deserialize(cipher.Data); + Data = identityData; + cipherData = identityData; + Identity = new CipherIdentityModel(identityData); + break; + default: + throw new ArgumentException("Unsupported " + nameof(Type) + "."); + } + + Name = cipherData.Name; + Notes = cipherData.Notes; + Fields = cipherData.Fields?.Select(f => new CipherFieldModel(f)); + PasswordHistory = cipherData.PasswordHistory?.Select(ph => new CipherPasswordHistoryModel(ph)); + RevisionDate = cipher.RevisionDate; + OrganizationId = cipher.OrganizationId?.ToString(); + Attachments = AttachmentResponseModel.FromCipher(cipher, globalSettings); + OrganizationUseTotp = orgUseTotp; + DeletedDate = cipher.DeletedDate; + Reprompt = cipher.Reprompt.GetValueOrDefault(CipherRepromptType.None); } - public class CipherResponseModel : CipherMiniResponseModel - { - public CipherResponseModel(CipherDetails cipher, IGlobalSettings globalSettings, string obj = "cipher") - : base(cipher, globalSettings, cipher.OrganizationUseTotp, obj) - { - FolderId = cipher.FolderId?.ToString(); - Favorite = cipher.Favorite; - Edit = cipher.Edit; - ViewPassword = cipher.ViewPassword; - } - - public string FolderId { get; set; } - public bool Favorite { get; set; } - public bool Edit { get; set; } - public bool ViewPassword { get; set; } - } - - public class CipherDetailsResponseModel : CipherResponseModel - { - public CipherDetailsResponseModel(CipherDetails cipher, GlobalSettings globalSettings, - IDictionary> collectionCiphers, string obj = "cipherDetails") - : base(cipher, globalSettings, obj) - { - if (collectionCiphers?.ContainsKey(cipher.Id) ?? false) - { - CollectionIds = collectionCiphers[cipher.Id].Select(c => c.CollectionId); - } - else - { - CollectionIds = new Guid[] { }; - } - } - - public CipherDetailsResponseModel(CipherDetails cipher, GlobalSettings globalSettings, - IEnumerable collectionCiphers, string obj = "cipherDetails") - : base(cipher, globalSettings, obj) - { - CollectionIds = collectionCiphers?.Select(c => c.CollectionId) ?? new List(); - } - - public IEnumerable CollectionIds { get; set; } - } - - public class CipherMiniDetailsResponseModel : CipherMiniResponseModel - { - public CipherMiniDetailsResponseModel(Cipher cipher, GlobalSettings globalSettings, - IDictionary> collectionCiphers, bool orgUseTotp, string obj = "cipherMiniDetails") - : base(cipher, globalSettings, orgUseTotp, obj) - { - if (collectionCiphers?.ContainsKey(cipher.Id) ?? false) - { - CollectionIds = collectionCiphers[cipher.Id].Select(c => c.CollectionId); - } - else - { - CollectionIds = new Guid[] { }; - } - } - - public IEnumerable CollectionIds { get; set; } - } + public string Id { get; set; } + public string OrganizationId { get; set; } + public CipherType Type { get; set; } + public dynamic Data { get; set; } + public string Name { get; set; } + public string Notes { get; set; } + public CipherLoginModel Login { get; set; } + public CipherCardModel Card { get; set; } + public CipherIdentityModel Identity { get; set; } + public CipherSecureNoteModel SecureNote { get; set; } + public IEnumerable Fields { get; set; } + public IEnumerable PasswordHistory { get; set; } + public IEnumerable Attachments { get; set; } + public bool OrganizationUseTotp { get; set; } + public DateTime RevisionDate { get; set; } + public DateTime? DeletedDate { get; set; } + public CipherRepromptType Reprompt { get; set; } +} + +public class CipherResponseModel : CipherMiniResponseModel +{ + public CipherResponseModel(CipherDetails cipher, IGlobalSettings globalSettings, string obj = "cipher") + : base(cipher, globalSettings, cipher.OrganizationUseTotp, obj) + { + FolderId = cipher.FolderId?.ToString(); + Favorite = cipher.Favorite; + Edit = cipher.Edit; + ViewPassword = cipher.ViewPassword; + } + + public string FolderId { get; set; } + public bool Favorite { get; set; } + public bool Edit { get; set; } + public bool ViewPassword { get; set; } +} + +public class CipherDetailsResponseModel : CipherResponseModel +{ + public CipherDetailsResponseModel(CipherDetails cipher, GlobalSettings globalSettings, + IDictionary> collectionCiphers, string obj = "cipherDetails") + : base(cipher, globalSettings, obj) + { + if (collectionCiphers?.ContainsKey(cipher.Id) ?? false) + { + CollectionIds = collectionCiphers[cipher.Id].Select(c => c.CollectionId); + } + else + { + CollectionIds = new Guid[] { }; + } + } + + public CipherDetailsResponseModel(CipherDetails cipher, GlobalSettings globalSettings, + IEnumerable collectionCiphers, string obj = "cipherDetails") + : base(cipher, globalSettings, obj) + { + CollectionIds = collectionCiphers?.Select(c => c.CollectionId) ?? new List(); + } + + public IEnumerable CollectionIds { get; set; } +} + +public class CipherMiniDetailsResponseModel : CipherMiniResponseModel +{ + public CipherMiniDetailsResponseModel(Cipher cipher, GlobalSettings globalSettings, + IDictionary> collectionCiphers, bool orgUseTotp, string obj = "cipherMiniDetails") + : base(cipher, globalSettings, orgUseTotp, obj) + { + if (collectionCiphers?.ContainsKey(cipher.Id) ?? false) + { + CollectionIds = collectionCiphers[cipher.Id].Select(c => c.CollectionId); + } + else + { + CollectionIds = new Guid[] { }; + } + } + + public IEnumerable CollectionIds { get; set; } } diff --git a/src/Api/Models/Response/CollectionResponseModel.cs b/src/Api/Models/Response/CollectionResponseModel.cs index 5ac923a9d..aa56402c0 100644 --- a/src/Api/Models/Response/CollectionResponseModel.cs +++ b/src/Api/Models/Response/CollectionResponseModel.cs @@ -2,51 +2,50 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Response +namespace Bit.Api.Models.Response; + +public class CollectionResponseModel : ResponseModel { - public class CollectionResponseModel : ResponseModel + public CollectionResponseModel(Collection collection, string obj = "collection") + : base(obj) { - public CollectionResponseModel(Collection collection, string obj = "collection") - : base(obj) + if (collection == null) { - if (collection == null) - { - throw new ArgumentNullException(nameof(collection)); - } - - Id = collection.Id.ToString(); - OrganizationId = collection.OrganizationId.ToString(); - Name = collection.Name; - ExternalId = collection.ExternalId; + throw new ArgumentNullException(nameof(collection)); } - public string Id { get; set; } - public string OrganizationId { get; set; } - public string Name { get; set; } - public string ExternalId { get; set; } + Id = collection.Id.ToString(); + OrganizationId = collection.OrganizationId.ToString(); + Name = collection.Name; + ExternalId = collection.ExternalId; } - public class CollectionDetailsResponseModel : CollectionResponseModel - { - public CollectionDetailsResponseModel(CollectionDetails collectionDetails) - : base(collectionDetails, "collectionDetails") - { - ReadOnly = collectionDetails.ReadOnly; - HidePasswords = collectionDetails.HidePasswords; - } - - public bool ReadOnly { get; set; } - public bool HidePasswords { get; set; } - } - - public class CollectionGroupDetailsResponseModel : CollectionResponseModel - { - public CollectionGroupDetailsResponseModel(Collection collection, IEnumerable groups) - : base(collection, "collectionGroupDetails") - { - Groups = groups.Select(g => new SelectionReadOnlyResponseModel(g)); - } - - public IEnumerable Groups { get; set; } - } + public string Id { get; set; } + public string OrganizationId { get; set; } + public string Name { get; set; } + public string ExternalId { get; set; } +} + +public class CollectionDetailsResponseModel : CollectionResponseModel +{ + public CollectionDetailsResponseModel(CollectionDetails collectionDetails) + : base(collectionDetails, "collectionDetails") + { + ReadOnly = collectionDetails.ReadOnly; + HidePasswords = collectionDetails.HidePasswords; + } + + public bool ReadOnly { get; set; } + public bool HidePasswords { get; set; } +} + +public class CollectionGroupDetailsResponseModel : CollectionResponseModel +{ + public CollectionGroupDetailsResponseModel(Collection collection, IEnumerable groups) + : base(collection, "collectionGroupDetails") + { + Groups = groups.Select(g => new SelectionReadOnlyResponseModel(g)); + } + + public IEnumerable Groups { get; set; } } diff --git a/src/Api/Models/Response/DeviceResponseModel.cs b/src/Api/Models/Response/DeviceResponseModel.cs index e25562cbb..e88dff9fa 100644 --- a/src/Api/Models/Response/DeviceResponseModel.cs +++ b/src/Api/Models/Response/DeviceResponseModel.cs @@ -2,29 +2,28 @@ using Bit.Core.Enums; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response -{ - public class DeviceResponseModel : ResponseModel - { - public DeviceResponseModel(Device device) - : base("device") - { - if (device == null) - { - throw new ArgumentNullException(nameof(device)); - } +namespace Bit.Api.Models.Response; - Id = device.Id.ToString(); - Name = device.Name; - Type = device.Type; - Identifier = device.Identifier; - CreationDate = device.CreationDate; +public class DeviceResponseModel : ResponseModel +{ + public DeviceResponseModel(Device device) + : base("device") + { + if (device == null) + { + throw new ArgumentNullException(nameof(device)); } - public string Id { get; set; } - public string Name { get; set; } - public DeviceType Type { get; set; } - public string Identifier { get; set; } - public DateTime CreationDate { get; set; } + Id = device.Id.ToString(); + Name = device.Name; + Type = device.Type; + Identifier = device.Identifier; + CreationDate = device.CreationDate; } + + public string Id { get; set; } + public string Name { get; set; } + public DeviceType Type { get; set; } + public string Identifier { get; set; } + public DateTime CreationDate { get; set; } } diff --git a/src/Api/Models/Response/DeviceVerificationResponseModel.cs b/src/Api/Models/Response/DeviceVerificationResponseModel.cs index 1f4754761..0358ff777 100644 --- a/src/Api/Models/Response/DeviceVerificationResponseModel.cs +++ b/src/Api/Models/Response/DeviceVerificationResponseModel.cs @@ -1,17 +1,16 @@ using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response -{ - public class DeviceVerificationResponseModel : ResponseModel - { - public DeviceVerificationResponseModel(bool isDeviceVerificationSectionEnabled, bool unknownDeviceVerificationEnabled) - : base("deviceVerificationSettings") - { - IsDeviceVerificationSectionEnabled = isDeviceVerificationSectionEnabled; - UnknownDeviceVerificationEnabled = unknownDeviceVerificationEnabled; - } +namespace Bit.Api.Models.Response; - public bool IsDeviceVerificationSectionEnabled { get; } - public bool UnknownDeviceVerificationEnabled { get; } +public class DeviceVerificationResponseModel : ResponseModel +{ + public DeviceVerificationResponseModel(bool isDeviceVerificationSectionEnabled, bool unknownDeviceVerificationEnabled) + : base("deviceVerificationSettings") + { + IsDeviceVerificationSectionEnabled = isDeviceVerificationSectionEnabled; + UnknownDeviceVerificationEnabled = unknownDeviceVerificationEnabled; } + + public bool IsDeviceVerificationSectionEnabled { get; } + public bool UnknownDeviceVerificationEnabled { get; } } diff --git a/src/Api/Models/Response/DomainsResponseModel.cs b/src/Api/Models/Response/DomainsResponseModel.cs index fd6ea46b6..b7f102845 100644 --- a/src/Api/Models/Response/DomainsResponseModel.cs +++ b/src/Api/Models/Response/DomainsResponseModel.cs @@ -3,54 +3,53 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response +namespace Bit.Api.Models.Response; + +public class DomainsResponseModel : ResponseModel { - public class DomainsResponseModel : ResponseModel + public DomainsResponseModel(User user, bool excluded = true) + : base("domains") { - public DomainsResponseModel(User user, bool excluded = true) - : base("domains") + if (user == null) { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - EquivalentDomains = user.EquivalentDomains != null ? - JsonSerializer.Deserialize>>(user.EquivalentDomains) : null; - - var excludedGlobalEquivalentDomains = user.ExcludedGlobalEquivalentDomains != null ? - JsonSerializer.Deserialize>(user.ExcludedGlobalEquivalentDomains) : - new List(); - var globalDomains = new List(); - var domainsToInclude = excluded ? Core.Utilities.StaticStore.GlobalDomains : - Core.Utilities.StaticStore.GlobalDomains.Where(d => !excludedGlobalEquivalentDomains.Contains(d.Key)); - foreach (var domain in domainsToInclude) - { - globalDomains.Add(new GlobalDomains(domain.Key, domain.Value, excludedGlobalEquivalentDomains, excluded)); - } - GlobalEquivalentDomains = !globalDomains.Any() ? null : globalDomains; + throw new ArgumentNullException(nameof(user)); } - public IEnumerable> EquivalentDomains { get; set; } - public IEnumerable GlobalEquivalentDomains { get; set; } + EquivalentDomains = user.EquivalentDomains != null ? + JsonSerializer.Deserialize>>(user.EquivalentDomains) : null; - - public class GlobalDomains + var excludedGlobalEquivalentDomains = user.ExcludedGlobalEquivalentDomains != null ? + JsonSerializer.Deserialize>(user.ExcludedGlobalEquivalentDomains) : + new List(); + var globalDomains = new List(); + var domainsToInclude = excluded ? Core.Utilities.StaticStore.GlobalDomains : + Core.Utilities.StaticStore.GlobalDomains.Where(d => !excludedGlobalEquivalentDomains.Contains(d.Key)); + foreach (var domain in domainsToInclude) { - public GlobalDomains( - GlobalEquivalentDomainsType globalDomain, - IEnumerable domains, - IEnumerable excludedDomains, - bool excluded) - { - Type = globalDomain; - Domains = domains; - Excluded = excluded && (excludedDomains?.Contains(globalDomain) ?? false); - } - - public GlobalEquivalentDomainsType Type { get; set; } - public IEnumerable Domains { get; set; } - public bool Excluded { get; set; } + globalDomains.Add(new GlobalDomains(domain.Key, domain.Value, excludedGlobalEquivalentDomains, excluded)); } + GlobalEquivalentDomains = !globalDomains.Any() ? null : globalDomains; + } + + public IEnumerable> EquivalentDomains { get; set; } + public IEnumerable GlobalEquivalentDomains { get; set; } + + + public class GlobalDomains + { + public GlobalDomains( + GlobalEquivalentDomainsType globalDomain, + IEnumerable domains, + IEnumerable excludedDomains, + bool excluded) + { + Type = globalDomain; + Domains = domains; + Excluded = excluded && (excludedDomains?.Contains(globalDomain) ?? false); + } + + public GlobalEquivalentDomainsType Type { get; set; } + public IEnumerable Domains { get; set; } + public bool Excluded { get; set; } } } diff --git a/src/Api/Models/Response/EmergencyAccessResponseModel.cs b/src/Api/Models/Response/EmergencyAccessResponseModel.cs index 16d255e92..ec8dbd1ee 100644 --- a/src/Api/Models/Response/EmergencyAccessResponseModel.cs +++ b/src/Api/Models/Response/EmergencyAccessResponseModel.cs @@ -5,114 +5,113 @@ using Bit.Core.Models.Data; using Bit.Core.Settings; using Core.Models.Data; -namespace Bit.Api.Models.Response +namespace Bit.Api.Models.Response; + +public class EmergencyAccessResponseModel : ResponseModel { - public class EmergencyAccessResponseModel : ResponseModel + public EmergencyAccessResponseModel(EmergencyAccess emergencyAccess, string obj = "emergencyAccess") : base(obj) { - public EmergencyAccessResponseModel(EmergencyAccess emergencyAccess, string obj = "emergencyAccess") : base(obj) + if (emergencyAccess == null) { - if (emergencyAccess == null) - { - throw new ArgumentNullException(nameof(emergencyAccess)); - } - - Id = emergencyAccess.Id.ToString(); - Status = emergencyAccess.Status; - Type = emergencyAccess.Type; - WaitTimeDays = emergencyAccess.WaitTimeDays; + throw new ArgumentNullException(nameof(emergencyAccess)); } - public EmergencyAccessResponseModel(EmergencyAccessDetails emergencyAccess, string obj = "emergencyAccess") : base(obj) - { - if (emergencyAccess == null) - { - throw new ArgumentNullException(nameof(emergencyAccess)); - } - - Id = emergencyAccess.Id.ToString(); - Status = emergencyAccess.Status; - Type = emergencyAccess.Type; - WaitTimeDays = emergencyAccess.WaitTimeDays; - } - - public string Id { get; private set; } - public EmergencyAccessStatusType Status { get; private set; } - public EmergencyAccessType Type { get; private set; } - public int WaitTimeDays { get; private set; } + Id = emergencyAccess.Id.ToString(); + Status = emergencyAccess.Status; + Type = emergencyAccess.Type; + WaitTimeDays = emergencyAccess.WaitTimeDays; } - public class EmergencyAccessGranteeDetailsResponseModel : EmergencyAccessResponseModel + public EmergencyAccessResponseModel(EmergencyAccessDetails emergencyAccess, string obj = "emergencyAccess") : base(obj) { - public EmergencyAccessGranteeDetailsResponseModel(EmergencyAccessDetails emergencyAccess) - : base(emergencyAccess, "emergencyAccessGranteeDetails") + if (emergencyAccess == null) { - if (emergencyAccess == null) - { - throw new ArgumentNullException(nameof(emergencyAccess)); - } - - GranteeId = emergencyAccess.GranteeId.ToString(); - Email = emergencyAccess.GranteeEmail; - Name = emergencyAccess.GranteeName; + throw new ArgumentNullException(nameof(emergencyAccess)); } - public string GranteeId { get; private set; } - public string Name { get; private set; } - public string Email { get; private set; } + Id = emergencyAccess.Id.ToString(); + Status = emergencyAccess.Status; + Type = emergencyAccess.Type; + WaitTimeDays = emergencyAccess.WaitTimeDays; } - public class EmergencyAccessGrantorDetailsResponseModel : EmergencyAccessResponseModel - { - public EmergencyAccessGrantorDetailsResponseModel(EmergencyAccessDetails emergencyAccess) - : base(emergencyAccess, "emergencyAccessGrantorDetails") - { - if (emergencyAccess == null) - { - throw new ArgumentNullException(nameof(emergencyAccess)); - } - - GrantorId = emergencyAccess.GrantorId.ToString(); - Email = emergencyAccess.GrantorEmail; - Name = emergencyAccess.GrantorName; - } - - public string GrantorId { get; private set; } - public string Name { get; private set; } - public string Email { get; private set; } - } - - public class EmergencyAccessTakeoverResponseModel : ResponseModel - { - public EmergencyAccessTakeoverResponseModel(EmergencyAccess emergencyAccess, User grantor, string obj = "emergencyAccessTakeover") : base(obj) - { - if (emergencyAccess == null) - { - throw new ArgumentNullException(nameof(emergencyAccess)); - } - - KeyEncrypted = emergencyAccess.KeyEncrypted; - Kdf = grantor.Kdf; - KdfIterations = grantor.KdfIterations; - } - - public int KdfIterations { get; private set; } - public KdfType Kdf { get; private set; } - public string KeyEncrypted { get; private set; } - } - - public class EmergencyAccessViewResponseModel : ResponseModel - { - public EmergencyAccessViewResponseModel( - IGlobalSettings globalSettings, - EmergencyAccess emergencyAccess, - IEnumerable ciphers) - : base("emergencyAccessView") - { - KeyEncrypted = emergencyAccess.KeyEncrypted; - Ciphers = ciphers.Select(c => new CipherResponseModel(c, globalSettings)); - } - - public string KeyEncrypted { get; set; } - public IEnumerable Ciphers { get; set; } - } + public string Id { get; private set; } + public EmergencyAccessStatusType Status { get; private set; } + public EmergencyAccessType Type { get; private set; } + public int WaitTimeDays { get; private set; } +} + +public class EmergencyAccessGranteeDetailsResponseModel : EmergencyAccessResponseModel +{ + public EmergencyAccessGranteeDetailsResponseModel(EmergencyAccessDetails emergencyAccess) + : base(emergencyAccess, "emergencyAccessGranteeDetails") + { + if (emergencyAccess == null) + { + throw new ArgumentNullException(nameof(emergencyAccess)); + } + + GranteeId = emergencyAccess.GranteeId.ToString(); + Email = emergencyAccess.GranteeEmail; + Name = emergencyAccess.GranteeName; + } + + public string GranteeId { get; private set; } + public string Name { get; private set; } + public string Email { get; private set; } +} + +public class EmergencyAccessGrantorDetailsResponseModel : EmergencyAccessResponseModel +{ + public EmergencyAccessGrantorDetailsResponseModel(EmergencyAccessDetails emergencyAccess) + : base(emergencyAccess, "emergencyAccessGrantorDetails") + { + if (emergencyAccess == null) + { + throw new ArgumentNullException(nameof(emergencyAccess)); + } + + GrantorId = emergencyAccess.GrantorId.ToString(); + Email = emergencyAccess.GrantorEmail; + Name = emergencyAccess.GrantorName; + } + + public string GrantorId { get; private set; } + public string Name { get; private set; } + public string Email { get; private set; } +} + +public class EmergencyAccessTakeoverResponseModel : ResponseModel +{ + public EmergencyAccessTakeoverResponseModel(EmergencyAccess emergencyAccess, User grantor, string obj = "emergencyAccessTakeover") : base(obj) + { + if (emergencyAccess == null) + { + throw new ArgumentNullException(nameof(emergencyAccess)); + } + + KeyEncrypted = emergencyAccess.KeyEncrypted; + Kdf = grantor.Kdf; + KdfIterations = grantor.KdfIterations; + } + + public int KdfIterations { get; private set; } + public KdfType Kdf { get; private set; } + public string KeyEncrypted { get; private set; } +} + +public class EmergencyAccessViewResponseModel : ResponseModel +{ + public EmergencyAccessViewResponseModel( + IGlobalSettings globalSettings, + EmergencyAccess emergencyAccess, + IEnumerable ciphers) + : base("emergencyAccessView") + { + KeyEncrypted = emergencyAccess.KeyEncrypted; + Ciphers = ciphers.Select(c => new CipherResponseModel(c, globalSettings)); + } + + public string KeyEncrypted { get; set; } + public IEnumerable Ciphers { get; set; } } diff --git a/src/Api/Models/Response/EventResponseModel.cs b/src/Api/Models/Response/EventResponseModel.cs index 40ccbb7e1..6cb723c6e 100644 --- a/src/Api/Models/Response/EventResponseModel.cs +++ b/src/Api/Models/Response/EventResponseModel.cs @@ -2,51 +2,50 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Response -{ - public class EventResponseModel : ResponseModel - { - public EventResponseModel(IEvent ev) - : base("event") - { - if (ev == null) - { - throw new ArgumentNullException(nameof(ev)); - } +namespace Bit.Api.Models.Response; - Type = ev.Type; - UserId = ev.UserId; - OrganizationId = ev.OrganizationId; - ProviderId = ev.ProviderId; - CipherId = ev.CipherId; - CollectionId = ev.CollectionId; - GroupId = ev.GroupId; - PolicyId = ev.PolicyId; - OrganizationUserId = ev.OrganizationUserId; - ProviderUserId = ev.ProviderUserId; - ProviderOrganizationId = ev.ProviderOrganizationId; - ActingUserId = ev.ActingUserId; - Date = ev.Date; - DeviceType = ev.DeviceType; - IpAddress = ev.IpAddress; - InstallationId = ev.InstallationId; +public class EventResponseModel : ResponseModel +{ + public EventResponseModel(IEvent ev) + : base("event") + { + if (ev == null) + { + throw new ArgumentNullException(nameof(ev)); } - public EventType Type { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public Guid? ProviderId { get; set; } - public Guid? CipherId { get; set; } - public Guid? CollectionId { get; set; } - public Guid? GroupId { get; set; } - public Guid? PolicyId { get; set; } - public Guid? OrganizationUserId { get; set; } - public Guid? ProviderUserId { get; set; } - public Guid? ProviderOrganizationId { get; set; } - public Guid? ActingUserId { get; set; } - public Guid? InstallationId { get; set; } - public DateTime Date { get; set; } - public DeviceType? DeviceType { get; set; } - public string IpAddress { get; set; } + Type = ev.Type; + UserId = ev.UserId; + OrganizationId = ev.OrganizationId; + ProviderId = ev.ProviderId; + CipherId = ev.CipherId; + CollectionId = ev.CollectionId; + GroupId = ev.GroupId; + PolicyId = ev.PolicyId; + OrganizationUserId = ev.OrganizationUserId; + ProviderUserId = ev.ProviderUserId; + ProviderOrganizationId = ev.ProviderOrganizationId; + ActingUserId = ev.ActingUserId; + Date = ev.Date; + DeviceType = ev.DeviceType; + IpAddress = ev.IpAddress; + InstallationId = ev.InstallationId; } + + public EventType Type { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public Guid? ProviderId { get; set; } + public Guid? CipherId { get; set; } + public Guid? CollectionId { get; set; } + public Guid? GroupId { get; set; } + public Guid? PolicyId { get; set; } + public Guid? OrganizationUserId { get; set; } + public Guid? ProviderUserId { get; set; } + public Guid? ProviderOrganizationId { get; set; } + public Guid? ActingUserId { get; set; } + public Guid? InstallationId { get; set; } + public DateTime Date { get; set; } + public DeviceType? DeviceType { get; set; } + public string IpAddress { get; set; } } diff --git a/src/Api/Models/Response/FolderResponseModel.cs b/src/Api/Models/Response/FolderResponseModel.cs index 0396471e1..03971b4e3 100644 --- a/src/Api/Models/Response/FolderResponseModel.cs +++ b/src/Api/Models/Response/FolderResponseModel.cs @@ -1,25 +1,24 @@ using Bit.Core.Entities; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response -{ - public class FolderResponseModel : ResponseModel - { - public FolderResponseModel(Folder folder) - : base("folder") - { - if (folder == null) - { - throw new ArgumentNullException(nameof(folder)); - } +namespace Bit.Api.Models.Response; - Id = folder.Id.ToString(); - Name = folder.Name; - RevisionDate = folder.RevisionDate; +public class FolderResponseModel : ResponseModel +{ + public FolderResponseModel(Folder folder) + : base("folder") + { + if (folder == null) + { + throw new ArgumentNullException(nameof(folder)); } - public string Id { get; set; } - public string Name { get; set; } - public DateTime RevisionDate { get; set; } + Id = folder.Id.ToString(); + Name = folder.Name; + RevisionDate = folder.RevisionDate; } + + public string Id { get; set; } + public string Name { get; set; } + public DateTime RevisionDate { get; set; } } diff --git a/src/Api/Models/Response/GroupResponseModel.cs b/src/Api/Models/Response/GroupResponseModel.cs index c75ff31e2..4b6496a40 100644 --- a/src/Api/Models/Response/GroupResponseModel.cs +++ b/src/Api/Models/Response/GroupResponseModel.cs @@ -2,40 +2,39 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Response +namespace Bit.Api.Models.Response; + +public class GroupResponseModel : ResponseModel { - public class GroupResponseModel : ResponseModel + public GroupResponseModel(Group group, string obj = "group") + : base(obj) { - public GroupResponseModel(Group group, string obj = "group") - : base(obj) + if (group == null) { - if (group == null) - { - throw new ArgumentNullException(nameof(group)); - } - - Id = group.Id.ToString(); - OrganizationId = group.OrganizationId.ToString(); - Name = group.Name; - AccessAll = group.AccessAll; - ExternalId = group.ExternalId; + throw new ArgumentNullException(nameof(group)); } - public string Id { get; set; } - public string OrganizationId { get; set; } - public string Name { get; set; } - public bool AccessAll { get; set; } - public string ExternalId { get; set; } + Id = group.Id.ToString(); + OrganizationId = group.OrganizationId.ToString(); + Name = group.Name; + AccessAll = group.AccessAll; + ExternalId = group.ExternalId; } - public class GroupDetailsResponseModel : GroupResponseModel - { - public GroupDetailsResponseModel(Group group, IEnumerable collections) - : base(group, "groupDetails") - { - Collections = collections.Select(c => new SelectionReadOnlyResponseModel(c)); - } - - public IEnumerable Collections { get; set; } - } + public string Id { get; set; } + public string OrganizationId { get; set; } + public string Name { get; set; } + public bool AccessAll { get; set; } + public string ExternalId { get; set; } +} + +public class GroupDetailsResponseModel : GroupResponseModel +{ + public GroupDetailsResponseModel(Group group, IEnumerable collections) + : base(group, "groupDetails") + { + Collections = collections.Select(c => new SelectionReadOnlyResponseModel(c)); + } + + public IEnumerable Collections { get; set; } } diff --git a/src/Api/Models/Response/InstallationResponseModel.cs b/src/Api/Models/Response/InstallationResponseModel.cs index 68e1524b1..75000471d 100644 --- a/src/Api/Models/Response/InstallationResponseModel.cs +++ b/src/Api/Models/Response/InstallationResponseModel.cs @@ -1,20 +1,19 @@ using Bit.Core.Entities; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response -{ - public class InstallationResponseModel : ResponseModel - { - public InstallationResponseModel(Installation installation, bool withKey) - : base("installation") - { - Id = installation.Id.ToString(); - Key = withKey ? installation.Key : null; - Enabled = installation.Enabled; - } +namespace Bit.Api.Models.Response; - public string Id { get; set; } - public string Key { get; set; } - public bool Enabled { get; set; } +public class InstallationResponseModel : ResponseModel +{ + public InstallationResponseModel(Installation installation, bool withKey) + : base("installation") + { + Id = installation.Id.ToString(); + Key = withKey ? installation.Key : null; + Enabled = installation.Enabled; } + + public string Id { get; set; } + public string Key { get; set; } + public bool Enabled { get; set; } } diff --git a/src/Api/Models/Response/KeysResponseModel.cs b/src/Api/Models/Response/KeysResponseModel.cs index 1ca1ae052..2f7e5e730 100644 --- a/src/Api/Models/Response/KeysResponseModel.cs +++ b/src/Api/Models/Response/KeysResponseModel.cs @@ -1,25 +1,24 @@ using Bit.Core.Entities; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response -{ - public class KeysResponseModel : ResponseModel - { - public KeysResponseModel(User user) - : base("keys") - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } +namespace Bit.Api.Models.Response; - Key = user.Key; - PublicKey = user.PublicKey; - PrivateKey = user.PrivateKey; +public class KeysResponseModel : ResponseModel +{ + public KeysResponseModel(User user) + : base("keys") + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); } - public string Key { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } + Key = user.Key; + PublicKey = user.PublicKey; + PrivateKey = user.PrivateKey; } + + public string Key { get; set; } + public string PublicKey { get; set; } + public string PrivateKey { get; set; } } diff --git a/src/Api/Models/Response/ListResponseModel.cs b/src/Api/Models/Response/ListResponseModel.cs index c16a3461c..ecfe0a7e1 100644 --- a/src/Api/Models/Response/ListResponseModel.cs +++ b/src/Api/Models/Response/ListResponseModel.cs @@ -1,17 +1,16 @@ using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response -{ - public class ListResponseModel : ResponseModel where T : ResponseModel - { - public ListResponseModel(IEnumerable data, string continuationToken = null) - : base("list") - { - Data = data; - ContinuationToken = continuationToken; - } +namespace Bit.Api.Models.Response; - public IEnumerable Data { get; set; } - public string ContinuationToken { get; set; } +public class ListResponseModel : ResponseModel where T : ResponseModel +{ + public ListResponseModel(IEnumerable data, string continuationToken = null) + : base("list") + { + Data = data; + ContinuationToken = continuationToken; } + + public IEnumerable Data { get; set; } + public string ContinuationToken { get; set; } } diff --git a/src/Api/Models/Response/OrganizationExportResponseModel.cs b/src/Api/Models/Response/OrganizationExportResponseModel.cs index f5ce61873..a7533c918 100644 --- a/src/Api/Models/Response/OrganizationExportResponseModel.cs +++ b/src/Api/Models/Response/OrganizationExportResponseModel.cs @@ -1,13 +1,12 @@ -namespace Bit.Api.Models.Response +namespace Bit.Api.Models.Response; + +public class OrganizationExportResponseModel { - public class OrganizationExportResponseModel + public OrganizationExportResponseModel() { - public OrganizationExportResponseModel() - { - } - - public ListResponseModel Collections { get; set; } - - public ListResponseModel Ciphers { get; set; } } + + public ListResponseModel Collections { get; set; } + + public ListResponseModel Ciphers { get; set; } } diff --git a/src/Api/Models/Response/Organizations/OrganizationApiKeyInformationResponseModel.cs b/src/Api/Models/Response/Organizations/OrganizationApiKeyInformationResponseModel.cs index 05adb502a..a25cb8935 100644 --- a/src/Api/Models/Response/Organizations/OrganizationApiKeyInformationResponseModel.cs +++ b/src/Api/Models/Response/Organizations/OrganizationApiKeyInformationResponseModel.cs @@ -2,17 +2,16 @@ using Bit.Core.Enums; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.Organizations -{ - public class OrganizationApiKeyInformation : ResponseModel - { - public OrganizationApiKeyInformation(OrganizationApiKey key) : base("keyInformation") - { - KeyType = key.Type; - RevisionDate = key.RevisionDate; - } +namespace Bit.Api.Models.Response.Organizations; - public OrganizationApiKeyType KeyType { get; set; } - public DateTime RevisionDate { get; set; } +public class OrganizationApiKeyInformation : ResponseModel +{ + public OrganizationApiKeyInformation(OrganizationApiKey key) : base("keyInformation") + { + KeyType = key.Type; + RevisionDate = key.RevisionDate; } + + public OrganizationApiKeyType KeyType { get; set; } + public DateTime RevisionDate { get; set; } } diff --git a/src/Api/Models/Response/Organizations/OrganizationAutoEnrollStatusResponseModel.cs b/src/Api/Models/Response/Organizations/OrganizationAutoEnrollStatusResponseModel.cs index 529168c6a..9c1f0ee22 100644 --- a/src/Api/Models/Response/Organizations/OrganizationAutoEnrollStatusResponseModel.cs +++ b/src/Api/Models/Response/Organizations/OrganizationAutoEnrollStatusResponseModel.cs @@ -1,16 +1,15 @@ using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.Organizations -{ - public class OrganizationAutoEnrollStatusResponseModel : ResponseModel - { - public OrganizationAutoEnrollStatusResponseModel(Guid orgId, bool resetPasswordEnabled) : base("organizationAutoEnrollStatus") - { - Id = orgId.ToString(); - ResetPasswordEnabled = resetPasswordEnabled; - } +namespace Bit.Api.Models.Response.Organizations; - public string Id { get; set; } - public bool ResetPasswordEnabled { get; set; } +public class OrganizationAutoEnrollStatusResponseModel : ResponseModel +{ + public OrganizationAutoEnrollStatusResponseModel(Guid orgId, bool resetPasswordEnabled) : base("organizationAutoEnrollStatus") + { + Id = orgId.ToString(); + ResetPasswordEnabled = resetPasswordEnabled; } + + public string Id { get; set; } + public bool ResetPasswordEnabled { get; set; } } diff --git a/src/Api/Models/Response/Organizations/OrganizationConnectionResponseModel.cs b/src/Api/Models/Response/Organizations/OrganizationConnectionResponseModel.cs index 86fb9b4db..f199ce56c 100644 --- a/src/Api/Models/Response/Organizations/OrganizationConnectionResponseModel.cs +++ b/src/Api/Models/Response/Organizations/OrganizationConnectionResponseModel.cs @@ -2,28 +2,27 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Api.Models.Response.Organizations +namespace Bit.Api.Models.Response.Organizations; + +public class OrganizationConnectionResponseModel { - public class OrganizationConnectionResponseModel + public Guid? Id { get; set; } + public OrganizationConnectionType Type { get; set; } + public Guid OrganizationId { get; set; } + public bool Enabled { get; set; } + public JsonDocument Config { get; set; } + + public OrganizationConnectionResponseModel(OrganizationConnection connection, Type configType) { - public Guid? Id { get; set; } - public OrganizationConnectionType Type { get; set; } - public Guid OrganizationId { get; set; } - public bool Enabled { get; set; } - public JsonDocument Config { get; set; } - - public OrganizationConnectionResponseModel(OrganizationConnection connection, Type configType) + if (connection == null) { - if (connection == null) - { - return; - } - - Id = connection.Id; - Type = connection.Type; - OrganizationId = connection.OrganizationId; - Enabled = connection.Enabled; - Config = JsonDocument.Parse(connection.Config); + return; } + + Id = connection.Id; + Type = connection.Type; + OrganizationId = connection.OrganizationId; + Enabled = connection.Enabled; + Config = JsonDocument.Parse(connection.Config); } } diff --git a/src/Api/Models/Response/Organizations/OrganizationKeysResponseModel.cs b/src/Api/Models/Response/Organizations/OrganizationKeysResponseModel.cs index 06430bef2..35c2f77e7 100644 --- a/src/Api/Models/Response/Organizations/OrganizationKeysResponseModel.cs +++ b/src/Api/Models/Response/Organizations/OrganizationKeysResponseModel.cs @@ -1,22 +1,21 @@ using Bit.Core.Entities; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.Organizations -{ - public class OrganizationKeysResponseModel : ResponseModel - { - public OrganizationKeysResponseModel(Organization org) : base("organizationKeys") - { - if (org == null) - { - throw new ArgumentNullException(nameof(org)); - } +namespace Bit.Api.Models.Response.Organizations; - PublicKey = org.PublicKey; - PrivateKey = org.PrivateKey; +public class OrganizationKeysResponseModel : ResponseModel +{ + public OrganizationKeysResponseModel(Organization org) : base("organizationKeys") + { + if (org == null) + { + throw new ArgumentNullException(nameof(org)); } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } + PublicKey = org.PublicKey; + PrivateKey = org.PrivateKey; } + + public string PublicKey { get; set; } + public string PrivateKey { get; set; } } diff --git a/src/Api/Models/Response/Organizations/OrganizationResponseModel.cs b/src/Api/Models/Response/Organizations/OrganizationResponseModel.cs index eab23bee9..4aa83d201 100644 --- a/src/Api/Models/Response/Organizations/OrganizationResponseModel.cs +++ b/src/Api/Models/Response/Organizations/OrganizationResponseModel.cs @@ -4,110 +4,109 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Business; using Bit.Core.Utilities; -namespace Bit.Api.Models.Response.Organizations +namespace Bit.Api.Models.Response.Organizations; + +public class OrganizationResponseModel : ResponseModel { - public class OrganizationResponseModel : ResponseModel + public OrganizationResponseModel(Organization organization, string obj = "organization") + : base(obj) { - public OrganizationResponseModel(Organization organization, string obj = "organization") - : base(obj) + if (organization == null) { - if (organization == null) - { - throw new ArgumentNullException(nameof(organization)); - } - - Id = organization.Id.ToString(); - Identifier = organization.Identifier; - Name = organization.Name; - BusinessName = organization.BusinessName; - BusinessAddress1 = organization.BusinessAddress1; - BusinessAddress2 = organization.BusinessAddress2; - BusinessAddress3 = organization.BusinessAddress3; - BusinessCountry = organization.BusinessCountry; - BusinessTaxNumber = organization.BusinessTaxNumber; - BillingEmail = organization.BillingEmail; - Plan = new PlanResponseModel(StaticStore.Plans.FirstOrDefault(plan => plan.Type == organization.PlanType)); - PlanType = organization.PlanType; - Seats = organization.Seats; - MaxAutoscaleSeats = organization.MaxAutoscaleSeats; - MaxCollections = organization.MaxCollections; - MaxStorageGb = organization.MaxStorageGb; - UsePolicies = organization.UsePolicies; - UseSso = organization.UseSso; - UseKeyConnector = organization.UseKeyConnector; - UseScim = organization.UseScim; - UseGroups = organization.UseGroups; - UseDirectory = organization.UseDirectory; - UseEvents = organization.UseEvents; - UseTotp = organization.UseTotp; - Use2fa = organization.Use2fa; - UseApi = organization.UseApi; - UseResetPassword = organization.UseResetPassword; - UsersGetPremium = organization.UsersGetPremium; - SelfHost = organization.SelfHost; - HasPublicAndPrivateKeys = organization.PublicKey != null && organization.PrivateKey != null; + throw new ArgumentNullException(nameof(organization)); } - public string Id { get; set; } - public string Identifier { get; set; } - public string Name { get; set; } - public string BusinessName { get; set; } - public string BusinessAddress1 { get; set; } - public string BusinessAddress2 { get; set; } - public string BusinessAddress3 { get; set; } - public string BusinessCountry { get; set; } - public string BusinessTaxNumber { get; set; } - public string BillingEmail { get; set; } - public PlanResponseModel Plan { get; set; } - public PlanType PlanType { get; set; } - public int? Seats { get; set; } - public int? MaxAutoscaleSeats { get; set; } = null; - public short? MaxCollections { get; set; } - public short? MaxStorageGb { get; set; } - public bool UsePolicies { get; set; } - public bool UseSso { get; set; } - public bool UseKeyConnector { get; set; } - public bool UseScim { get; set; } - public bool UseGroups { get; set; } - public bool UseDirectory { get; set; } - public bool UseEvents { get; set; } - public bool UseTotp { get; set; } - public bool Use2fa { get; set; } - public bool UseApi { get; set; } - public bool UseResetPassword { get; set; } - public bool UsersGetPremium { get; set; } - public bool SelfHost { get; set; } - public bool HasPublicAndPrivateKeys { get; set; } + Id = organization.Id.ToString(); + Identifier = organization.Identifier; + Name = organization.Name; + BusinessName = organization.BusinessName; + BusinessAddress1 = organization.BusinessAddress1; + BusinessAddress2 = organization.BusinessAddress2; + BusinessAddress3 = organization.BusinessAddress3; + BusinessCountry = organization.BusinessCountry; + BusinessTaxNumber = organization.BusinessTaxNumber; + BillingEmail = organization.BillingEmail; + Plan = new PlanResponseModel(StaticStore.Plans.FirstOrDefault(plan => plan.Type == organization.PlanType)); + PlanType = organization.PlanType; + Seats = organization.Seats; + MaxAutoscaleSeats = organization.MaxAutoscaleSeats; + MaxCollections = organization.MaxCollections; + MaxStorageGb = organization.MaxStorageGb; + UsePolicies = organization.UsePolicies; + UseSso = organization.UseSso; + UseKeyConnector = organization.UseKeyConnector; + UseScim = organization.UseScim; + UseGroups = organization.UseGroups; + UseDirectory = organization.UseDirectory; + UseEvents = organization.UseEvents; + UseTotp = organization.UseTotp; + Use2fa = organization.Use2fa; + UseApi = organization.UseApi; + UseResetPassword = organization.UseResetPassword; + UsersGetPremium = organization.UsersGetPremium; + SelfHost = organization.SelfHost; + HasPublicAndPrivateKeys = organization.PublicKey != null && organization.PrivateKey != null; } - public class OrganizationSubscriptionResponseModel : OrganizationResponseModel - { - public OrganizationSubscriptionResponseModel(Organization organization, SubscriptionInfo subscription = null) - : base(organization, "organizationSubscription") - { - if (subscription != null) - { - Subscription = subscription.Subscription != null ? - new BillingSubscription(subscription.Subscription) : null; - UpcomingInvoice = subscription.UpcomingInvoice != null ? - new BillingSubscriptionUpcomingInvoice(subscription.UpcomingInvoice) : null; - Expiration = DateTime.UtcNow.AddYears(1); // Not used, so just give it a value. - } - else - { - Expiration = organization.ExpirationDate; - } - - StorageName = organization.Storage.HasValue ? - CoreHelpers.ReadableBytesSize(organization.Storage.Value) : null; - StorageGb = organization.Storage.HasValue ? - Math.Round(organization.Storage.Value / 1073741824D, 2) : 0; // 1 GB - } - - public string StorageName { get; set; } - public double? StorageGb { get; set; } - public BillingSubscription Subscription { get; set; } - public BillingSubscriptionUpcomingInvoice UpcomingInvoice { get; set; } - public DateTime? Expiration { get; set; } - } + public string Id { get; set; } + public string Identifier { get; set; } + public string Name { get; set; } + public string BusinessName { get; set; } + public string BusinessAddress1 { get; set; } + public string BusinessAddress2 { get; set; } + public string BusinessAddress3 { get; set; } + public string BusinessCountry { get; set; } + public string BusinessTaxNumber { get; set; } + public string BillingEmail { get; set; } + public PlanResponseModel Plan { get; set; } + public PlanType PlanType { get; set; } + public int? Seats { get; set; } + public int? MaxAutoscaleSeats { get; set; } = null; + public short? MaxCollections { get; set; } + public short? MaxStorageGb { get; set; } + public bool UsePolicies { get; set; } + public bool UseSso { get; set; } + public bool UseKeyConnector { get; set; } + public bool UseScim { get; set; } + public bool UseGroups { get; set; } + public bool UseDirectory { get; set; } + public bool UseEvents { get; set; } + public bool UseTotp { get; set; } + public bool Use2fa { get; set; } + public bool UseApi { get; set; } + public bool UseResetPassword { get; set; } + public bool UsersGetPremium { get; set; } + public bool SelfHost { get; set; } + public bool HasPublicAndPrivateKeys { get; set; } +} + +public class OrganizationSubscriptionResponseModel : OrganizationResponseModel +{ + public OrganizationSubscriptionResponseModel(Organization organization, SubscriptionInfo subscription = null) + : base(organization, "organizationSubscription") + { + if (subscription != null) + { + Subscription = subscription.Subscription != null ? + new BillingSubscription(subscription.Subscription) : null; + UpcomingInvoice = subscription.UpcomingInvoice != null ? + new BillingSubscriptionUpcomingInvoice(subscription.UpcomingInvoice) : null; + Expiration = DateTime.UtcNow.AddYears(1); // Not used, so just give it a value. + } + else + { + Expiration = organization.ExpirationDate; + } + + StorageName = organization.Storage.HasValue ? + CoreHelpers.ReadableBytesSize(organization.Storage.Value) : null; + StorageGb = organization.Storage.HasValue ? + Math.Round(organization.Storage.Value / 1073741824D, 2) : 0; // 1 GB + } + + public string StorageName { get; set; } + public double? StorageGb { get; set; } + public BillingSubscription Subscription { get; set; } + public BillingSubscriptionUpcomingInvoice UpcomingInvoice { get; set; } + public DateTime? Expiration { get; set; } } diff --git a/src/Api/Models/Response/Organizations/OrganizationSponsorshipSyncStatusResponseModel.cs b/src/Api/Models/Response/Organizations/OrganizationSponsorshipSyncStatusResponseModel.cs index 33e349bbf..33862f391 100644 --- a/src/Api/Models/Response/Organizations/OrganizationSponsorshipSyncStatusResponseModel.cs +++ b/src/Api/Models/Response/Organizations/OrganizationSponsorshipSyncStatusResponseModel.cs @@ -1,15 +1,14 @@ using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.Organizations -{ - public class OrganizationSponsorshipSyncStatusResponseModel : ResponseModel - { - public OrganizationSponsorshipSyncStatusResponseModel(DateTime? lastSyncDate) - : base("syncStatus") - { - LastSyncDate = lastSyncDate; - } +namespace Bit.Api.Models.Response.Organizations; - public DateTime? LastSyncDate { get; set; } +public class OrganizationSponsorshipSyncStatusResponseModel : ResponseModel +{ + public OrganizationSponsorshipSyncStatusResponseModel(DateTime? lastSyncDate) + : base("syncStatus") + { + LastSyncDate = lastSyncDate; } + + public DateTime? LastSyncDate { get; set; } } diff --git a/src/Api/Models/Response/Organizations/OrganizationSsoResponseModel.cs b/src/Api/Models/Response/Organizations/OrganizationSsoResponseModel.cs index dd828630b..cd7e6c266 100644 --- a/src/Api/Models/Response/Organizations/OrganizationSsoResponseModel.cs +++ b/src/Api/Models/Response/Organizations/OrganizationSsoResponseModel.cs @@ -3,42 +3,41 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Data; using Bit.Core.Settings; -namespace Bit.Api.Models.Response.Organizations +namespace Bit.Api.Models.Response.Organizations; + +public class OrganizationSsoResponseModel : ResponseModel { - public class OrganizationSsoResponseModel : ResponseModel + public OrganizationSsoResponseModel(Organization organization, GlobalSettings globalSettings, + SsoConfig config = null) : base("organizationSso") { - public OrganizationSsoResponseModel(Organization organization, GlobalSettings globalSettings, - SsoConfig config = null) : base("organizationSso") + if (config != null) { - if (config != null) - { - Enabled = config.Enabled; - Data = config.GetData(); - } - - Urls = new SsoUrls(organization.Id.ToString(), globalSettings); + Enabled = config.Enabled; + Data = config.GetData(); } - public bool Enabled { get; set; } - public SsoConfigurationData Data { get; set; } - public SsoUrls Urls { get; set; } + Urls = new SsoUrls(organization.Id.ToString(), globalSettings); } - public class SsoUrls - { - public SsoUrls(string organizationId, GlobalSettings globalSettings) - { - CallbackPath = SsoConfigurationData.BuildCallbackPath(globalSettings.BaseServiceUri.Sso); - SignedOutCallbackPath = SsoConfigurationData.BuildSignedOutCallbackPath(globalSettings.BaseServiceUri.Sso); - SpEntityId = SsoConfigurationData.BuildSaml2ModulePath(globalSettings.BaseServiceUri.Sso); - SpMetadataUrl = SsoConfigurationData.BuildSaml2MetadataUrl(globalSettings.BaseServiceUri.Sso, organizationId); - SpAcsUrl = SsoConfigurationData.BuildSaml2AcsUrl(globalSettings.BaseServiceUri.Sso, organizationId); - } - - public string CallbackPath { get; set; } - public string SignedOutCallbackPath { get; set; } - public string SpEntityId { get; set; } - public string SpMetadataUrl { get; set; } - public string SpAcsUrl { get; set; } - } + public bool Enabled { get; set; } + public SsoConfigurationData Data { get; set; } + public SsoUrls Urls { get; set; } +} + +public class SsoUrls +{ + public SsoUrls(string organizationId, GlobalSettings globalSettings) + { + CallbackPath = SsoConfigurationData.BuildCallbackPath(globalSettings.BaseServiceUri.Sso); + SignedOutCallbackPath = SsoConfigurationData.BuildSignedOutCallbackPath(globalSettings.BaseServiceUri.Sso); + SpEntityId = SsoConfigurationData.BuildSaml2ModulePath(globalSettings.BaseServiceUri.Sso); + SpMetadataUrl = SsoConfigurationData.BuildSaml2MetadataUrl(globalSettings.BaseServiceUri.Sso, organizationId); + SpAcsUrl = SsoConfigurationData.BuildSaml2AcsUrl(globalSettings.BaseServiceUri.Sso, organizationId); + } + + public string CallbackPath { get; set; } + public string SignedOutCallbackPath { get; set; } + public string SpEntityId { get; set; } + public string SpMetadataUrl { get; set; } + public string SpAcsUrl { get; set; } } diff --git a/src/Api/Models/Response/Organizations/OrganizationUserResponseModel.cs b/src/Api/Models/Response/Organizations/OrganizationUserResponseModel.cs index 7be68c41a..619769b06 100644 --- a/src/Api/Models/Response/Organizations/OrganizationUserResponseModel.cs +++ b/src/Api/Models/Response/Organizations/OrganizationUserResponseModel.cs @@ -5,139 +5,138 @@ using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Utilities; -namespace Bit.Api.Models.Response.Organizations +namespace Bit.Api.Models.Response.Organizations; + +public class OrganizationUserResponseModel : ResponseModel { - public class OrganizationUserResponseModel : ResponseModel + public OrganizationUserResponseModel(OrganizationUser organizationUser, string obj = "organizationUser") + : base(obj) { - public OrganizationUserResponseModel(OrganizationUser organizationUser, string obj = "organizationUser") - : base(obj) + if (organizationUser == null) { - if (organizationUser == null) - { - throw new ArgumentNullException(nameof(organizationUser)); - } - - Id = organizationUser.Id.ToString(); - UserId = organizationUser.UserId?.ToString(); - Type = organizationUser.Type; - Status = organizationUser.Status; - AccessAll = organizationUser.AccessAll; - Permissions = CoreHelpers.LoadClassFromJsonData(organizationUser.Permissions); - ResetPasswordEnrolled = !string.IsNullOrEmpty(organizationUser.ResetPasswordKey); + throw new ArgumentNullException(nameof(organizationUser)); } - public OrganizationUserResponseModel(OrganizationUserUserDetails organizationUser, string obj = "organizationUser") - : base(obj) - { - if (organizationUser == null) - { - throw new ArgumentNullException(nameof(organizationUser)); - } - - Id = organizationUser.Id.ToString(); - UserId = organizationUser.UserId?.ToString(); - Type = organizationUser.Type; - Status = organizationUser.Status; - AccessAll = organizationUser.AccessAll; - Permissions = CoreHelpers.LoadClassFromJsonData(organizationUser.Permissions); - ResetPasswordEnrolled = !string.IsNullOrEmpty(organizationUser.ResetPasswordKey); - UsesKeyConnector = organizationUser.UsesKeyConnector; - } - - public string Id { get; set; } - public string UserId { get; set; } - public OrganizationUserType Type { get; set; } - public OrganizationUserStatusType Status { get; set; } - public bool AccessAll { get; set; } - public Permissions Permissions { get; set; } - public bool ResetPasswordEnrolled { get; set; } - public bool UsesKeyConnector { get; set; } + Id = organizationUser.Id.ToString(); + UserId = organizationUser.UserId?.ToString(); + Type = organizationUser.Type; + Status = organizationUser.Status; + AccessAll = organizationUser.AccessAll; + Permissions = CoreHelpers.LoadClassFromJsonData(organizationUser.Permissions); + ResetPasswordEnrolled = !string.IsNullOrEmpty(organizationUser.ResetPasswordKey); } - public class OrganizationUserDetailsResponseModel : OrganizationUserResponseModel + public OrganizationUserResponseModel(OrganizationUserUserDetails organizationUser, string obj = "organizationUser") + : base(obj) { - public OrganizationUserDetailsResponseModel(OrganizationUser organizationUser, - IEnumerable collections) - : base(organizationUser, "organizationUserDetails") + if (organizationUser == null) { - Collections = collections.Select(c => new SelectionReadOnlyResponseModel(c)); + throw new ArgumentNullException(nameof(organizationUser)); } - public IEnumerable Collections { get; set; } + Id = organizationUser.Id.ToString(); + UserId = organizationUser.UserId?.ToString(); + Type = organizationUser.Type; + Status = organizationUser.Status; + AccessAll = organizationUser.AccessAll; + Permissions = CoreHelpers.LoadClassFromJsonData(organizationUser.Permissions); + ResetPasswordEnrolled = !string.IsNullOrEmpty(organizationUser.ResetPasswordKey); + UsesKeyConnector = organizationUser.UsesKeyConnector; } - public class OrganizationUserUserDetailsResponseModel : OrganizationUserResponseModel - { - public OrganizationUserUserDetailsResponseModel(OrganizationUserUserDetails organizationUser, - bool twoFactorEnabled, string obj = "organizationUserUserDetails") - : base(organizationUser, obj) - { - if (organizationUser == null) - { - throw new ArgumentNullException(nameof(organizationUser)); - } - - Name = organizationUser.Name; - Email = organizationUser.Email; - TwoFactorEnabled = twoFactorEnabled; - SsoBound = !string.IsNullOrWhiteSpace(organizationUser.SsoExternalId); - // Prevent reset password when using key connector. - ResetPasswordEnrolled = ResetPasswordEnrolled && !organizationUser.UsesKeyConnector; - } - - public string Name { get; set; } - public string Email { get; set; } - public bool TwoFactorEnabled { get; set; } - public bool SsoBound { get; set; } - } - - public class OrganizationUserResetPasswordDetailsResponseModel : ResponseModel - { - public OrganizationUserResetPasswordDetailsResponseModel(OrganizationUserResetPasswordDetails orgUser, - string obj = "organizationUserResetPasswordDetails") : base(obj) - { - if (orgUser == null) - { - throw new ArgumentNullException(nameof(orgUser)); - } - - Kdf = orgUser.Kdf; - KdfIterations = orgUser.KdfIterations; - ResetPasswordKey = orgUser.ResetPasswordKey; - EncryptedPrivateKey = orgUser.EncryptedPrivateKey; - } - - public KdfType Kdf { get; set; } - public int KdfIterations { get; set; } - public string ResetPasswordKey { get; set; } - public string EncryptedPrivateKey { get; set; } - } - - public class OrganizationUserPublicKeyResponseModel : ResponseModel - { - public OrganizationUserPublicKeyResponseModel(Guid id, Guid userId, - string key, string obj = "organizationUserPublicKeyResponseModel") : - base(obj) - { - Id = id; - UserId = userId; - Key = key; - } - - public Guid Id { get; set; } - public Guid UserId { get; set; } - public string Key { get; set; } - } - - public class OrganizationUserBulkResponseModel : ResponseModel - { - public OrganizationUserBulkResponseModel(Guid id, string error, - string obj = "OrganizationBulkConfirmResponseModel") : base(obj) - { - Id = id; - Error = error; - } - public Guid Id { get; set; } - public string Error { get; set; } - } + public string Id { get; set; } + public string UserId { get; set; } + public OrganizationUserType Type { get; set; } + public OrganizationUserStatusType Status { get; set; } + public bool AccessAll { get; set; } + public Permissions Permissions { get; set; } + public bool ResetPasswordEnrolled { get; set; } + public bool UsesKeyConnector { get; set; } +} + +public class OrganizationUserDetailsResponseModel : OrganizationUserResponseModel +{ + public OrganizationUserDetailsResponseModel(OrganizationUser organizationUser, + IEnumerable collections) + : base(organizationUser, "organizationUserDetails") + { + Collections = collections.Select(c => new SelectionReadOnlyResponseModel(c)); + } + + public IEnumerable Collections { get; set; } +} + +public class OrganizationUserUserDetailsResponseModel : OrganizationUserResponseModel +{ + public OrganizationUserUserDetailsResponseModel(OrganizationUserUserDetails organizationUser, + bool twoFactorEnabled, string obj = "organizationUserUserDetails") + : base(organizationUser, obj) + { + if (organizationUser == null) + { + throw new ArgumentNullException(nameof(organizationUser)); + } + + Name = organizationUser.Name; + Email = organizationUser.Email; + TwoFactorEnabled = twoFactorEnabled; + SsoBound = !string.IsNullOrWhiteSpace(organizationUser.SsoExternalId); + // Prevent reset password when using key connector. + ResetPasswordEnrolled = ResetPasswordEnrolled && !organizationUser.UsesKeyConnector; + } + + public string Name { get; set; } + public string Email { get; set; } + public bool TwoFactorEnabled { get; set; } + public bool SsoBound { get; set; } +} + +public class OrganizationUserResetPasswordDetailsResponseModel : ResponseModel +{ + public OrganizationUserResetPasswordDetailsResponseModel(OrganizationUserResetPasswordDetails orgUser, + string obj = "organizationUserResetPasswordDetails") : base(obj) + { + if (orgUser == null) + { + throw new ArgumentNullException(nameof(orgUser)); + } + + Kdf = orgUser.Kdf; + KdfIterations = orgUser.KdfIterations; + ResetPasswordKey = orgUser.ResetPasswordKey; + EncryptedPrivateKey = orgUser.EncryptedPrivateKey; + } + + public KdfType Kdf { get; set; } + public int KdfIterations { get; set; } + public string ResetPasswordKey { get; set; } + public string EncryptedPrivateKey { get; set; } +} + +public class OrganizationUserPublicKeyResponseModel : ResponseModel +{ + public OrganizationUserPublicKeyResponseModel(Guid id, Guid userId, + string key, string obj = "organizationUserPublicKeyResponseModel") : + base(obj) + { + Id = id; + UserId = userId; + Key = key; + } + + public Guid Id { get; set; } + public Guid UserId { get; set; } + public string Key { get; set; } +} + +public class OrganizationUserBulkResponseModel : ResponseModel +{ + public OrganizationUserBulkResponseModel(Guid id, string error, + string obj = "OrganizationBulkConfirmResponseModel") : base(obj) + { + Id = id; + Error = error; + } + public Guid Id { get; set; } + public string Error { get; set; } } diff --git a/src/Api/Models/Response/PaymentResponseModel.cs b/src/Api/Models/Response/PaymentResponseModel.cs index 43edb3216..067ac969e 100644 --- a/src/Api/Models/Response/PaymentResponseModel.cs +++ b/src/Api/Models/Response/PaymentResponseModel.cs @@ -1,15 +1,14 @@ using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response -{ - public class PaymentResponseModel : ResponseModel - { - public PaymentResponseModel() - : base("payment") - { } +namespace Bit.Api.Models.Response; - public ProfileResponseModel UserProfile { get; set; } - public string PaymentIntentClientSecret { get; set; } - public bool Success { get; set; } - } +public class PaymentResponseModel : ResponseModel +{ + public PaymentResponseModel() + : base("payment") + { } + + public ProfileResponseModel UserProfile { get; set; } + public string PaymentIntentClientSecret { get; set; } + public bool Success { get; set; } } diff --git a/src/Api/Models/Response/PlanResponseModel.cs b/src/Api/Models/Response/PlanResponseModel.cs index 5974772a8..fd2934e73 100644 --- a/src/Api/Models/Response/PlanResponseModel.cs +++ b/src/Api/Models/Response/PlanResponseModel.cs @@ -2,101 +2,100 @@ using Bit.Core.Models.Api; using Bit.Core.Models.StaticStore; -namespace Bit.Api.Models.Response -{ - public class PlanResponseModel : ResponseModel - { - public PlanResponseModel(Plan plan, string obj = "plan") - : base(obj) - { - if (plan == null) - { - throw new ArgumentNullException(nameof(plan)); - } +namespace Bit.Api.Models.Response; - Type = plan.Type; - Product = plan.Product; - Name = plan.Name; - IsAnnual = plan.IsAnnual; - NameLocalizationKey = plan.NameLocalizationKey; - DescriptionLocalizationKey = plan.DescriptionLocalizationKey; - CanBeUsedByBusiness = plan.CanBeUsedByBusiness; - BaseSeats = plan.BaseSeats; - BaseStorageGb = plan.BaseStorageGb; - MaxCollections = plan.MaxCollections; - MaxUsers = plan.MaxUsers; - HasAdditionalSeatsOption = plan.HasAdditionalSeatsOption; - HasAdditionalStorageOption = plan.HasAdditionalStorageOption; - MaxAdditionalSeats = plan.MaxAdditionalSeats; - MaxAdditionalStorage = plan.MaxAdditionalStorage; - HasPremiumAccessOption = plan.HasPremiumAccessOption; - TrialPeriodDays = plan.TrialPeriodDays; - HasSelfHost = plan.HasSelfHost; - HasPolicies = plan.HasPolicies; - HasGroups = plan.HasGroups; - HasDirectory = plan.HasDirectory; - HasEvents = plan.HasEvents; - HasTotp = plan.HasTotp; - Has2fa = plan.Has2fa; - HasSso = plan.HasSso; - HasResetPassword = plan.HasResetPassword; - UsersGetPremium = plan.UsersGetPremium; - UpgradeSortOrder = plan.UpgradeSortOrder; - DisplaySortOrder = plan.DisplaySortOrder; - LegacyYear = plan.LegacyYear; - Disabled = plan.Disabled; - StripePlanId = plan.StripePlanId; - StripeSeatPlanId = plan.StripeSeatPlanId; - StripeStoragePlanId = plan.StripeStoragePlanId; - BasePrice = plan.BasePrice; - SeatPrice = plan.SeatPrice; - AdditionalStoragePricePerGb = plan.AdditionalStoragePricePerGb; - PremiumAccessOptionPrice = plan.PremiumAccessOptionPrice; +public class PlanResponseModel : ResponseModel +{ + public PlanResponseModel(Plan plan, string obj = "plan") + : base(obj) + { + if (plan == null) + { + throw new ArgumentNullException(nameof(plan)); } - public PlanType Type { get; set; } - public ProductType Product { get; set; } - public string Name { get; set; } - public bool IsAnnual { get; set; } - public string NameLocalizationKey { get; set; } - public string DescriptionLocalizationKey { get; set; } - public bool CanBeUsedByBusiness { get; set; } - public int BaseSeats { get; set; } - public short? BaseStorageGb { get; set; } - public short? MaxCollections { get; set; } - public short? MaxUsers { get; set; } - - public bool HasAdditionalSeatsOption { get; set; } - public int? MaxAdditionalSeats { get; set; } - public bool HasAdditionalStorageOption { get; set; } - public short? MaxAdditionalStorage { get; set; } - public bool HasPremiumAccessOption { get; set; } - public int? TrialPeriodDays { get; set; } - - public bool HasSelfHost { get; set; } - public bool HasPolicies { get; set; } - public bool HasGroups { get; set; } - public bool HasDirectory { get; set; } - public bool HasEvents { get; set; } - public bool HasTotp { get; set; } - public bool Has2fa { get; set; } - public bool HasApi { get; set; } - public bool HasSso { get; set; } - public bool HasResetPassword { get; set; } - public bool UsersGetPremium { get; set; } - - public int UpgradeSortOrder { get; set; } - public int DisplaySortOrder { get; set; } - public int? LegacyYear { get; set; } - public bool Disabled { get; set; } - - public string StripePlanId { get; set; } - public string StripeSeatPlanId { get; set; } - public string StripeStoragePlanId { get; set; } - public string StripePremiumAccessPlanId { get; set; } - public decimal BasePrice { get; set; } - public decimal SeatPrice { get; set; } - public decimal AdditionalStoragePricePerGb { get; set; } - public decimal PremiumAccessOptionPrice { get; set; } + Type = plan.Type; + Product = plan.Product; + Name = plan.Name; + IsAnnual = plan.IsAnnual; + NameLocalizationKey = plan.NameLocalizationKey; + DescriptionLocalizationKey = plan.DescriptionLocalizationKey; + CanBeUsedByBusiness = plan.CanBeUsedByBusiness; + BaseSeats = plan.BaseSeats; + BaseStorageGb = plan.BaseStorageGb; + MaxCollections = plan.MaxCollections; + MaxUsers = plan.MaxUsers; + HasAdditionalSeatsOption = plan.HasAdditionalSeatsOption; + HasAdditionalStorageOption = plan.HasAdditionalStorageOption; + MaxAdditionalSeats = plan.MaxAdditionalSeats; + MaxAdditionalStorage = plan.MaxAdditionalStorage; + HasPremiumAccessOption = plan.HasPremiumAccessOption; + TrialPeriodDays = plan.TrialPeriodDays; + HasSelfHost = plan.HasSelfHost; + HasPolicies = plan.HasPolicies; + HasGroups = plan.HasGroups; + HasDirectory = plan.HasDirectory; + HasEvents = plan.HasEvents; + HasTotp = plan.HasTotp; + Has2fa = plan.Has2fa; + HasSso = plan.HasSso; + HasResetPassword = plan.HasResetPassword; + UsersGetPremium = plan.UsersGetPremium; + UpgradeSortOrder = plan.UpgradeSortOrder; + DisplaySortOrder = plan.DisplaySortOrder; + LegacyYear = plan.LegacyYear; + Disabled = plan.Disabled; + StripePlanId = plan.StripePlanId; + StripeSeatPlanId = plan.StripeSeatPlanId; + StripeStoragePlanId = plan.StripeStoragePlanId; + BasePrice = plan.BasePrice; + SeatPrice = plan.SeatPrice; + AdditionalStoragePricePerGb = plan.AdditionalStoragePricePerGb; + PremiumAccessOptionPrice = plan.PremiumAccessOptionPrice; } + + public PlanType Type { get; set; } + public ProductType Product { get; set; } + public string Name { get; set; } + public bool IsAnnual { get; set; } + public string NameLocalizationKey { get; set; } + public string DescriptionLocalizationKey { get; set; } + public bool CanBeUsedByBusiness { get; set; } + public int BaseSeats { get; set; } + public short? BaseStorageGb { get; set; } + public short? MaxCollections { get; set; } + public short? MaxUsers { get; set; } + + public bool HasAdditionalSeatsOption { get; set; } + public int? MaxAdditionalSeats { get; set; } + public bool HasAdditionalStorageOption { get; set; } + public short? MaxAdditionalStorage { get; set; } + public bool HasPremiumAccessOption { get; set; } + public int? TrialPeriodDays { get; set; } + + public bool HasSelfHost { get; set; } + public bool HasPolicies { get; set; } + public bool HasGroups { get; set; } + public bool HasDirectory { get; set; } + public bool HasEvents { get; set; } + public bool HasTotp { get; set; } + public bool Has2fa { get; set; } + public bool HasApi { get; set; } + public bool HasSso { get; set; } + public bool HasResetPassword { get; set; } + public bool UsersGetPremium { get; set; } + + public int UpgradeSortOrder { get; set; } + public int DisplaySortOrder { get; set; } + public int? LegacyYear { get; set; } + public bool Disabled { get; set; } + + public string StripePlanId { get; set; } + public string StripeSeatPlanId { get; set; } + public string StripeStoragePlanId { get; set; } + public string StripePremiumAccessPlanId { get; set; } + public decimal BasePrice { get; set; } + public decimal SeatPrice { get; set; } + public decimal AdditionalStoragePricePerGb { get; set; } + public decimal PremiumAccessOptionPrice { get; set; } } diff --git a/src/Api/Models/Response/PolicyResponseModel.cs b/src/Api/Models/Response/PolicyResponseModel.cs index 7f725ba31..a812a911d 100644 --- a/src/Api/Models/Response/PolicyResponseModel.cs +++ b/src/Api/Models/Response/PolicyResponseModel.cs @@ -3,32 +3,31 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response -{ - public class PolicyResponseModel : ResponseModel - { - public PolicyResponseModel(Policy policy, string obj = "policy") - : base(obj) - { - if (policy == null) - { - throw new ArgumentNullException(nameof(policy)); - } +namespace Bit.Api.Models.Response; - Id = policy.Id.ToString(); - OrganizationId = policy.OrganizationId.ToString(); - Type = policy.Type; - Enabled = policy.Enabled; - if (!string.IsNullOrWhiteSpace(policy.Data)) - { - Data = JsonSerializer.Deserialize>(policy.Data); - } +public class PolicyResponseModel : ResponseModel +{ + public PolicyResponseModel(Policy policy, string obj = "policy") + : base(obj) + { + if (policy == null) + { + throw new ArgumentNullException(nameof(policy)); } - public string Id { get; set; } - public string OrganizationId { get; set; } - public PolicyType Type { get; set; } - public Dictionary Data { get; set; } - public bool Enabled { get; set; } + Id = policy.Id.ToString(); + OrganizationId = policy.OrganizationId.ToString(); + Type = policy.Type; + Enabled = policy.Enabled; + if (!string.IsNullOrWhiteSpace(policy.Data)) + { + Data = JsonSerializer.Deserialize>(policy.Data); + } } + + public string Id { get; set; } + public string OrganizationId { get; set; } + public PolicyType Type { get; set; } + public Dictionary Data { get; set; } + public bool Enabled { get; set; } } diff --git a/src/Api/Models/Response/ProfileOrganizationResponseModel.cs b/src/Api/Models/Response/ProfileOrganizationResponseModel.cs index 969dbbaf1..4285ae432 100644 --- a/src/Api/Models/Response/ProfileOrganizationResponseModel.cs +++ b/src/Api/Models/Response/ProfileOrganizationResponseModel.cs @@ -4,98 +4,97 @@ using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Utilities; -namespace Bit.Api.Models.Response +namespace Bit.Api.Models.Response; + +public class ProfileOrganizationResponseModel : ResponseModel { - public class ProfileOrganizationResponseModel : ResponseModel + public ProfileOrganizationResponseModel(string str) : base(str) { } + + public ProfileOrganizationResponseModel(OrganizationUserOrganizationDetails organization) : this("profileOrganization") { - public ProfileOrganizationResponseModel(string str) : base(str) { } + Id = organization.OrganizationId.ToString(); + Name = organization.Name; + UsePolicies = organization.UsePolicies; + UseSso = organization.UseSso; + UseKeyConnector = organization.UseKeyConnector; + UseScim = organization.UseScim; + UseGroups = organization.UseGroups; + UseDirectory = organization.UseDirectory; + UseEvents = organization.UseEvents; + UseTotp = organization.UseTotp; + Use2fa = organization.Use2fa; + UseApi = organization.UseApi; + UseResetPassword = organization.UseResetPassword; + UsersGetPremium = organization.UsersGetPremium; + SelfHost = organization.SelfHost; + Seats = organization.Seats; + MaxCollections = organization.MaxCollections; + MaxStorageGb = organization.MaxStorageGb; + Key = organization.Key; + HasPublicAndPrivateKeys = organization.PublicKey != null && organization.PrivateKey != null; + Status = organization.Status; + Type = organization.Type; + Enabled = organization.Enabled; + SsoBound = !string.IsNullOrWhiteSpace(organization.SsoExternalId); + Identifier = organization.Identifier; + Permissions = CoreHelpers.LoadClassFromJsonData(organization.Permissions); + ResetPasswordEnrolled = organization.ResetPasswordKey != null; + UserId = organization.UserId?.ToString(); + ProviderId = organization.ProviderId?.ToString(); + ProviderName = organization.ProviderName; + FamilySponsorshipFriendlyName = organization.FamilySponsorshipFriendlyName; + FamilySponsorshipAvailable = FamilySponsorshipFriendlyName == null && + StaticStore.GetSponsoredPlan(PlanSponsorshipType.FamiliesForEnterprise) + .UsersCanSponsor(organization); + PlanProductType = StaticStore.GetPlan(organization.PlanType).Product; + FamilySponsorshipLastSyncDate = organization.FamilySponsorshipLastSyncDate; + FamilySponsorshipToDelete = organization.FamilySponsorshipToDelete; + FamilySponsorshipValidUntil = organization.FamilySponsorshipValidUntil; - public ProfileOrganizationResponseModel(OrganizationUserOrganizationDetails organization) : this("profileOrganization") + if (organization.SsoConfig != null) { - Id = organization.OrganizationId.ToString(); - Name = organization.Name; - UsePolicies = organization.UsePolicies; - UseSso = organization.UseSso; - UseKeyConnector = organization.UseKeyConnector; - UseScim = organization.UseScim; - UseGroups = organization.UseGroups; - UseDirectory = organization.UseDirectory; - UseEvents = organization.UseEvents; - UseTotp = organization.UseTotp; - Use2fa = organization.Use2fa; - UseApi = organization.UseApi; - UseResetPassword = organization.UseResetPassword; - UsersGetPremium = organization.UsersGetPremium; - SelfHost = organization.SelfHost; - Seats = organization.Seats; - MaxCollections = organization.MaxCollections; - MaxStorageGb = organization.MaxStorageGb; - Key = organization.Key; - HasPublicAndPrivateKeys = organization.PublicKey != null && organization.PrivateKey != null; - Status = organization.Status; - Type = organization.Type; - Enabled = organization.Enabled; - SsoBound = !string.IsNullOrWhiteSpace(organization.SsoExternalId); - Identifier = organization.Identifier; - Permissions = CoreHelpers.LoadClassFromJsonData(organization.Permissions); - ResetPasswordEnrolled = organization.ResetPasswordKey != null; - UserId = organization.UserId?.ToString(); - ProviderId = organization.ProviderId?.ToString(); - ProviderName = organization.ProviderName; - FamilySponsorshipFriendlyName = organization.FamilySponsorshipFriendlyName; - FamilySponsorshipAvailable = FamilySponsorshipFriendlyName == null && - StaticStore.GetSponsoredPlan(PlanSponsorshipType.FamiliesForEnterprise) - .UsersCanSponsor(organization); - PlanProductType = StaticStore.GetPlan(organization.PlanType).Product; - FamilySponsorshipLastSyncDate = organization.FamilySponsorshipLastSyncDate; - FamilySponsorshipToDelete = organization.FamilySponsorshipToDelete; - FamilySponsorshipValidUntil = organization.FamilySponsorshipValidUntil; - - if (organization.SsoConfig != null) - { - var ssoConfigData = SsoConfigurationData.Deserialize(organization.SsoConfig); - KeyConnectorEnabled = ssoConfigData.KeyConnectorEnabled && !string.IsNullOrEmpty(ssoConfigData.KeyConnectorUrl); - KeyConnectorUrl = ssoConfigData.KeyConnectorUrl; - } + var ssoConfigData = SsoConfigurationData.Deserialize(organization.SsoConfig); + KeyConnectorEnabled = ssoConfigData.KeyConnectorEnabled && !string.IsNullOrEmpty(ssoConfigData.KeyConnectorUrl); + KeyConnectorUrl = ssoConfigData.KeyConnectorUrl; } - - public string Id { get; set; } - public string Name { get; set; } - public bool UsePolicies { get; set; } - public bool UseSso { get; set; } - public bool UseKeyConnector { get; set; } - public bool UseScim { get; set; } - public bool UseGroups { get; set; } - public bool UseDirectory { get; set; } - public bool UseEvents { get; set; } - public bool UseTotp { get; set; } - public bool Use2fa { get; set; } - public bool UseApi { get; set; } - public bool UseResetPassword { get; set; } - public bool UsersGetPremium { get; set; } - public bool SelfHost { get; set; } - public int? Seats { get; set; } - public short? MaxCollections { get; set; } - public short? MaxStorageGb { get; set; } - public string Key { get; set; } - public OrganizationUserStatusType Status { get; set; } - public OrganizationUserType Type { get; set; } - public bool Enabled { get; set; } - public bool SsoBound { get; set; } - public string Identifier { get; set; } - public Permissions Permissions { get; set; } - public bool ResetPasswordEnrolled { get; set; } - public string UserId { get; set; } - public bool HasPublicAndPrivateKeys { get; set; } - public string ProviderId { get; set; } - public string ProviderName { get; set; } - public string FamilySponsorshipFriendlyName { get; set; } - public bool FamilySponsorshipAvailable { get; set; } - public ProductType PlanProductType { get; set; } - public bool KeyConnectorEnabled { get; set; } - public string KeyConnectorUrl { get; set; } - public DateTime? FamilySponsorshipLastSyncDate { get; set; } - public DateTime? FamilySponsorshipValidUntil { get; set; } - public bool? FamilySponsorshipToDelete { get; set; } } + + public string Id { get; set; } + public string Name { get; set; } + public bool UsePolicies { get; set; } + public bool UseSso { get; set; } + public bool UseKeyConnector { get; set; } + public bool UseScim { get; set; } + public bool UseGroups { get; set; } + public bool UseDirectory { get; set; } + public bool UseEvents { get; set; } + public bool UseTotp { get; set; } + public bool Use2fa { get; set; } + public bool UseApi { get; set; } + public bool UseResetPassword { get; set; } + public bool UsersGetPremium { get; set; } + public bool SelfHost { get; set; } + public int? Seats { get; set; } + public short? MaxCollections { get; set; } + public short? MaxStorageGb { get; set; } + public string Key { get; set; } + public OrganizationUserStatusType Status { get; set; } + public OrganizationUserType Type { get; set; } + public bool Enabled { get; set; } + public bool SsoBound { get; set; } + public string Identifier { get; set; } + public Permissions Permissions { get; set; } + public bool ResetPasswordEnrolled { get; set; } + public string UserId { get; set; } + public bool HasPublicAndPrivateKeys { get; set; } + public string ProviderId { get; set; } + public string ProviderName { get; set; } + public string FamilySponsorshipFriendlyName { get; set; } + public bool FamilySponsorshipAvailable { get; set; } + public ProductType PlanProductType { get; set; } + public bool KeyConnectorEnabled { get; set; } + public string KeyConnectorUrl { get; set; } + public DateTime? FamilySponsorshipLastSyncDate { get; set; } + public DateTime? FamilySponsorshipValidUntil { get; set; } + public bool? FamilySponsorshipToDelete { get; set; } } diff --git a/src/Api/Models/Response/ProfileProviderOrganizationResponseModel.cs b/src/Api/Models/Response/ProfileProviderOrganizationResponseModel.cs index c2d7858b5..a660662fa 100644 --- a/src/Api/Models/Response/ProfileProviderOrganizationResponseModel.cs +++ b/src/Api/Models/Response/ProfileProviderOrganizationResponseModel.cs @@ -1,43 +1,42 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Response +namespace Bit.Api.Models.Response; + +public class ProfileProviderOrganizationResponseModel : ProfileOrganizationResponseModel { - public class ProfileProviderOrganizationResponseModel : ProfileOrganizationResponseModel + public ProfileProviderOrganizationResponseModel(ProviderUserOrganizationDetails organization) + : base("profileProviderOrganization") { - public ProfileProviderOrganizationResponseModel(ProviderUserOrganizationDetails organization) - : base("profileProviderOrganization") - { - Id = organization.OrganizationId.ToString(); - Name = organization.Name; - UsePolicies = organization.UsePolicies; - UseSso = organization.UseSso; - UseKeyConnector = organization.UseKeyConnector; - UseScim = organization.UseScim; - UseGroups = organization.UseGroups; - UseDirectory = organization.UseDirectory; - UseEvents = organization.UseEvents; - UseTotp = organization.UseTotp; - Use2fa = organization.Use2fa; - UseApi = organization.UseApi; - UseResetPassword = organization.UseResetPassword; - UsersGetPremium = organization.UsersGetPremium; - SelfHost = organization.SelfHost; - Seats = organization.Seats; - MaxCollections = organization.MaxCollections; - MaxStorageGb = organization.MaxStorageGb; - Key = organization.Key; - HasPublicAndPrivateKeys = organization.PublicKey != null && organization.PrivateKey != null; - Status = OrganizationUserStatusType.Confirmed; // Provider users are always confirmed - Type = OrganizationUserType.Owner; // Provider users behave like Owners - Enabled = organization.Enabled; - SsoBound = false; - Identifier = organization.Identifier; - Permissions = new Permissions(); - ResetPasswordEnrolled = false; - UserId = organization.UserId?.ToString(); - ProviderId = organization.ProviderId?.ToString(); - ProviderName = organization.ProviderName; - } + Id = organization.OrganizationId.ToString(); + Name = organization.Name; + UsePolicies = organization.UsePolicies; + UseSso = organization.UseSso; + UseKeyConnector = organization.UseKeyConnector; + UseScim = organization.UseScim; + UseGroups = organization.UseGroups; + UseDirectory = organization.UseDirectory; + UseEvents = organization.UseEvents; + UseTotp = organization.UseTotp; + Use2fa = organization.Use2fa; + UseApi = organization.UseApi; + UseResetPassword = organization.UseResetPassword; + UsersGetPremium = organization.UsersGetPremium; + SelfHost = organization.SelfHost; + Seats = organization.Seats; + MaxCollections = organization.MaxCollections; + MaxStorageGb = organization.MaxStorageGb; + Key = organization.Key; + HasPublicAndPrivateKeys = organization.PublicKey != null && organization.PrivateKey != null; + Status = OrganizationUserStatusType.Confirmed; // Provider users are always confirmed + Type = OrganizationUserType.Owner; // Provider users behave like Owners + Enabled = organization.Enabled; + SsoBound = false; + Identifier = organization.Identifier; + Permissions = new Permissions(); + ResetPasswordEnrolled = false; + UserId = organization.UserId?.ToString(); + ProviderId = organization.ProviderId?.ToString(); + ProviderName = organization.ProviderName; } } diff --git a/src/Api/Models/Response/ProfileResponseModel.cs b/src/Api/Models/Response/ProfileResponseModel.cs index 42e9943c4..dfa9e5dac 100644 --- a/src/Api/Models/Response/ProfileResponseModel.cs +++ b/src/Api/Models/Response/ProfileResponseModel.cs @@ -4,62 +4,61 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Api.Models.Response +namespace Bit.Api.Models.Response; + +public class ProfileResponseModel : ResponseModel { - public class ProfileResponseModel : ResponseModel + public ProfileResponseModel(User user, + IEnumerable organizationsUserDetails, + IEnumerable providerUserDetails, + IEnumerable providerUserOrganizationDetails, + bool twoFactorEnabled, + bool premiumFromOrganization) : base("profile") { - public ProfileResponseModel(User user, - IEnumerable organizationsUserDetails, - IEnumerable providerUserDetails, - IEnumerable providerUserOrganizationDetails, - bool twoFactorEnabled, - bool premiumFromOrganization) : base("profile") + if (user == null) { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - Id = user.Id.ToString(); - Name = user.Name; - Email = user.Email; - EmailVerified = user.EmailVerified; - Premium = user.Premium; - PremiumFromOrganization = premiumFromOrganization; - MasterPasswordHint = string.IsNullOrWhiteSpace(user.MasterPasswordHint) ? null : user.MasterPasswordHint; - Culture = user.Culture; - TwoFactorEnabled = twoFactorEnabled; - Key = user.Key; - PrivateKey = user.PrivateKey; - SecurityStamp = user.SecurityStamp; - ForcePasswordReset = user.ForcePasswordReset; - UsesKeyConnector = user.UsesKeyConnector; - Organizations = organizationsUserDetails?.Select(o => new ProfileOrganizationResponseModel(o)); - Providers = providerUserDetails?.Select(p => new ProfileProviderResponseModel(p)); - ProviderOrganizations = - providerUserOrganizationDetails?.Select(po => new ProfileProviderOrganizationResponseModel(po)); + throw new ArgumentNullException(nameof(user)); } - public ProfileResponseModel() : base("profile") - { - } - - public string Id { get; set; } - public string Name { get; set; } - public string Email { get; set; } - public bool EmailVerified { get; set; } - public bool Premium { get; set; } - public bool PremiumFromOrganization { get; set; } - public string MasterPasswordHint { get; set; } - public string Culture { get; set; } - public bool TwoFactorEnabled { get; set; } - public string Key { get; set; } - public string PrivateKey { get; set; } - public string SecurityStamp { get; set; } - public bool ForcePasswordReset { get; set; } - public bool UsesKeyConnector { get; set; } - public IEnumerable Organizations { get; set; } - public IEnumerable Providers { get; set; } - public IEnumerable ProviderOrganizations { get; set; } + Id = user.Id.ToString(); + Name = user.Name; + Email = user.Email; + EmailVerified = user.EmailVerified; + Premium = user.Premium; + PremiumFromOrganization = premiumFromOrganization; + MasterPasswordHint = string.IsNullOrWhiteSpace(user.MasterPasswordHint) ? null : user.MasterPasswordHint; + Culture = user.Culture; + TwoFactorEnabled = twoFactorEnabled; + Key = user.Key; + PrivateKey = user.PrivateKey; + SecurityStamp = user.SecurityStamp; + ForcePasswordReset = user.ForcePasswordReset; + UsesKeyConnector = user.UsesKeyConnector; + Organizations = organizationsUserDetails?.Select(o => new ProfileOrganizationResponseModel(o)); + Providers = providerUserDetails?.Select(p => new ProfileProviderResponseModel(p)); + ProviderOrganizations = + providerUserOrganizationDetails?.Select(po => new ProfileProviderOrganizationResponseModel(po)); } + + public ProfileResponseModel() : base("profile") + { + } + + public string Id { get; set; } + public string Name { get; set; } + public string Email { get; set; } + public bool EmailVerified { get; set; } + public bool Premium { get; set; } + public bool PremiumFromOrganization { get; set; } + public string MasterPasswordHint { get; set; } + public string Culture { get; set; } + public bool TwoFactorEnabled { get; set; } + public string Key { get; set; } + public string PrivateKey { get; set; } + public string SecurityStamp { get; set; } + public bool ForcePasswordReset { get; set; } + public bool UsesKeyConnector { get; set; } + public IEnumerable Organizations { get; set; } + public IEnumerable Providers { get; set; } + public IEnumerable ProviderOrganizations { get; set; } } diff --git a/src/Api/Models/Response/Providers/ProfileProviderResponseModel.cs b/src/Api/Models/Response/Providers/ProfileProviderResponseModel.cs index c8a0c3818..7a218d1c7 100644 --- a/src/Api/Models/Response/Providers/ProfileProviderResponseModel.cs +++ b/src/Api/Models/Response/Providers/ProfileProviderResponseModel.cs @@ -3,32 +3,31 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models.Response.Providers -{ - public class ProfileProviderResponseModel : ResponseModel - { - public ProfileProviderResponseModel(ProviderUserProviderDetails provider) - : base("profileProvider") - { - Id = provider.ProviderId.ToString(); - Name = provider.Name; - Key = provider.Key; - Status = provider.Status; - Type = provider.Type; - Enabled = provider.Enabled; - Permissions = CoreHelpers.LoadClassFromJsonData(provider.Permissions); - UserId = provider.UserId?.ToString(); - UseEvents = provider.UseEvents; - } +namespace Bit.Api.Models.Response.Providers; - public string Id { get; set; } - public string Name { get; set; } - public string Key { get; set; } - public ProviderUserStatusType Status { get; set; } - public ProviderUserType Type { get; set; } - public bool Enabled { get; set; } - public Permissions Permissions { get; set; } - public string UserId { get; set; } - public bool UseEvents { get; set; } +public class ProfileProviderResponseModel : ResponseModel +{ + public ProfileProviderResponseModel(ProviderUserProviderDetails provider) + : base("profileProvider") + { + Id = provider.ProviderId.ToString(); + Name = provider.Name; + Key = provider.Key; + Status = provider.Status; + Type = provider.Type; + Enabled = provider.Enabled; + Permissions = CoreHelpers.LoadClassFromJsonData(provider.Permissions); + UserId = provider.UserId?.ToString(); + UseEvents = provider.UseEvents; } + + public string Id { get; set; } + public string Name { get; set; } + public string Key { get; set; } + public ProviderUserStatusType Status { get; set; } + public ProviderUserType Type { get; set; } + public bool Enabled { get; set; } + public Permissions Permissions { get; set; } + public string UserId { get; set; } + public bool UseEvents { get; set; } } diff --git a/src/Api/Models/Response/Providers/ProviderOrganizationResponseModel.cs b/src/Api/Models/Response/Providers/ProviderOrganizationResponseModel.cs index e508787a0..9bc7d52dc 100644 --- a/src/Api/Models/Response/Providers/ProviderOrganizationResponseModel.cs +++ b/src/Api/Models/Response/Providers/ProviderOrganizationResponseModel.cs @@ -2,72 +2,71 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Response.Providers +namespace Bit.Api.Models.Response.Providers; + +public class ProviderOrganizationResponseModel : ResponseModel { - public class ProviderOrganizationResponseModel : ResponseModel + public ProviderOrganizationResponseModel(ProviderOrganization providerOrganization, + string obj = "providerOrganization") : base(obj) { - public ProviderOrganizationResponseModel(ProviderOrganization providerOrganization, - string obj = "providerOrganization") : base(obj) + if (providerOrganization == null) { - if (providerOrganization == null) - { - throw new ArgumentNullException(nameof(providerOrganization)); - } - - Id = providerOrganization.Id; - ProviderId = providerOrganization.ProviderId; - OrganizationId = providerOrganization.OrganizationId; - Key = providerOrganization.Key; - Settings = providerOrganization.Settings; - CreationDate = providerOrganization.CreationDate; - RevisionDate = providerOrganization.RevisionDate; + throw new ArgumentNullException(nameof(providerOrganization)); } - public ProviderOrganizationResponseModel(ProviderOrganizationOrganizationDetails providerOrganization, - string obj = "providerOrganization") : base(obj) - { - if (providerOrganization == null) - { - throw new ArgumentNullException(nameof(providerOrganization)); - } - - Id = providerOrganization.Id; - ProviderId = providerOrganization.ProviderId; - OrganizationId = providerOrganization.OrganizationId; - Key = providerOrganization.Key; - Settings = providerOrganization.Settings; - CreationDate = providerOrganization.CreationDate; - RevisionDate = providerOrganization.RevisionDate; - UserCount = providerOrganization.UserCount; - Seats = providerOrganization.Seats; - Plan = providerOrganization.Plan; - } - - public Guid Id { get; set; } - public Guid ProviderId { get; set; } - public Guid OrganizationId { get; set; } - public string Key { get; set; } - public string Settings { get; set; } - public DateTime CreationDate { get; set; } - public DateTime RevisionDate { get; set; } - public int UserCount { get; set; } - public int? Seats { get; set; } - public string Plan { get; set; } + Id = providerOrganization.Id; + ProviderId = providerOrganization.ProviderId; + OrganizationId = providerOrganization.OrganizationId; + Key = providerOrganization.Key; + Settings = providerOrganization.Settings; + CreationDate = providerOrganization.CreationDate; + RevisionDate = providerOrganization.RevisionDate; } - public class ProviderOrganizationOrganizationDetailsResponseModel : ProviderOrganizationResponseModel + public ProviderOrganizationResponseModel(ProviderOrganizationOrganizationDetails providerOrganization, + string obj = "providerOrganization") : base(obj) { - public ProviderOrganizationOrganizationDetailsResponseModel(ProviderOrganizationOrganizationDetails providerOrganization, - string obj = "providerOrganizationOrganizationDetail") : base(providerOrganization, obj) + if (providerOrganization == null) { - if (providerOrganization == null) - { - throw new ArgumentNullException(nameof(providerOrganization)); - } - - OrganizationName = providerOrganization.OrganizationName; + throw new ArgumentNullException(nameof(providerOrganization)); } - public string OrganizationName { get; set; } + Id = providerOrganization.Id; + ProviderId = providerOrganization.ProviderId; + OrganizationId = providerOrganization.OrganizationId; + Key = providerOrganization.Key; + Settings = providerOrganization.Settings; + CreationDate = providerOrganization.CreationDate; + RevisionDate = providerOrganization.RevisionDate; + UserCount = providerOrganization.UserCount; + Seats = providerOrganization.Seats; + Plan = providerOrganization.Plan; } + + public Guid Id { get; set; } + public Guid ProviderId { get; set; } + public Guid OrganizationId { get; set; } + public string Key { get; set; } + public string Settings { get; set; } + public DateTime CreationDate { get; set; } + public DateTime RevisionDate { get; set; } + public int UserCount { get; set; } + public int? Seats { get; set; } + public string Plan { get; set; } +} + +public class ProviderOrganizationOrganizationDetailsResponseModel : ProviderOrganizationResponseModel +{ + public ProviderOrganizationOrganizationDetailsResponseModel(ProviderOrganizationOrganizationDetails providerOrganization, + string obj = "providerOrganizationOrganizationDetail") : base(providerOrganization, obj) + { + if (providerOrganization == null) + { + throw new ArgumentNullException(nameof(providerOrganization)); + } + + OrganizationName = providerOrganization.OrganizationName; + } + + public string OrganizationName { get; set; } } diff --git a/src/Api/Models/Response/Providers/ProviderResponseModel.cs b/src/Api/Models/Response/Providers/ProviderResponseModel.cs index 02cea09d1..ce62fdaa6 100644 --- a/src/Api/Models/Response/Providers/ProviderResponseModel.cs +++ b/src/Api/Models/Response/Providers/ProviderResponseModel.cs @@ -1,36 +1,35 @@ using Bit.Core.Entities.Provider; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.Providers -{ - public class ProviderResponseModel : ResponseModel - { - public ProviderResponseModel(Provider provider, string obj = "provider") : base(obj) - { - if (provider == null) - { - throw new ArgumentNullException(nameof(provider)); - } +namespace Bit.Api.Models.Response.Providers; - Id = provider.Id; - Name = provider.Name; - BusinessName = provider.BusinessName; - BusinessAddress1 = provider.BusinessAddress1; - BusinessAddress2 = provider.BusinessAddress2; - BusinessAddress3 = provider.BusinessAddress3; - BusinessCountry = provider.BusinessCountry; - BusinessTaxNumber = provider.BusinessTaxNumber; - BillingEmail = provider.BillingEmail; +public class ProviderResponseModel : ResponseModel +{ + public ProviderResponseModel(Provider provider, string obj = "provider") : base(obj) + { + if (provider == null) + { + throw new ArgumentNullException(nameof(provider)); } - public Guid Id { get; set; } - public string Name { get; set; } - public string BusinessName { get; set; } - public string BusinessAddress1 { get; set; } - public string BusinessAddress2 { get; set; } - public string BusinessAddress3 { get; set; } - public string BusinessCountry { get; set; } - public string BusinessTaxNumber { get; set; } - public string BillingEmail { get; set; } + Id = provider.Id; + Name = provider.Name; + BusinessName = provider.BusinessName; + BusinessAddress1 = provider.BusinessAddress1; + BusinessAddress2 = provider.BusinessAddress2; + BusinessAddress3 = provider.BusinessAddress3; + BusinessCountry = provider.BusinessCountry; + BusinessTaxNumber = provider.BusinessTaxNumber; + BillingEmail = provider.BillingEmail; } + + public Guid Id { get; set; } + public string Name { get; set; } + public string BusinessName { get; set; } + public string BusinessAddress1 { get; set; } + public string BusinessAddress2 { get; set; } + public string BusinessAddress3 { get; set; } + public string BusinessCountry { get; set; } + public string BusinessTaxNumber { get; set; } + public string BillingEmail { get; set; } } diff --git a/src/Api/Models/Response/Providers/ProviderUserResponseModel.cs b/src/Api/Models/Response/Providers/ProviderUserResponseModel.cs index 44122b2b0..b08e39e19 100644 --- a/src/Api/Models/Response/Providers/ProviderUserResponseModel.cs +++ b/src/Api/Models/Response/Providers/ProviderUserResponseModel.cs @@ -4,89 +4,88 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models.Response.Providers +namespace Bit.Api.Models.Response.Providers; + +public class ProviderUserResponseModel : ResponseModel { - public class ProviderUserResponseModel : ResponseModel + public ProviderUserResponseModel(ProviderUser providerUser, string obj = "providerUser") + : base(obj) { - public ProviderUserResponseModel(ProviderUser providerUser, string obj = "providerUser") - : base(obj) + if (providerUser == null) { - if (providerUser == null) - { - throw new ArgumentNullException(nameof(providerUser)); - } - - Id = providerUser.Id.ToString(); - UserId = providerUser.UserId?.ToString(); - Type = providerUser.Type; - Status = providerUser.Status; - Permissions = CoreHelpers.LoadClassFromJsonData(providerUser.Permissions); + throw new ArgumentNullException(nameof(providerUser)); } - public ProviderUserResponseModel(ProviderUserUserDetails providerUser, string obj = "providerUser") - : base(obj) - { - if (providerUser == null) - { - throw new ArgumentNullException(nameof(providerUser)); - } - - Id = providerUser.Id.ToString(); - UserId = providerUser.UserId?.ToString(); - Type = providerUser.Type; - Status = providerUser.Status; - Permissions = CoreHelpers.LoadClassFromJsonData(providerUser.Permissions); - } - - public string Id { get; set; } - public string UserId { get; set; } - public ProviderUserType Type { get; set; } - public ProviderUserStatusType Status { get; set; } - public Permissions Permissions { get; set; } + Id = providerUser.Id.ToString(); + UserId = providerUser.UserId?.ToString(); + Type = providerUser.Type; + Status = providerUser.Status; + Permissions = CoreHelpers.LoadClassFromJsonData(providerUser.Permissions); } - public class ProviderUserUserDetailsResponseModel : ProviderUserResponseModel + public ProviderUserResponseModel(ProviderUserUserDetails providerUser, string obj = "providerUser") + : base(obj) { - public ProviderUserUserDetailsResponseModel(ProviderUserUserDetails providerUser, - string obj = "providerUserUserDetails") : base(providerUser, obj) + if (providerUser == null) { - if (providerUser == null) - { - throw new ArgumentNullException(nameof(providerUser)); - } - - Name = providerUser.Name; - Email = providerUser.Email; + throw new ArgumentNullException(nameof(providerUser)); } - public string Name { get; set; } - public string Email { get; set; } + Id = providerUser.Id.ToString(); + UserId = providerUser.UserId?.ToString(); + Type = providerUser.Type; + Status = providerUser.Status; + Permissions = CoreHelpers.LoadClassFromJsonData(providerUser.Permissions); } - public class ProviderUserPublicKeyResponseModel : ResponseModel - { - public ProviderUserPublicKeyResponseModel(Guid id, Guid userId, string key, - string obj = "providerUserPublicKeyResponseModel") : base(obj) - { - Id = id; - UserId = userId; - Key = key; - } - - public Guid Id { get; set; } - public Guid UserId { get; set; } - public string Key { get; set; } - } - - public class ProviderUserBulkResponseModel : ResponseModel - { - public ProviderUserBulkResponseModel(Guid id, string error, - string obj = "providerBulkConfirmResponseModel") : base(obj) - { - Id = id; - Error = error; - } - public Guid Id { get; set; } - public string Error { get; set; } - } + public string Id { get; set; } + public string UserId { get; set; } + public ProviderUserType Type { get; set; } + public ProviderUserStatusType Status { get; set; } + public Permissions Permissions { get; set; } +} + +public class ProviderUserUserDetailsResponseModel : ProviderUserResponseModel +{ + public ProviderUserUserDetailsResponseModel(ProviderUserUserDetails providerUser, + string obj = "providerUserUserDetails") : base(providerUser, obj) + { + if (providerUser == null) + { + throw new ArgumentNullException(nameof(providerUser)); + } + + Name = providerUser.Name; + Email = providerUser.Email; + } + + public string Name { get; set; } + public string Email { get; set; } +} + +public class ProviderUserPublicKeyResponseModel : ResponseModel +{ + public ProviderUserPublicKeyResponseModel(Guid id, Guid userId, string key, + string obj = "providerUserPublicKeyResponseModel") : base(obj) + { + Id = id; + UserId = userId; + Key = key; + } + + public Guid Id { get; set; } + public Guid UserId { get; set; } + public string Key { get; set; } +} + +public class ProviderUserBulkResponseModel : ResponseModel +{ + public ProviderUserBulkResponseModel(Guid id, string error, + string obj = "providerBulkConfirmResponseModel") : base(obj) + { + Id = id; + Error = error; + } + public Guid Id { get; set; } + public string Error { get; set; } } diff --git a/src/Api/Models/Response/SelectionReadOnlyResponseModel.cs b/src/Api/Models/Response/SelectionReadOnlyResponseModel.cs index a3ff0ddf6..0d4cc637d 100644 --- a/src/Api/Models/Response/SelectionReadOnlyResponseModel.cs +++ b/src/Api/Models/Response/SelectionReadOnlyResponseModel.cs @@ -1,23 +1,22 @@ using Bit.Core.Models.Data; -namespace Bit.Api.Models.Response -{ - public class SelectionReadOnlyResponseModel - { - public SelectionReadOnlyResponseModel(SelectionReadOnly selection) - { - if (selection == null) - { - throw new ArgumentNullException(nameof(selection)); - } +namespace Bit.Api.Models.Response; - Id = selection.Id.ToString(); - ReadOnly = selection.ReadOnly; - HidePasswords = selection.HidePasswords; +public class SelectionReadOnlyResponseModel +{ + public SelectionReadOnlyResponseModel(SelectionReadOnly selection) + { + if (selection == null) + { + throw new ArgumentNullException(nameof(selection)); } - public string Id { get; set; } - public bool ReadOnly { get; set; } - public bool HidePasswords { get; set; } + Id = selection.Id.ToString(); + ReadOnly = selection.ReadOnly; + HidePasswords = selection.HidePasswords; } + + public string Id { get; set; } + public bool ReadOnly { get; set; } + public bool HidePasswords { get; set; } } diff --git a/src/Api/Models/Response/SendAccessResponseModel.cs b/src/Api/Models/Response/SendAccessResponseModel.cs index 7e2adc04f..d4620385b 100644 --- a/src/Api/Models/Response/SendAccessResponseModel.cs +++ b/src/Api/Models/Response/SendAccessResponseModel.cs @@ -6,48 +6,47 @@ using Bit.Core.Models.Data; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Api.Models.Response +namespace Bit.Api.Models.Response; + +public class SendAccessResponseModel : ResponseModel { - public class SendAccessResponseModel : ResponseModel + public SendAccessResponseModel(Send send, GlobalSettings globalSettings) + : base("send-access") { - public SendAccessResponseModel(Send send, GlobalSettings globalSettings) - : base("send-access") + if (send == null) { - if (send == null) - { - throw new ArgumentNullException(nameof(send)); - } - - Id = CoreHelpers.Base64UrlEncode(send.Id.ToByteArray()); - Type = send.Type; - - SendData sendData; - switch (send.Type) - { - case SendType.File: - var fileData = JsonSerializer.Deserialize(send.Data); - sendData = fileData; - File = new SendFileModel(fileData); - break; - case SendType.Text: - var textData = JsonSerializer.Deserialize(send.Data); - sendData = textData; - Text = new SendTextModel(textData); - break; - default: - throw new ArgumentException("Unsupported " + nameof(Type) + "."); - } - - Name = sendData.Name; - ExpirationDate = send.ExpirationDate; + throw new ArgumentNullException(nameof(send)); } - public string Id { get; set; } - public SendType Type { get; set; } - public string Name { get; set; } - public SendFileModel File { get; set; } - public SendTextModel Text { get; set; } - public DateTime? ExpirationDate { get; set; } - public string CreatorIdentifier { get; set; } + Id = CoreHelpers.Base64UrlEncode(send.Id.ToByteArray()); + Type = send.Type; + + SendData sendData; + switch (send.Type) + { + case SendType.File: + var fileData = JsonSerializer.Deserialize(send.Data); + sendData = fileData; + File = new SendFileModel(fileData); + break; + case SendType.Text: + var textData = JsonSerializer.Deserialize(send.Data); + sendData = textData; + Text = new SendTextModel(textData); + break; + default: + throw new ArgumentException("Unsupported " + nameof(Type) + "."); + } + + Name = sendData.Name; + ExpirationDate = send.ExpirationDate; } + + public string Id { get; set; } + public SendType Type { get; set; } + public string Name { get; set; } + public SendFileModel File { get; set; } + public SendTextModel Text { get; set; } + public DateTime? ExpirationDate { get; set; } + public string CreatorIdentifier { get; set; } } diff --git a/src/Api/Models/Response/SendFileDownloadDataResponseModel.cs b/src/Api/Models/Response/SendFileDownloadDataResponseModel.cs index e8efed8a4..24e3a53f7 100644 --- a/src/Api/Models/Response/SendFileDownloadDataResponseModel.cs +++ b/src/Api/Models/Response/SendFileDownloadDataResponseModel.cs @@ -1,12 +1,11 @@ using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response -{ - public class SendFileDownloadDataResponseModel : ResponseModel - { - public string Id { get; set; } - public string Url { get; set; } +namespace Bit.Api.Models.Response; - public SendFileDownloadDataResponseModel() : base("send-fileDownload") { } - } +public class SendFileDownloadDataResponseModel : ResponseModel +{ + public string Id { get; set; } + public string Url { get; set; } + + public SendFileDownloadDataResponseModel() : base("send-fileDownload") { } } diff --git a/src/Api/Models/Response/SendFileUploadDataResponseModel.cs b/src/Api/Models/Response/SendFileUploadDataResponseModel.cs index 20e3694fe..0e7b4997c 100644 --- a/src/Api/Models/Response/SendFileUploadDataResponseModel.cs +++ b/src/Api/Models/Response/SendFileUploadDataResponseModel.cs @@ -1,15 +1,14 @@ using Bit.Core.Enums; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response +namespace Bit.Api.Models.Response; + +public class SendFileUploadDataResponseModel : ResponseModel { - public class SendFileUploadDataResponseModel : ResponseModel - { - public SendFileUploadDataResponseModel() : base("send-fileUpload") { } + public SendFileUploadDataResponseModel() : base("send-fileUpload") { } - public string Url { get; set; } - public FileUploadType FileUploadType { get; set; } - public SendResponseModel SendResponse { get; set; } + public string Url { get; set; } + public FileUploadType FileUploadType { get; set; } + public SendResponseModel SendResponse { get; set; } - } } diff --git a/src/Api/Models/Response/SendResponseModel.cs b/src/Api/Models/Response/SendResponseModel.cs index c4f88157d..42552d2a4 100644 --- a/src/Api/Models/Response/SendResponseModel.cs +++ b/src/Api/Models/Response/SendResponseModel.cs @@ -6,67 +6,66 @@ using Bit.Core.Models.Data; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Api.Models.Response +namespace Bit.Api.Models.Response; + +public class SendResponseModel : ResponseModel { - public class SendResponseModel : ResponseModel + public SendResponseModel(Send send, GlobalSettings globalSettings) + : base("send") { - public SendResponseModel(Send send, GlobalSettings globalSettings) - : base("send") + if (send == null) { - if (send == null) - { - throw new ArgumentNullException(nameof(send)); - } - - Id = send.Id.ToString(); - AccessId = CoreHelpers.Base64UrlEncode(send.Id.ToByteArray()); - Type = send.Type; - Key = send.Key; - MaxAccessCount = send.MaxAccessCount; - AccessCount = send.AccessCount; - RevisionDate = send.RevisionDate; - ExpirationDate = send.ExpirationDate; - DeletionDate = send.DeletionDate; - Password = send.Password; - Disabled = send.Disabled; - HideEmail = send.HideEmail.GetValueOrDefault(); - - SendData sendData; - switch (send.Type) - { - case SendType.File: - var fileData = JsonSerializer.Deserialize(send.Data); - sendData = fileData; - File = new SendFileModel(fileData); - break; - case SendType.Text: - var textData = JsonSerializer.Deserialize(send.Data); - sendData = textData; - Text = new SendTextModel(textData); - break; - default: - throw new ArgumentException("Unsupported " + nameof(Type) + "."); - } - - Name = sendData.Name; - Notes = sendData.Notes; + throw new ArgumentNullException(nameof(send)); } - public string Id { get; set; } - public string AccessId { get; set; } - public SendType Type { get; set; } - public string Name { get; set; } - public string Notes { get; set; } - public SendFileModel File { get; set; } - public SendTextModel Text { get; set; } - public string Key { get; set; } - public int? MaxAccessCount { get; set; } - public int AccessCount { get; set; } - public string Password { get; set; } - public bool Disabled { get; set; } - public DateTime RevisionDate { get; set; } - public DateTime? ExpirationDate { get; set; } - public DateTime DeletionDate { get; set; } - public bool HideEmail { get; set; } + Id = send.Id.ToString(); + AccessId = CoreHelpers.Base64UrlEncode(send.Id.ToByteArray()); + Type = send.Type; + Key = send.Key; + MaxAccessCount = send.MaxAccessCount; + AccessCount = send.AccessCount; + RevisionDate = send.RevisionDate; + ExpirationDate = send.ExpirationDate; + DeletionDate = send.DeletionDate; + Password = send.Password; + Disabled = send.Disabled; + HideEmail = send.HideEmail.GetValueOrDefault(); + + SendData sendData; + switch (send.Type) + { + case SendType.File: + var fileData = JsonSerializer.Deserialize(send.Data); + sendData = fileData; + File = new SendFileModel(fileData); + break; + case SendType.Text: + var textData = JsonSerializer.Deserialize(send.Data); + sendData = textData; + Text = new SendTextModel(textData); + break; + default: + throw new ArgumentException("Unsupported " + nameof(Type) + "."); + } + + Name = sendData.Name; + Notes = sendData.Notes; } + + public string Id { get; set; } + public string AccessId { get; set; } + public SendType Type { get; set; } + public string Name { get; set; } + public string Notes { get; set; } + public SendFileModel File { get; set; } + public SendTextModel Text { get; set; } + public string Key { get; set; } + public int? MaxAccessCount { get; set; } + public int AccessCount { get; set; } + public string Password { get; set; } + public bool Disabled { get; set; } + public DateTime RevisionDate { get; set; } + public DateTime? ExpirationDate { get; set; } + public DateTime DeletionDate { get; set; } + public bool HideEmail { get; set; } } diff --git a/src/Api/Models/Response/SubscriptionResponseModel.cs b/src/Api/Models/Response/SubscriptionResponseModel.cs index e8b9dbbcb..4888bd208 100644 --- a/src/Api/Models/Response/SubscriptionResponseModel.cs +++ b/src/Api/Models/Response/SubscriptionResponseModel.cs @@ -3,104 +3,103 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Business; using Bit.Core.Utilities; -namespace Bit.Api.Models.Response +namespace Bit.Api.Models.Response; + +public class SubscriptionResponseModel : ResponseModel { - public class SubscriptionResponseModel : ResponseModel + public SubscriptionResponseModel(User user, SubscriptionInfo subscription, UserLicense license) + : base("subscription") { - public SubscriptionResponseModel(User user, SubscriptionInfo subscription, UserLicense license) - : base("subscription") + Subscription = subscription.Subscription != null ? new BillingSubscription(subscription.Subscription) : null; + UpcomingInvoice = subscription.UpcomingInvoice != null ? + new BillingSubscriptionUpcomingInvoice(subscription.UpcomingInvoice) : null; + StorageName = user.Storage.HasValue ? CoreHelpers.ReadableBytesSize(user.Storage.Value) : null; + StorageGb = user.Storage.HasValue ? Math.Round(user.Storage.Value / 1073741824D, 2) : 0; // 1 GB + MaxStorageGb = user.MaxStorageGb; + License = license; + Expiration = License.Expires; + UsingInAppPurchase = subscription.UsingInAppPurchase; + } + + public SubscriptionResponseModel(User user, UserLicense license = null) + : base("subscription") + { + StorageName = user.Storage.HasValue ? CoreHelpers.ReadableBytesSize(user.Storage.Value) : null; + StorageGb = user.Storage.HasValue ? Math.Round(user.Storage.Value / 1073741824D, 2) : 0; // 1 GB + MaxStorageGb = user.MaxStorageGb; + Expiration = user.PremiumExpirationDate; + + if (license != null) { - Subscription = subscription.Subscription != null ? new BillingSubscription(subscription.Subscription) : null; - UpcomingInvoice = subscription.UpcomingInvoice != null ? - new BillingSubscriptionUpcomingInvoice(subscription.UpcomingInvoice) : null; - StorageName = user.Storage.HasValue ? CoreHelpers.ReadableBytesSize(user.Storage.Value) : null; - StorageGb = user.Storage.HasValue ? Math.Round(user.Storage.Value / 1073741824D, 2) : 0; // 1 GB - MaxStorageGb = user.MaxStorageGb; License = license; - Expiration = License.Expires; - UsingInAppPurchase = subscription.UsingInAppPurchase; - } - - public SubscriptionResponseModel(User user, UserLicense license = null) - : base("subscription") - { - StorageName = user.Storage.HasValue ? CoreHelpers.ReadableBytesSize(user.Storage.Value) : null; - StorageGb = user.Storage.HasValue ? Math.Round(user.Storage.Value / 1073741824D, 2) : 0; // 1 GB - MaxStorageGb = user.MaxStorageGb; - Expiration = user.PremiumExpirationDate; - - if (license != null) - { - License = license; - } - } - - public string StorageName { get; set; } - public double? StorageGb { get; set; } - public short? MaxStorageGb { get; set; } - public BillingSubscriptionUpcomingInvoice UpcomingInvoice { get; set; } - public BillingSubscription Subscription { get; set; } - public UserLicense License { get; set; } - public DateTime? Expiration { get; set; } - public bool UsingInAppPurchase { get; set; } - } - - public class BillingSubscription - { - public BillingSubscription(SubscriptionInfo.BillingSubscription sub) - { - Status = sub.Status; - TrialStartDate = sub.TrialStartDate; - TrialEndDate = sub.TrialEndDate; - PeriodStartDate = sub.PeriodStartDate; - PeriodEndDate = sub.PeriodEndDate; - CancelledDate = sub.CancelledDate; - CancelAtEndDate = sub.CancelAtEndDate; - Cancelled = sub.Cancelled; - if (sub.Items != null) - { - Items = sub.Items.Select(i => new BillingSubscriptionItem(i)); - } - } - - public DateTime? TrialStartDate { get; set; } - public DateTime? TrialEndDate { get; set; } - public DateTime? PeriodStartDate { get; set; } - public DateTime? PeriodEndDate { get; set; } - public DateTime? CancelledDate { get; set; } - public bool CancelAtEndDate { get; set; } - public string Status { get; set; } - public bool Cancelled { get; set; } - public IEnumerable Items { get; set; } = new List(); - - public class BillingSubscriptionItem - { - public BillingSubscriptionItem(SubscriptionInfo.BillingSubscription.BillingSubscriptionItem item) - { - Name = item.Name; - Amount = item.Amount; - Interval = item.Interval; - Quantity = item.Quantity; - SponsoredSubscriptionItem = item.SponsoredSubscriptionItem; - } - - public string Name { get; set; } - public decimal Amount { get; set; } - public int Quantity { get; set; } - public string Interval { get; set; } - public bool SponsoredSubscriptionItem { get; set; } } } - public class BillingSubscriptionUpcomingInvoice + public string StorageName { get; set; } + public double? StorageGb { get; set; } + public short? MaxStorageGb { get; set; } + public BillingSubscriptionUpcomingInvoice UpcomingInvoice { get; set; } + public BillingSubscription Subscription { get; set; } + public UserLicense License { get; set; } + public DateTime? Expiration { get; set; } + public bool UsingInAppPurchase { get; set; } +} + +public class BillingSubscription +{ + public BillingSubscription(SubscriptionInfo.BillingSubscription sub) { - public BillingSubscriptionUpcomingInvoice(SubscriptionInfo.BillingUpcomingInvoice inv) + Status = sub.Status; + TrialStartDate = sub.TrialStartDate; + TrialEndDate = sub.TrialEndDate; + PeriodStartDate = sub.PeriodStartDate; + PeriodEndDate = sub.PeriodEndDate; + CancelledDate = sub.CancelledDate; + CancelAtEndDate = sub.CancelAtEndDate; + Cancelled = sub.Cancelled; + if (sub.Items != null) { - Amount = inv.Amount; - Date = inv.Date; + Items = sub.Items.Select(i => new BillingSubscriptionItem(i)); + } + } + + public DateTime? TrialStartDate { get; set; } + public DateTime? TrialEndDate { get; set; } + public DateTime? PeriodStartDate { get; set; } + public DateTime? PeriodEndDate { get; set; } + public DateTime? CancelledDate { get; set; } + public bool CancelAtEndDate { get; set; } + public string Status { get; set; } + public bool Cancelled { get; set; } + public IEnumerable Items { get; set; } = new List(); + + public class BillingSubscriptionItem + { + public BillingSubscriptionItem(SubscriptionInfo.BillingSubscription.BillingSubscriptionItem item) + { + Name = item.Name; + Amount = item.Amount; + Interval = item.Interval; + Quantity = item.Quantity; + SponsoredSubscriptionItem = item.SponsoredSubscriptionItem; } + public string Name { get; set; } public decimal Amount { get; set; } - public DateTime? Date { get; set; } + public int Quantity { get; set; } + public string Interval { get; set; } + public bool SponsoredSubscriptionItem { get; set; } } } + +public class BillingSubscriptionUpcomingInvoice +{ + public BillingSubscriptionUpcomingInvoice(SubscriptionInfo.BillingUpcomingInvoice inv) + { + Amount = inv.Amount; + Date = inv.Date; + } + + public decimal Amount { get; set; } + public DateTime? Date { get; set; } +} diff --git a/src/Api/Models/Response/SyncResponseModel.cs b/src/Api/Models/Response/SyncResponseModel.cs index 8c9f12686..6d028b12f 100644 --- a/src/Api/Models/Response/SyncResponseModel.cs +++ b/src/Api/Models/Response/SyncResponseModel.cs @@ -5,44 +5,43 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Settings; using Core.Models.Data; -namespace Bit.Api.Models.Response -{ - public class SyncResponseModel : ResponseModel - { - public SyncResponseModel( - GlobalSettings globalSettings, - User user, - bool userTwoFactorEnabled, - bool userHasPremiumFromOrganization, - IEnumerable organizationUserDetails, - IEnumerable providerUserDetails, - IEnumerable providerUserOrganizationDetails, - IEnumerable folders, - IEnumerable collections, - IEnumerable ciphers, - IDictionary> collectionCiphersDict, - bool excludeDomains, - IEnumerable policies, - IEnumerable sends) - : base("sync") - { - Profile = new ProfileResponseModel(user, organizationUserDetails, providerUserDetails, - providerUserOrganizationDetails, userTwoFactorEnabled, userHasPremiumFromOrganization); - Folders = folders.Select(f => new FolderResponseModel(f)); - Ciphers = ciphers.Select(c => new CipherDetailsResponseModel(c, globalSettings, collectionCiphersDict)); - Collections = collections?.Select( - c => new CollectionDetailsResponseModel(c)) ?? new List(); - Domains = excludeDomains ? null : new DomainsResponseModel(user, false); - Policies = policies?.Select(p => new PolicyResponseModel(p)) ?? new List(); - Sends = sends.Select(s => new SendResponseModel(s, globalSettings)); - } +namespace Bit.Api.Models.Response; - public ProfileResponseModel Profile { get; set; } - public IEnumerable Folders { get; set; } - public IEnumerable Collections { get; set; } - public IEnumerable Ciphers { get; set; } - public DomainsResponseModel Domains { get; set; } - public IEnumerable Policies { get; set; } - public IEnumerable Sends { get; set; } +public class SyncResponseModel : ResponseModel +{ + public SyncResponseModel( + GlobalSettings globalSettings, + User user, + bool userTwoFactorEnabled, + bool userHasPremiumFromOrganization, + IEnumerable organizationUserDetails, + IEnumerable providerUserDetails, + IEnumerable providerUserOrganizationDetails, + IEnumerable folders, + IEnumerable collections, + IEnumerable ciphers, + IDictionary> collectionCiphersDict, + bool excludeDomains, + IEnumerable policies, + IEnumerable sends) + : base("sync") + { + Profile = new ProfileResponseModel(user, organizationUserDetails, providerUserDetails, + providerUserOrganizationDetails, userTwoFactorEnabled, userHasPremiumFromOrganization); + Folders = folders.Select(f => new FolderResponseModel(f)); + Ciphers = ciphers.Select(c => new CipherDetailsResponseModel(c, globalSettings, collectionCiphersDict)); + Collections = collections?.Select( + c => new CollectionDetailsResponseModel(c)) ?? new List(); + Domains = excludeDomains ? null : new DomainsResponseModel(user, false); + Policies = policies?.Select(p => new PolicyResponseModel(p)) ?? new List(); + Sends = sends.Select(s => new SendResponseModel(s, globalSettings)); } + + public ProfileResponseModel Profile { get; set; } + public IEnumerable Folders { get; set; } + public IEnumerable Collections { get; set; } + public IEnumerable Ciphers { get; set; } + public DomainsResponseModel Domains { get; set; } + public IEnumerable Policies { get; set; } + public IEnumerable Sends { get; set; } } diff --git a/src/Api/Models/Response/TaxInfoResponseModel.cs b/src/Api/Models/Response/TaxInfoResponseModel.cs index 6ba6bad45..c1cd51267 100644 --- a/src/Api/Models/Response/TaxInfoResponseModel.cs +++ b/src/Api/Models/Response/TaxInfoResponseModel.cs @@ -1,35 +1,34 @@ using Bit.Core.Models.Business; -namespace Bit.Api.Models.Response +namespace Bit.Api.Models.Response; + +public class TaxInfoResponseModel { - public class TaxInfoResponseModel + public TaxInfoResponseModel() { } + + public TaxInfoResponseModel(TaxInfo taxInfo) { - public TaxInfoResponseModel() { } - - public TaxInfoResponseModel(TaxInfo taxInfo) + if (taxInfo == null) { - if (taxInfo == null) - { - return; - } - - TaxIdNumber = taxInfo.TaxIdNumber; - TaxIdType = taxInfo.TaxIdType; - Line1 = taxInfo.BillingAddressLine1; - Line2 = taxInfo.BillingAddressLine2; - City = taxInfo.BillingAddressCity; - State = taxInfo.BillingAddressState; - PostalCode = taxInfo.BillingAddressPostalCode; - Country = taxInfo.BillingAddressCountry; + return; } - public string TaxIdNumber { get; set; } - public string TaxIdType { get; set; } - public string Line1 { get; set; } - public string Line2 { get; set; } - public string City { get; set; } - public string State { get; set; } - public string PostalCode { get; set; } - public string Country { get; set; } + TaxIdNumber = taxInfo.TaxIdNumber; + TaxIdType = taxInfo.TaxIdType; + Line1 = taxInfo.BillingAddressLine1; + Line2 = taxInfo.BillingAddressLine2; + City = taxInfo.BillingAddressCity; + State = taxInfo.BillingAddressState; + PostalCode = taxInfo.BillingAddressPostalCode; + Country = taxInfo.BillingAddressCountry; } + + public string TaxIdNumber { get; set; } + public string TaxIdType { get; set; } + public string Line1 { get; set; } + public string Line2 { get; set; } + public string City { get; set; } + public string State { get; set; } + public string PostalCode { get; set; } + public string Country { get; set; } } diff --git a/src/Api/Models/Response/TaxRateResponseModel.cs b/src/Api/Models/Response/TaxRateResponseModel.cs index ec08cb7f7..2c3335314 100644 --- a/src/Api/Models/Response/TaxRateResponseModel.cs +++ b/src/Api/Models/Response/TaxRateResponseModel.cs @@ -1,29 +1,28 @@ using Bit.Core.Entities; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response -{ - public class TaxRateResponseModel : ResponseModel - { - public TaxRateResponseModel(TaxRate taxRate) - : base("profile") - { - if (taxRate == null) - { - throw new ArgumentNullException(nameof(taxRate)); - } +namespace Bit.Api.Models.Response; - Id = taxRate.Id; - Country = taxRate.Country; - State = taxRate.State; - PostalCode = taxRate.PostalCode; - Rate = taxRate.Rate; +public class TaxRateResponseModel : ResponseModel +{ + public TaxRateResponseModel(TaxRate taxRate) + : base("profile") + { + if (taxRate == null) + { + throw new ArgumentNullException(nameof(taxRate)); } - public string Id { get; set; } - public string Country { get; set; } - public string State { get; set; } - public string PostalCode { get; set; } - public decimal Rate { get; set; } + Id = taxRate.Id; + Country = taxRate.Country; + State = taxRate.State; + PostalCode = taxRate.PostalCode; + Rate = taxRate.Rate; } + + public string Id { get; set; } + public string Country { get; set; } + public string State { get; set; } + public string PostalCode { get; set; } + public decimal Rate { get; set; } } diff --git a/src/Api/Models/Response/TwoFactor/TwoFactorAuthenticatorResponseModel.cs b/src/Api/Models/Response/TwoFactor/TwoFactorAuthenticatorResponseModel.cs index 3747a411a..0a283b7e6 100644 --- a/src/Api/Models/Response/TwoFactor/TwoFactorAuthenticatorResponseModel.cs +++ b/src/Api/Models/Response/TwoFactor/TwoFactorAuthenticatorResponseModel.cs @@ -3,33 +3,32 @@ using Bit.Core.Enums; using Bit.Core.Models.Api; using OtpNet; -namespace Bit.Api.Models.Response.TwoFactor -{ - public class TwoFactorAuthenticatorResponseModel : ResponseModel - { - public TwoFactorAuthenticatorResponseModel(User user) - : base("twoFactorAuthenticator") - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } +namespace Bit.Api.Models.Response.TwoFactor; - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Authenticator); - if (provider?.MetaData?.ContainsKey("Key") ?? false) - { - Key = (string)provider.MetaData["Key"]; - Enabled = provider.Enabled; - } - else - { - var key = KeyGeneration.GenerateRandomKey(20); - Key = Base32Encoding.ToString(key); - Enabled = false; - } +public class TwoFactorAuthenticatorResponseModel : ResponseModel +{ + public TwoFactorAuthenticatorResponseModel(User user) + : base("twoFactorAuthenticator") + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); } - public bool Enabled { get; set; } - public string Key { get; set; } + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Authenticator); + if (provider?.MetaData?.ContainsKey("Key") ?? false) + { + Key = (string)provider.MetaData["Key"]; + Enabled = provider.Enabled; + } + else + { + var key = KeyGeneration.GenerateRandomKey(20); + Key = Base32Encoding.ToString(key); + Enabled = false; + } } + + public bool Enabled { get; set; } + public string Key { get; set; } } diff --git a/src/Api/Models/Response/TwoFactor/TwoFactorDuoResponseModel.cs b/src/Api/Models/Response/TwoFactor/TwoFactorDuoResponseModel.cs index c2461abdb..3331a8d76 100644 --- a/src/Api/Models/Response/TwoFactor/TwoFactorDuoResponseModel.cs +++ b/src/Api/Models/Response/TwoFactor/TwoFactorDuoResponseModel.cs @@ -3,64 +3,63 @@ using Bit.Core.Enums; using Bit.Core.Models; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.TwoFactor +namespace Bit.Api.Models.Response.TwoFactor; + +public class TwoFactorDuoResponseModel : ResponseModel { - public class TwoFactorDuoResponseModel : ResponseModel + private const string ResponseObj = "twoFactorDuo"; + + public TwoFactorDuoResponseModel(User user) + : base(ResponseObj) { - private const string ResponseObj = "twoFactorDuo"; - - public TwoFactorDuoResponseModel(User user) - : base(ResponseObj) + if (user == null) { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Duo); - Build(provider); + throw new ArgumentNullException(nameof(user)); } - public TwoFactorDuoResponseModel(Organization org) - : base(ResponseObj) - { - if (org == null) - { - throw new ArgumentNullException(nameof(org)); - } + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Duo); + Build(provider); + } - var provider = org.GetTwoFactorProvider(TwoFactorProviderType.OrganizationDuo); - Build(provider); + public TwoFactorDuoResponseModel(Organization org) + : base(ResponseObj) + { + if (org == null) + { + throw new ArgumentNullException(nameof(org)); } - public bool Enabled { get; set; } - public string Host { get; set; } - public string SecretKey { get; set; } - public string IntegrationKey { get; set; } + var provider = org.GetTwoFactorProvider(TwoFactorProviderType.OrganizationDuo); + Build(provider); + } - private void Build(TwoFactorProvider provider) + public bool Enabled { get; set; } + public string Host { get; set; } + public string SecretKey { get; set; } + public string IntegrationKey { get; set; } + + private void Build(TwoFactorProvider provider) + { + if (provider?.MetaData != null && provider.MetaData.Count > 0) { - if (provider?.MetaData != null && provider.MetaData.Count > 0) - { - Enabled = provider.Enabled; + Enabled = provider.Enabled; - if (provider.MetaData.ContainsKey("Host")) - { - Host = (string)provider.MetaData["Host"]; - } - if (provider.MetaData.ContainsKey("SKey")) - { - SecretKey = (string)provider.MetaData["SKey"]; - } - if (provider.MetaData.ContainsKey("IKey")) - { - IntegrationKey = (string)provider.MetaData["IKey"]; - } - } - else + if (provider.MetaData.ContainsKey("Host")) { - Enabled = false; + Host = (string)provider.MetaData["Host"]; } + if (provider.MetaData.ContainsKey("SKey")) + { + SecretKey = (string)provider.MetaData["SKey"]; + } + if (provider.MetaData.ContainsKey("IKey")) + { + IntegrationKey = (string)provider.MetaData["IKey"]; + } + } + else + { + Enabled = false; } } } diff --git a/src/Api/Models/Response/TwoFactor/TwoFactorEmailResponseModel.cs b/src/Api/Models/Response/TwoFactor/TwoFactorEmailResponseModel.cs index 9f8fecc4f..f2be91f9d 100644 --- a/src/Api/Models/Response/TwoFactor/TwoFactorEmailResponseModel.cs +++ b/src/Api/Models/Response/TwoFactor/TwoFactorEmailResponseModel.cs @@ -2,31 +2,30 @@ using Bit.Core.Enums; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.TwoFactor -{ - public class TwoFactorEmailResponseModel : ResponseModel - { - public TwoFactorEmailResponseModel(User user) - : base("twoFactorEmail") - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } +namespace Bit.Api.Models.Response.TwoFactor; - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); - if (provider?.MetaData?.ContainsKey("Email") ?? false) - { - Email = (string)provider.MetaData["Email"]; - Enabled = provider.Enabled; - } - else - { - Enabled = false; - } +public class TwoFactorEmailResponseModel : ResponseModel +{ + public TwoFactorEmailResponseModel(User user) + : base("twoFactorEmail") + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); } - public bool Enabled { get; set; } - public string Email { get; set; } + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); + if (provider?.MetaData?.ContainsKey("Email") ?? false) + { + Email = (string)provider.MetaData["Email"]; + Enabled = provider.Enabled; + } + else + { + Enabled = false; + } } + + public bool Enabled { get; set; } + public string Email { get; set; } } diff --git a/src/Api/Models/Response/TwoFactor/TwoFactorProviderResponseModel.cs b/src/Api/Models/Response/TwoFactor/TwoFactorProviderResponseModel.cs index c742d8b2f..0e8522104 100644 --- a/src/Api/Models/Response/TwoFactor/TwoFactorProviderResponseModel.cs +++ b/src/Api/Models/Response/TwoFactor/TwoFactorProviderResponseModel.cs @@ -3,51 +3,50 @@ using Bit.Core.Enums; using Bit.Core.Models; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.TwoFactor +namespace Bit.Api.Models.Response.TwoFactor; + +public class TwoFactorProviderResponseModel : ResponseModel { - public class TwoFactorProviderResponseModel : ResponseModel + private const string ResponseObj = "twoFactorProvider"; + + public TwoFactorProviderResponseModel(TwoFactorProviderType type, TwoFactorProvider provider) + : base(ResponseObj) { - private const string ResponseObj = "twoFactorProvider"; - - public TwoFactorProviderResponseModel(TwoFactorProviderType type, TwoFactorProvider provider) - : base(ResponseObj) + if (provider == null) { - if (provider == null) - { - throw new ArgumentNullException(nameof(provider)); - } - - Enabled = provider.Enabled; - Type = type; + throw new ArgumentNullException(nameof(provider)); } - public TwoFactorProviderResponseModel(TwoFactorProviderType type, User user) - : base(ResponseObj) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - var provider = user.GetTwoFactorProvider(type); - Enabled = provider?.Enabled ?? false; - Type = type; - } - - public TwoFactorProviderResponseModel(TwoFactorProviderType type, Organization organization) - : base(ResponseObj) - { - if (organization == null) - { - throw new ArgumentNullException(nameof(organization)); - } - - var provider = organization.GetTwoFactorProvider(type); - Enabled = provider?.Enabled ?? false; - Type = type; - } - - public bool Enabled { get; set; } - public TwoFactorProviderType Type { get; set; } + Enabled = provider.Enabled; + Type = type; } + + public TwoFactorProviderResponseModel(TwoFactorProviderType type, User user) + : base(ResponseObj) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + var provider = user.GetTwoFactorProvider(type); + Enabled = provider?.Enabled ?? false; + Type = type; + } + + public TwoFactorProviderResponseModel(TwoFactorProviderType type, Organization organization) + : base(ResponseObj) + { + if (organization == null) + { + throw new ArgumentNullException(nameof(organization)); + } + + var provider = organization.GetTwoFactorProvider(type); + Enabled = provider?.Enabled ?? false; + Type = type; + } + + public bool Enabled { get; set; } + public TwoFactorProviderType Type { get; set; } } diff --git a/src/Api/Models/Response/TwoFactor/TwoFactorRecoverResponseModel.cs b/src/Api/Models/Response/TwoFactor/TwoFactorRecoverResponseModel.cs index 5d87a0e94..26324de7c 100644 --- a/src/Api/Models/Response/TwoFactor/TwoFactorRecoverResponseModel.cs +++ b/src/Api/Models/Response/TwoFactor/TwoFactorRecoverResponseModel.cs @@ -1,21 +1,20 @@ using Bit.Core.Entities; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.TwoFactor -{ - public class TwoFactorRecoverResponseModel : ResponseModel - { - public TwoFactorRecoverResponseModel(User user) - : base("twoFactorRecover") - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } +namespace Bit.Api.Models.Response.TwoFactor; - Code = user.TwoFactorRecoveryCode; +public class TwoFactorRecoverResponseModel : ResponseModel +{ + public TwoFactorRecoverResponseModel(User user) + : base("twoFactorRecover") + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); } - public string Code { get; set; } + Code = user.TwoFactorRecoveryCode; } + + public string Code { get; set; } } diff --git a/src/Api/Models/Response/TwoFactor/TwoFactorWebAuthnResponseModel.cs b/src/Api/Models/Response/TwoFactor/TwoFactorWebAuthnResponseModel.cs index 05c3b2f44..3e2ab2bc6 100644 --- a/src/Api/Models/Response/TwoFactor/TwoFactorWebAuthnResponseModel.cs +++ b/src/Api/Models/Response/TwoFactor/TwoFactorWebAuthnResponseModel.cs @@ -3,40 +3,39 @@ using Bit.Core.Enums; using Bit.Core.Models; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.TwoFactor +namespace Bit.Api.Models.Response.TwoFactor; + +public class TwoFactorWebAuthnResponseModel : ResponseModel { - public class TwoFactorWebAuthnResponseModel : ResponseModel + public TwoFactorWebAuthnResponseModel(User user) + : base("twoFactorWebAuthn") { - public TwoFactorWebAuthnResponseModel(User user) - : base("twoFactorWebAuthn") + if (user == null) { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); - Enabled = provider?.Enabled ?? false; - Keys = provider?.MetaData? - .Where(k => k.Key.StartsWith("Key")) - .Select(k => new KeyModel(k.Key, new TwoFactorProvider.WebAuthnData((dynamic)k.Value))); + throw new ArgumentNullException(nameof(user)); } - public bool Enabled { get; set; } - public IEnumerable Keys { get; set; } + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); + Enabled = provider?.Enabled ?? false; + Keys = provider?.MetaData? + .Where(k => k.Key.StartsWith("Key")) + .Select(k => new KeyModel(k.Key, new TwoFactorProvider.WebAuthnData((dynamic)k.Value))); + } - public class KeyModel + public bool Enabled { get; set; } + public IEnumerable Keys { get; set; } + + public class KeyModel + { + public KeyModel(string id, TwoFactorProvider.WebAuthnData data) { - public KeyModel(string id, TwoFactorProvider.WebAuthnData data) - { - Name = data.Name; - Id = Convert.ToInt32(id.Replace("Key", string.Empty)); - Migrated = data.Migrated; - } - - public string Name { get; set; } - public int Id { get; set; } - public bool Migrated { get; set; } + Name = data.Name; + Id = Convert.ToInt32(id.Replace("Key", string.Empty)); + Migrated = data.Migrated; } + + public string Name { get; set; } + public int Id { get; set; } + public bool Migrated { get; set; } } } diff --git a/src/Api/Models/Response/TwoFactor/TwoFactorYubiKeyResponseModel.cs b/src/Api/Models/Response/TwoFactor/TwoFactorYubiKeyResponseModel.cs index 9654bd1e6..48c7670c3 100644 --- a/src/Api/Models/Response/TwoFactor/TwoFactorYubiKeyResponseModel.cs +++ b/src/Api/Models/Response/TwoFactor/TwoFactorYubiKeyResponseModel.cs @@ -2,60 +2,59 @@ using Bit.Core.Enums; using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response.TwoFactor +namespace Bit.Api.Models.Response.TwoFactor; + +public class TwoFactorYubiKeyResponseModel : ResponseModel { - public class TwoFactorYubiKeyResponseModel : ResponseModel + public TwoFactorYubiKeyResponseModel(User user) + : base("twoFactorYubiKey") { - public TwoFactorYubiKeyResponseModel(User user) - : base("twoFactorYubiKey") + if (user == null) { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.YubiKey); - if (provider?.MetaData != null && provider.MetaData.Count > 0) - { - Enabled = provider.Enabled; - - if (provider.MetaData.ContainsKey("Key1")) - { - Key1 = (string)provider.MetaData["Key1"]; - } - if (provider.MetaData.ContainsKey("Key2")) - { - Key2 = (string)provider.MetaData["Key2"]; - } - if (provider.MetaData.ContainsKey("Key3")) - { - Key3 = (string)provider.MetaData["Key3"]; - } - if (provider.MetaData.ContainsKey("Key4")) - { - Key4 = (string)provider.MetaData["Key4"]; - } - if (provider.MetaData.ContainsKey("Key5")) - { - Key5 = (string)provider.MetaData["Key5"]; - } - if (provider.MetaData.ContainsKey("Nfc")) - { - Nfc = (bool)provider.MetaData["Nfc"]; - } - } - else - { - Enabled = false; - } + throw new ArgumentNullException(nameof(user)); } - public bool Enabled { get; set; } - public string Key1 { get; set; } - public string Key2 { get; set; } - public string Key3 { get; set; } - public string Key4 { get; set; } - public string Key5 { get; set; } - public bool Nfc { get; set; } + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.YubiKey); + if (provider?.MetaData != null && provider.MetaData.Count > 0) + { + Enabled = provider.Enabled; + + if (provider.MetaData.ContainsKey("Key1")) + { + Key1 = (string)provider.MetaData["Key1"]; + } + if (provider.MetaData.ContainsKey("Key2")) + { + Key2 = (string)provider.MetaData["Key2"]; + } + if (provider.MetaData.ContainsKey("Key3")) + { + Key3 = (string)provider.MetaData["Key3"]; + } + if (provider.MetaData.ContainsKey("Key4")) + { + Key4 = (string)provider.MetaData["Key4"]; + } + if (provider.MetaData.ContainsKey("Key5")) + { + Key5 = (string)provider.MetaData["Key5"]; + } + if (provider.MetaData.ContainsKey("Nfc")) + { + Nfc = (bool)provider.MetaData["Nfc"]; + } + } + else + { + Enabled = false; + } } + + public bool Enabled { get; set; } + public string Key1 { get; set; } + public string Key2 { get; set; } + public string Key3 { get; set; } + public string Key4 { get; set; } + public string Key5 { get; set; } + public bool Nfc { get; set; } } diff --git a/src/Api/Models/Response/UserKeyResponseModel.cs b/src/Api/Models/Response/UserKeyResponseModel.cs index b31f1e95a..d80571993 100644 --- a/src/Api/Models/Response/UserKeyResponseModel.cs +++ b/src/Api/Models/Response/UserKeyResponseModel.cs @@ -1,17 +1,16 @@ using Bit.Core.Models.Api; -namespace Bit.Api.Models.Response -{ - public class UserKeyResponseModel : ResponseModel - { - public UserKeyResponseModel(Guid id, string key) - : base("userKey") - { - UserId = id.ToString(); - PublicKey = key; - } +namespace Bit.Api.Models.Response; - public string UserId { get; set; } - public string PublicKey { get; set; } +public class UserKeyResponseModel : ResponseModel +{ + public UserKeyResponseModel(Guid id, string key) + : base("userKey") + { + UserId = id.ToString(); + PublicKey = key; } + + public string UserId { get; set; } + public string PublicKey { get; set; } } diff --git a/src/Api/Models/SendFileModel.cs b/src/Api/Models/SendFileModel.cs index 653510c89..bfe10f86f 100644 --- a/src/Api/Models/SendFileModel.cs +++ b/src/Api/Models/SendFileModel.cs @@ -2,26 +2,25 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models +namespace Bit.Api.Models; + +public class SendFileModel { - public class SendFileModel + public SendFileModel() { } + + public SendFileModel(SendFileData data) { - public SendFileModel() { } - - public SendFileModel(SendFileData data) - { - Id = data.Id; - FileName = data.FileName; - Size = data.Size; - SizeName = CoreHelpers.ReadableBytesSize(data.Size); - } - - public string Id { get; set; } - [EncryptedString] - [EncryptedStringLength(1000)] - public string FileName { get; set; } - [JsonNumberHandling(JsonNumberHandling.AllowReadingFromString | JsonNumberHandling.WriteAsString)] - public long? Size { get; set; } - public string SizeName { get; set; } + Id = data.Id; + FileName = data.FileName; + Size = data.Size; + SizeName = CoreHelpers.ReadableBytesSize(data.Size); } + + public string Id { get; set; } + [EncryptedString] + [EncryptedStringLength(1000)] + public string FileName { get; set; } + [JsonNumberHandling(JsonNumberHandling.AllowReadingFromString | JsonNumberHandling.WriteAsString)] + public long? Size { get; set; } + public string SizeName { get; set; } } diff --git a/src/Api/Models/SendTextModel.cs b/src/Api/Models/SendTextModel.cs index a362a61d9..ba2e6f8a6 100644 --- a/src/Api/Models/SendTextModel.cs +++ b/src/Api/Models/SendTextModel.cs @@ -1,21 +1,20 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.Models +namespace Bit.Api.Models; + +public class SendTextModel { - public class SendTextModel + public SendTextModel() { } + + public SendTextModel(SendTextData data) { - public SendTextModel() { } - - public SendTextModel(SendTextData data) - { - Text = data.Text; - Hidden = data.Hidden; - } - - [EncryptedString] - [EncryptedStringLength(1000)] - public string Text { get; set; } - public bool Hidden { get; set; } + Text = data.Text; + Hidden = data.Hidden; } + + [EncryptedString] + [EncryptedStringLength(1000)] + public string Text { get; set; } + public bool Hidden { get; set; } } diff --git a/src/Api/Program.cs b/src/Api/Program.cs index bcd6284af..b7e80d6c2 100644 --- a/src/Api/Program.cs +++ b/src/Api/Program.cs @@ -3,46 +3,45 @@ using Bit.Core.Utilities; using Microsoft.IdentityModel.Tokens; using Serilog.Events; -namespace Bit.Api +namespace Bit.Api; + +public class Program { - public class Program + public static void Main(string[] args) { - public static void Main(string[] args) - { - Host - .CreateDefaultBuilder(args) - .ConfigureCustomAppConfiguration(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => + Host + .CreateDefaultBuilder(args) + .ConfigureCustomAppConfiguration(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => + { + var context = e.Properties["SourceContext"].ToString(); + if (e.Exception != null && + (e.Exception.GetType() == typeof(SecurityTokenValidationException) || + e.Exception.Message == "Bad security stamp.")) { - var context = e.Properties["SourceContext"].ToString(); - if (e.Exception != null && - (e.Exception.GetType() == typeof(SecurityTokenValidationException) || - e.Exception.Message == "Bad security stamp.")) - { - return false; - } + return false; + } - if (e.Level == LogEventLevel.Information && - context.Contains(typeof(IpRateLimitMiddleware).FullName)) - { - return true; - } + if (e.Level == LogEventLevel.Information && + context.Contains(typeof(IpRateLimitMiddleware).FullName)) + { + return true; + } - if (context.Contains("IdentityServer4.Validation.TokenValidator") || - context.Contains("IdentityServer4.Validation.TokenRequestValidator")) - { - return e.Level > LogEventLevel.Error; - } + if (context.Contains("IdentityServer4.Validation.TokenValidator") || + context.Contains("IdentityServer4.Validation.TokenRequestValidator")) + { + return e.Level > LogEventLevel.Error; + } - return e.Level >= LogEventLevel.Error; - })); - }) - .Build() - .Run(); - } + return e.Level >= LogEventLevel.Error; + })); + }) + .Build() + .Run(); } } diff --git a/src/Api/Public/Controllers/CollectionsController.cs b/src/Api/Public/Controllers/CollectionsController.cs index 677d53861..ae56d6824 100644 --- a/src/Api/Public/Controllers/CollectionsController.cs +++ b/src/Api/Public/Controllers/CollectionsController.cs @@ -7,114 +7,113 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Public.Controllers +namespace Bit.Api.Public.Controllers; + +[Route("public/collections")] +[Authorize("Organization")] +public class CollectionsController : Controller { - [Route("public/collections")] - [Authorize("Organization")] - public class CollectionsController : Controller + private readonly ICollectionRepository _collectionRepository; + private readonly ICollectionService _collectionService; + private readonly ICurrentContext _currentContext; + + public CollectionsController( + ICollectionRepository collectionRepository, + ICollectionService collectionService, + ICurrentContext currentContext) { - private readonly ICollectionRepository _collectionRepository; - private readonly ICollectionService _collectionService; - private readonly ICurrentContext _currentContext; + _collectionRepository = collectionRepository; + _collectionService = collectionService; + _currentContext = currentContext; + } - public CollectionsController( - ICollectionRepository collectionRepository, - ICollectionService collectionService, - ICurrentContext currentContext) + /// + /// Retrieve a collection. + /// + /// + /// Retrieves the details of an existing collection. You need only supply the unique collection identifier + /// that was returned upon collection creation. + /// + /// The identifier of the collection to be retrieved. + [HttpGet("{id}")] + [ProducesResponseType(typeof(CollectionResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Get(Guid id) + { + var collectionWithGroups = await _collectionRepository.GetByIdWithGroupsAsync(id); + var collection = collectionWithGroups?.Item1; + if (collection == null || collection.OrganizationId != _currentContext.OrganizationId) { - _collectionRepository = collectionRepository; - _collectionService = collectionService; - _currentContext = currentContext; + return new NotFoundResult(); } + var response = new CollectionResponseModel(collection, collectionWithGroups.Item2); + return new JsonResult(response); + } - /// - /// Retrieve a collection. - /// - /// - /// Retrieves the details of an existing collection. You need only supply the unique collection identifier - /// that was returned upon collection creation. - /// - /// The identifier of the collection to be retrieved. - [HttpGet("{id}")] - [ProducesResponseType(typeof(CollectionResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Get(Guid id) - { - var collectionWithGroups = await _collectionRepository.GetByIdWithGroupsAsync(id); - var collection = collectionWithGroups?.Item1; - if (collection == null || collection.OrganizationId != _currentContext.OrganizationId) - { - return new NotFoundResult(); - } - var response = new CollectionResponseModel(collection, collectionWithGroups.Item2); - return new JsonResult(response); - } + /// + /// List all collections. + /// + /// + /// Returns a list of your organization's collections. + /// Collection objects listed in this call do not include information about their associated groups. + /// + [HttpGet] + [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] + public async Task List() + { + var collections = await _collectionRepository.GetManyByOrganizationIdAsync( + _currentContext.OrganizationId.Value); + // TODO: Get all CollectionGroup associations for the organization and marry them up here for the response. + var collectionResponses = collections.Select(c => new CollectionResponseModel(c, null)); + var response = new ListResponseModel(collectionResponses); + return new JsonResult(response); + } - /// - /// List all collections. - /// - /// - /// Returns a list of your organization's collections. - /// Collection objects listed in this call do not include information about their associated groups. - /// - [HttpGet] - [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] - public async Task List() + /// + /// Update a collection. + /// + /// + /// Updates the specified collection object. If a property is not provided, + /// the value of the existing property will be reset. + /// + /// The identifier of the collection to be updated. + /// The request model. + [HttpPut("{id}")] + [ProducesResponseType(typeof(CollectionResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Put(Guid id, [FromBody] CollectionUpdateRequestModel model) + { + var existingCollection = await _collectionRepository.GetByIdAsync(id); + if (existingCollection == null || existingCollection.OrganizationId != _currentContext.OrganizationId) { - var collections = await _collectionRepository.GetManyByOrganizationIdAsync( - _currentContext.OrganizationId.Value); - // TODO: Get all CollectionGroup associations for the organization and marry them up here for the response. - var collectionResponses = collections.Select(c => new CollectionResponseModel(c, null)); - var response = new ListResponseModel(collectionResponses); - return new JsonResult(response); + return new NotFoundResult(); } + var updatedCollection = model.ToCollection(existingCollection); + var associations = model.Groups?.Select(c => c.ToSelectionReadOnly()); + await _collectionService.SaveAsync(updatedCollection, associations); + var response = new CollectionResponseModel(updatedCollection, associations); + return new JsonResult(response); + } - /// - /// Update a collection. - /// - /// - /// Updates the specified collection object. If a property is not provided, - /// the value of the existing property will be reset. - /// - /// The identifier of the collection to be updated. - /// The request model. - [HttpPut("{id}")] - [ProducesResponseType(typeof(CollectionResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Put(Guid id, [FromBody] CollectionUpdateRequestModel model) + /// + /// Delete a collection. + /// + /// + /// Permanently deletes a collection. This cannot be undone. + /// + /// The identifier of the collection to be deleted. + [HttpDelete("{id}")] + [ProducesResponseType((int)HttpStatusCode.OK)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Delete(Guid id) + { + var collection = await _collectionRepository.GetByIdAsync(id); + if (collection == null || collection.OrganizationId != _currentContext.OrganizationId) { - var existingCollection = await _collectionRepository.GetByIdAsync(id); - if (existingCollection == null || existingCollection.OrganizationId != _currentContext.OrganizationId) - { - return new NotFoundResult(); - } - var updatedCollection = model.ToCollection(existingCollection); - var associations = model.Groups?.Select(c => c.ToSelectionReadOnly()); - await _collectionService.SaveAsync(updatedCollection, associations); - var response = new CollectionResponseModel(updatedCollection, associations); - return new JsonResult(response); - } - - /// - /// Delete a collection. - /// - /// - /// Permanently deletes a collection. This cannot be undone. - /// - /// The identifier of the collection to be deleted. - [HttpDelete("{id}")] - [ProducesResponseType((int)HttpStatusCode.OK)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Delete(Guid id) - { - var collection = await _collectionRepository.GetByIdAsync(id); - if (collection == null || collection.OrganizationId != _currentContext.OrganizationId) - { - return new NotFoundResult(); - } - await _collectionRepository.DeleteAsync(collection); - return new OkResult(); + return new NotFoundResult(); } + await _collectionRepository.DeleteAsync(collection); + return new OkResult(); } } diff --git a/src/Api/Public/Controllers/EventsController.cs b/src/Api/Public/Controllers/EventsController.cs index 5fe5bdb7b..6e9c734c1 100644 --- a/src/Api/Public/Controllers/EventsController.cs +++ b/src/Api/Public/Controllers/EventsController.cs @@ -7,65 +7,64 @@ using Bit.Core.Repositories; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Public.Controllers +namespace Bit.Api.Public.Controllers; + +[Route("public/events")] +[Authorize("Organization")] +public class EventsController : Controller { - [Route("public/events")] - [Authorize("Organization")] - public class EventsController : Controller + private readonly IEventRepository _eventRepository; + private readonly ICipherRepository _cipherRepository; + private readonly ICurrentContext _currentContext; + + public EventsController( + IEventRepository eventRepository, + ICipherRepository cipherRepository, + ICurrentContext currentContext) { - private readonly IEventRepository _eventRepository; - private readonly ICipherRepository _cipherRepository; - private readonly ICurrentContext _currentContext; + _eventRepository = eventRepository; + _cipherRepository = cipherRepository; + _currentContext = currentContext; + } - public EventsController( - IEventRepository eventRepository, - ICipherRepository cipherRepository, - ICurrentContext currentContext) + /// + /// List all events. + /// + /// + /// Returns a filtered list of your organization's event logs, paged by a continuation token. + /// If no filters are provided, it will return the last 30 days of event for the organization. + /// + [HttpGet] + [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] + public async Task List([FromQuery] EventFilterRequestModel request) + { + var dateRange = request.ToDateRange(); + var result = new PagedResult(); + if (request.ActingUserId.HasValue) { - _eventRepository = eventRepository; - _cipherRepository = cipherRepository; - _currentContext = currentContext; + result = await _eventRepository.GetManyByOrganizationActingUserAsync( + _currentContext.OrganizationId.Value, request.ActingUserId.Value, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = request.ContinuationToken }); } - - /// - /// List all events. - /// - /// - /// Returns a filtered list of your organization's event logs, paged by a continuation token. - /// If no filters are provided, it will return the last 30 days of event for the organization. - /// - [HttpGet] - [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] - public async Task List([FromQuery] EventFilterRequestModel request) + else if (request.ItemId.HasValue) { - var dateRange = request.ToDateRange(); - var result = new PagedResult(); - if (request.ActingUserId.HasValue) + var cipher = await _cipherRepository.GetByIdAsync(request.ItemId.Value); + if (cipher != null && cipher.OrganizationId == _currentContext.OrganizationId.Value) { - result = await _eventRepository.GetManyByOrganizationActingUserAsync( - _currentContext.OrganizationId.Value, request.ActingUserId.Value, dateRange.Item1, dateRange.Item2, + result = await _eventRepository.GetManyByCipherAsync( + cipher, dateRange.Item1, dateRange.Item2, new PageOptions { ContinuationToken = request.ContinuationToken }); } - else if (request.ItemId.HasValue) - { - var cipher = await _cipherRepository.GetByIdAsync(request.ItemId.Value); - if (cipher != null && cipher.OrganizationId == _currentContext.OrganizationId.Value) - { - result = await _eventRepository.GetManyByCipherAsync( - cipher, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = request.ContinuationToken }); - } - } - else - { - result = await _eventRepository.GetManyByOrganizationAsync( - _currentContext.OrganizationId.Value, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = request.ContinuationToken }); - } - - var eventResponses = result.Data.Select(e => new EventResponseModel(e)); - var response = new ListResponseModel(eventResponses, result.ContinuationToken); - return new JsonResult(response); } + else + { + result = await _eventRepository.GetManyByOrganizationAsync( + _currentContext.OrganizationId.Value, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = request.ContinuationToken }); + } + + var eventResponses = result.Data.Select(e => new EventResponseModel(e)); + var response = new ListResponseModel(eventResponses, result.ContinuationToken); + return new JsonResult(response); } } diff --git a/src/Api/Public/Controllers/GroupsController.cs b/src/Api/Public/Controllers/GroupsController.cs index ef29db568..f65f7b9fe 100644 --- a/src/Api/Public/Controllers/GroupsController.cs +++ b/src/Api/Public/Controllers/GroupsController.cs @@ -7,177 +7,176 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Public.Controllers +namespace Bit.Api.Public.Controllers; + +[Route("public/groups")] +[Authorize("Organization")] +public class GroupsController : Controller { - [Route("public/groups")] - [Authorize("Organization")] - public class GroupsController : Controller + private readonly IGroupRepository _groupRepository; + private readonly IGroupService _groupService; + private readonly ICurrentContext _currentContext; + + public GroupsController( + IGroupRepository groupRepository, + IGroupService groupService, + ICurrentContext currentContext) { - private readonly IGroupRepository _groupRepository; - private readonly IGroupService _groupService; - private readonly ICurrentContext _currentContext; + _groupRepository = groupRepository; + _groupService = groupService; + _currentContext = currentContext; + } - public GroupsController( - IGroupRepository groupRepository, - IGroupService groupService, - ICurrentContext currentContext) + /// + /// Retrieve a group. + /// + /// + /// Retrieves the details of an existing group. You need only supply the unique group identifier + /// that was returned upon group creation. + /// + /// The identifier of the group to be retrieved. + [HttpGet("{id}")] + [ProducesResponseType(typeof(GroupResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Get(Guid id) + { + var groupDetails = await _groupRepository.GetByIdWithCollectionsAsync(id); + var group = groupDetails?.Item1; + if (group == null || group.OrganizationId != _currentContext.OrganizationId) { - _groupRepository = groupRepository; - _groupService = groupService; - _currentContext = currentContext; + return new NotFoundResult(); } + var response = new GroupResponseModel(group, groupDetails.Item2); + return new JsonResult(response); + } - /// - /// Retrieve a group. - /// - /// - /// Retrieves the details of an existing group. You need only supply the unique group identifier - /// that was returned upon group creation. - /// - /// The identifier of the group to be retrieved. - [HttpGet("{id}")] - [ProducesResponseType(typeof(GroupResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Get(Guid id) + /// + /// Retrieve a groups's member ids + /// + /// + /// Retrieves the unique identifiers for all members that are associated with this group. You need only + /// supply the unique group identifier that was returned upon group creation. + /// + /// The identifier of the group to be retrieved. + [HttpGet("{id}/member-ids")] + [ProducesResponseType(typeof(HashSet), (int)HttpStatusCode.OK)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task GetMemberIds(Guid id) + { + var group = await _groupRepository.GetByIdAsync(id); + if (group == null || group.OrganizationId != _currentContext.OrganizationId) { - var groupDetails = await _groupRepository.GetByIdWithCollectionsAsync(id); - var group = groupDetails?.Item1; - if (group == null || group.OrganizationId != _currentContext.OrganizationId) - { - return new NotFoundResult(); - } - var response = new GroupResponseModel(group, groupDetails.Item2); - return new JsonResult(response); + return new NotFoundResult(); } + var orgUserIds = await _groupRepository.GetManyUserIdsByIdAsync(id); + return new JsonResult(orgUserIds); + } - /// - /// Retrieve a groups's member ids - /// - /// - /// Retrieves the unique identifiers for all members that are associated with this group. You need only - /// supply the unique group identifier that was returned upon group creation. - /// - /// The identifier of the group to be retrieved. - [HttpGet("{id}/member-ids")] - [ProducesResponseType(typeof(HashSet), (int)HttpStatusCode.OK)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task GetMemberIds(Guid id) - { - var group = await _groupRepository.GetByIdAsync(id); - if (group == null || group.OrganizationId != _currentContext.OrganizationId) - { - return new NotFoundResult(); - } - var orgUserIds = await _groupRepository.GetManyUserIdsByIdAsync(id); - return new JsonResult(orgUserIds); - } + /// + /// List all groups. + /// + /// + /// Returns a list of your organization's groups. + /// Group objects listed in this call do not include information about their associated collections. + /// + [HttpGet] + [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] + public async Task List() + { + var groups = await _groupRepository.GetManyByOrganizationIdAsync(_currentContext.OrganizationId.Value); + // TODO: Get all CollectionGroup associations for the organization and marry them up here for the response. + var groupResponses = groups.Select(g => new GroupResponseModel(g, null)); + var response = new ListResponseModel(groupResponses); + return new JsonResult(response); + } - /// - /// List all groups. - /// - /// - /// Returns a list of your organization's groups. - /// Group objects listed in this call do not include information about their associated collections. - /// - [HttpGet] - [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] - public async Task List() - { - var groups = await _groupRepository.GetManyByOrganizationIdAsync(_currentContext.OrganizationId.Value); - // TODO: Get all CollectionGroup associations for the organization and marry them up here for the response. - var groupResponses = groups.Select(g => new GroupResponseModel(g, null)); - var response = new ListResponseModel(groupResponses); - return new JsonResult(response); - } + /// + /// Create a group. + /// + /// + /// Creates a new group object. + /// + /// The request model. + [HttpPost] + [ProducesResponseType(typeof(GroupResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + public async Task Post([FromBody] GroupCreateUpdateRequestModel model) + { + var group = model.ToGroup(_currentContext.OrganizationId.Value); + var associations = model.Collections?.Select(c => c.ToSelectionReadOnly()); + await _groupService.SaveAsync(group, associations); + var response = new GroupResponseModel(group, associations); + return new JsonResult(response); + } - /// - /// Create a group. - /// - /// - /// Creates a new group object. - /// - /// The request model. - [HttpPost] - [ProducesResponseType(typeof(GroupResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - public async Task Post([FromBody] GroupCreateUpdateRequestModel model) + /// + /// Update a group. + /// + /// + /// Updates the specified group object. If a property is not provided, + /// the value of the existing property will be reset. + /// + /// The identifier of the group to be updated. + /// The request model. + [HttpPut("{id}")] + [ProducesResponseType(typeof(GroupResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Put(Guid id, [FromBody] GroupCreateUpdateRequestModel model) + { + var existingGroup = await _groupRepository.GetByIdAsync(id); + if (existingGroup == null || existingGroup.OrganizationId != _currentContext.OrganizationId) { - var group = model.ToGroup(_currentContext.OrganizationId.Value); - var associations = model.Collections?.Select(c => c.ToSelectionReadOnly()); - await _groupService.SaveAsync(group, associations); - var response = new GroupResponseModel(group, associations); - return new JsonResult(response); + return new NotFoundResult(); } + var updatedGroup = model.ToGroup(existingGroup); + var associations = model.Collections?.Select(c => c.ToSelectionReadOnly()); + await _groupService.SaveAsync(updatedGroup, associations); + var response = new GroupResponseModel(updatedGroup, associations); + return new JsonResult(response); + } - /// - /// Update a group. - /// - /// - /// Updates the specified group object. If a property is not provided, - /// the value of the existing property will be reset. - /// - /// The identifier of the group to be updated. - /// The request model. - [HttpPut("{id}")] - [ProducesResponseType(typeof(GroupResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Put(Guid id, [FromBody] GroupCreateUpdateRequestModel model) + /// + /// Update a group's members. + /// + /// + /// Updates the specified group's member associations. + /// + /// The identifier of the group to be updated. + /// The request model. + [HttpPut("{id}/member-ids")] + [ProducesResponseType((int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task PutMemberIds(Guid id, [FromBody] UpdateMemberIdsRequestModel model) + { + var existingGroup = await _groupRepository.GetByIdAsync(id); + if (existingGroup == null || existingGroup.OrganizationId != _currentContext.OrganizationId) { - var existingGroup = await _groupRepository.GetByIdAsync(id); - if (existingGroup == null || existingGroup.OrganizationId != _currentContext.OrganizationId) - { - return new NotFoundResult(); - } - var updatedGroup = model.ToGroup(existingGroup); - var associations = model.Collections?.Select(c => c.ToSelectionReadOnly()); - await _groupService.SaveAsync(updatedGroup, associations); - var response = new GroupResponseModel(updatedGroup, associations); - return new JsonResult(response); + return new NotFoundResult(); } + await _groupRepository.UpdateUsersAsync(existingGroup.Id, model.MemberIds); + return new OkResult(); + } - /// - /// Update a group's members. - /// - /// - /// Updates the specified group's member associations. - /// - /// The identifier of the group to be updated. - /// The request model. - [HttpPut("{id}/member-ids")] - [ProducesResponseType((int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task PutMemberIds(Guid id, [FromBody] UpdateMemberIdsRequestModel model) + /// + /// Delete a group. + /// + /// + /// Permanently deletes a group. This cannot be undone. + /// + /// The identifier of the group to be deleted. + [HttpDelete("{id}")] + [ProducesResponseType((int)HttpStatusCode.OK)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Delete(Guid id) + { + var group = await _groupRepository.GetByIdAsync(id); + if (group == null || group.OrganizationId != _currentContext.OrganizationId) { - var existingGroup = await _groupRepository.GetByIdAsync(id); - if (existingGroup == null || existingGroup.OrganizationId != _currentContext.OrganizationId) - { - return new NotFoundResult(); - } - await _groupRepository.UpdateUsersAsync(existingGroup.Id, model.MemberIds); - return new OkResult(); - } - - /// - /// Delete a group. - /// - /// - /// Permanently deletes a group. This cannot be undone. - /// - /// The identifier of the group to be deleted. - [HttpDelete("{id}")] - [ProducesResponseType((int)HttpStatusCode.OK)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Delete(Guid id) - { - var group = await _groupRepository.GetByIdAsync(id); - if (group == null || group.OrganizationId != _currentContext.OrganizationId) - { - return new NotFoundResult(); - } - await _groupRepository.DeleteAsync(group); - return new OkResult(); + return new NotFoundResult(); } + await _groupRepository.DeleteAsync(group); + return new OkResult(); } } diff --git a/src/Api/Public/Controllers/MembersController.cs b/src/Api/Public/Controllers/MembersController.cs index bfe7f86b7..5ea079ee3 100644 --- a/src/Api/Public/Controllers/MembersController.cs +++ b/src/Api/Public/Controllers/MembersController.cs @@ -8,227 +8,226 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Public.Controllers +namespace Bit.Api.Public.Controllers; + +[Route("public/members")] +[Authorize("Organization")] +public class MembersController : Controller { - [Route("public/members")] - [Authorize("Organization")] - public class MembersController : Controller + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IGroupRepository _groupRepository; + private readonly IOrganizationService _organizationService; + private readonly IUserService _userService; + private readonly ICurrentContext _currentContext; + + public MembersController( + IOrganizationUserRepository organizationUserRepository, + IGroupRepository groupRepository, + IOrganizationService organizationService, + IUserService userService, + ICurrentContext currentContext) { - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IGroupRepository _groupRepository; - private readonly IOrganizationService _organizationService; - private readonly IUserService _userService; - private readonly ICurrentContext _currentContext; + _organizationUserRepository = organizationUserRepository; + _groupRepository = groupRepository; + _organizationService = organizationService; + _userService = userService; + _currentContext = currentContext; + } - public MembersController( - IOrganizationUserRepository organizationUserRepository, - IGroupRepository groupRepository, - IOrganizationService organizationService, - IUserService userService, - ICurrentContext currentContext) + /// + /// Retrieve a member. + /// + /// + /// Retrieves the details of an existing member of the organization. You need only supply the + /// unique member identifier that was returned upon member creation. + /// + /// The identifier of the member to be retrieved. + [HttpGet("{id}")] + [ProducesResponseType(typeof(MemberResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Get(Guid id) + { + var userDetails = await _organizationUserRepository.GetDetailsByIdWithCollectionsAsync(id); + var orgUser = userDetails?.Item1; + if (orgUser == null || orgUser.OrganizationId != _currentContext.OrganizationId) { - _organizationUserRepository = organizationUserRepository; - _groupRepository = groupRepository; - _organizationService = organizationService; - _userService = userService; - _currentContext = currentContext; + return new NotFoundResult(); } + var response = new MemberResponseModel(orgUser, await _userService.TwoFactorIsEnabledAsync(orgUser), + userDetails.Item2); + return new JsonResult(response); + } - /// - /// Retrieve a member. - /// - /// - /// Retrieves the details of an existing member of the organization. You need only supply the - /// unique member identifier that was returned upon member creation. - /// - /// The identifier of the member to be retrieved. - [HttpGet("{id}")] - [ProducesResponseType(typeof(MemberResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Get(Guid id) + /// + /// Retrieve a member's group ids + /// + /// + /// Retrieves the unique identifiers for all groups that are associated with this member. You need only + /// supply the unique member identifier that was returned upon member creation. + /// + /// The identifier of the member to be retrieved. + [HttpGet("{id}/group-ids")] + [ProducesResponseType(typeof(HashSet), (int)HttpStatusCode.OK)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task GetGroupIds(Guid id) + { + var orgUser = await _organizationUserRepository.GetByIdAsync(id); + if (orgUser == null || orgUser.OrganizationId != _currentContext.OrganizationId) { - var userDetails = await _organizationUserRepository.GetDetailsByIdWithCollectionsAsync(id); - var orgUser = userDetails?.Item1; - if (orgUser == null || orgUser.OrganizationId != _currentContext.OrganizationId) - { - return new NotFoundResult(); - } - var response = new MemberResponseModel(orgUser, await _userService.TwoFactorIsEnabledAsync(orgUser), - userDetails.Item2); - return new JsonResult(response); + return new NotFoundResult(); } + var groupIds = await _groupRepository.GetManyIdsByUserIdAsync(id); + return new JsonResult(groupIds); + } - /// - /// Retrieve a member's group ids - /// - /// - /// Retrieves the unique identifiers for all groups that are associated with this member. You need only - /// supply the unique member identifier that was returned upon member creation. - /// - /// The identifier of the member to be retrieved. - [HttpGet("{id}/group-ids")] - [ProducesResponseType(typeof(HashSet), (int)HttpStatusCode.OK)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task GetGroupIds(Guid id) - { - var orgUser = await _organizationUserRepository.GetByIdAsync(id); - if (orgUser == null || orgUser.OrganizationId != _currentContext.OrganizationId) - { - return new NotFoundResult(); - } - var groupIds = await _groupRepository.GetManyIdsByUserIdAsync(id); - return new JsonResult(groupIds); - } + /// + /// List all members. + /// + /// + /// Returns a list of your organization's members. + /// Member objects listed in this call do not include information about their associated collections. + /// + [HttpGet] + [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] + public async Task List() + { + var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync( + _currentContext.OrganizationId.Value); + // TODO: Get all CollectionUser associations for the organization and marry them up here for the response. + var memberResponsesTasks = users.Select(async u => new MemberResponseModel(u, + await _userService.TwoFactorIsEnabledAsync(u), null)); + var memberResponses = await Task.WhenAll(memberResponsesTasks); + var response = new ListResponseModel(memberResponses); + return new JsonResult(response); + } - /// - /// List all members. - /// - /// - /// Returns a list of your organization's members. - /// Member objects listed in this call do not include information about their associated collections. - /// - [HttpGet] - [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] - public async Task List() + /// + /// Create a member. + /// + /// + /// Creates a new member object by inviting a user to the organization. + /// + /// The request model. + [HttpPost] + [ProducesResponseType(typeof(MemberResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + public async Task Post([FromBody] MemberCreateRequestModel model) + { + var associations = model.Collections?.Select(c => c.ToSelectionReadOnly()); + var invite = new OrganizationUserInvite { - var users = await _organizationUserRepository.GetManyDetailsByOrganizationAsync( - _currentContext.OrganizationId.Value); - // TODO: Get all CollectionUser associations for the organization and marry them up here for the response. - var memberResponsesTasks = users.Select(async u => new MemberResponseModel(u, - await _userService.TwoFactorIsEnabledAsync(u), null)); - var memberResponses = await Task.WhenAll(memberResponsesTasks); - var response = new ListResponseModel(memberResponses); - return new JsonResult(response); - } + Emails = new List { model.Email }, + Type = model.Type.Value, + AccessAll = model.AccessAll.Value, + Collections = associations + }; + var user = await _organizationService.InviteUserAsync(_currentContext.OrganizationId.Value, null, + model.Email, model.Type.Value, model.AccessAll.Value, model.ExternalId, associations); + var response = new MemberResponseModel(user, associations); + return new JsonResult(response); + } - /// - /// Create a member. - /// - /// - /// Creates a new member object by inviting a user to the organization. - /// - /// The request model. - [HttpPost] - [ProducesResponseType(typeof(MemberResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - public async Task Post([FromBody] MemberCreateRequestModel model) + /// + /// Update a member. + /// + /// + /// Updates the specified member object. If a property is not provided, + /// the value of the existing property will be reset. + /// + /// The identifier of the member to be updated. + /// The request model. + [HttpPut("{id}")] + [ProducesResponseType(typeof(MemberResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Put(Guid id, [FromBody] MemberUpdateRequestModel model) + { + var existingUser = await _organizationUserRepository.GetByIdAsync(id); + if (existingUser == null || existingUser.OrganizationId != _currentContext.OrganizationId) { - var associations = model.Collections?.Select(c => c.ToSelectionReadOnly()); - var invite = new OrganizationUserInvite - { - Emails = new List { model.Email }, - Type = model.Type.Value, - AccessAll = model.AccessAll.Value, - Collections = associations - }; - var user = await _organizationService.InviteUserAsync(_currentContext.OrganizationId.Value, null, - model.Email, model.Type.Value, model.AccessAll.Value, model.ExternalId, associations); - var response = new MemberResponseModel(user, associations); - return new JsonResult(response); + return new NotFoundResult(); } + var updatedUser = model.ToOrganizationUser(existingUser); + var associations = model.Collections?.Select(c => c.ToSelectionReadOnly()); + await _organizationService.SaveUserAsync(updatedUser, null, associations); + MemberResponseModel response = null; + if (existingUser.UserId.HasValue) + { + var existingUserDetails = await _organizationUserRepository.GetDetailsByIdAsync(id); + response = new MemberResponseModel(existingUserDetails, + await _userService.TwoFactorIsEnabledAsync(existingUserDetails), associations); + } + else + { + response = new MemberResponseModel(updatedUser, associations); + } + return new JsonResult(response); + } - /// - /// Update a member. - /// - /// - /// Updates the specified member object. If a property is not provided, - /// the value of the existing property will be reset. - /// - /// The identifier of the member to be updated. - /// The request model. - [HttpPut("{id}")] - [ProducesResponseType(typeof(MemberResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Put(Guid id, [FromBody] MemberUpdateRequestModel model) + /// + /// Update a member's groups. + /// + /// + /// Updates the specified member's group associations. + /// + /// The identifier of the member to be updated. + /// The request model. + [HttpPut("{id}/group-ids")] + [ProducesResponseType((int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task PutGroupIds(Guid id, [FromBody] UpdateGroupIdsRequestModel model) + { + var existingUser = await _organizationUserRepository.GetByIdAsync(id); + if (existingUser == null || existingUser.OrganizationId != _currentContext.OrganizationId) { - var existingUser = await _organizationUserRepository.GetByIdAsync(id); - if (existingUser == null || existingUser.OrganizationId != _currentContext.OrganizationId) - { - return new NotFoundResult(); - } - var updatedUser = model.ToOrganizationUser(existingUser); - var associations = model.Collections?.Select(c => c.ToSelectionReadOnly()); - await _organizationService.SaveUserAsync(updatedUser, null, associations); - MemberResponseModel response = null; - if (existingUser.UserId.HasValue) - { - var existingUserDetails = await _organizationUserRepository.GetDetailsByIdAsync(id); - response = new MemberResponseModel(existingUserDetails, - await _userService.TwoFactorIsEnabledAsync(existingUserDetails), associations); - } - else - { - response = new MemberResponseModel(updatedUser, associations); - } - return new JsonResult(response); + return new NotFoundResult(); } + await _organizationService.UpdateUserGroupsAsync(existingUser, model.GroupIds, null); + return new OkResult(); + } - /// - /// Update a member's groups. - /// - /// - /// Updates the specified member's group associations. - /// - /// The identifier of the member to be updated. - /// The request model. - [HttpPut("{id}/group-ids")] - [ProducesResponseType((int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task PutGroupIds(Guid id, [FromBody] UpdateGroupIdsRequestModel model) + /// + /// Delete a member. + /// + /// + /// Permanently deletes a member from the organization. This cannot be undone. + /// The user account will still remain. The user is only removed from the organization. + /// + /// The identifier of the member to be deleted. + [HttpDelete("{id}")] + [ProducesResponseType((int)HttpStatusCode.OK)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Delete(Guid id) + { + var user = await _organizationUserRepository.GetByIdAsync(id); + if (user == null || user.OrganizationId != _currentContext.OrganizationId) { - var existingUser = await _organizationUserRepository.GetByIdAsync(id); - if (existingUser == null || existingUser.OrganizationId != _currentContext.OrganizationId) - { - return new NotFoundResult(); - } - await _organizationService.UpdateUserGroupsAsync(existingUser, model.GroupIds, null); - return new OkResult(); + return new NotFoundResult(); } + await _organizationService.DeleteUserAsync(_currentContext.OrganizationId.Value, id, null); + return new OkResult(); + } - /// - /// Delete a member. - /// - /// - /// Permanently deletes a member from the organization. This cannot be undone. - /// The user account will still remain. The user is only removed from the organization. - /// - /// The identifier of the member to be deleted. - [HttpDelete("{id}")] - [ProducesResponseType((int)HttpStatusCode.OK)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Delete(Guid id) + /// + /// Re-invite a member. + /// + /// + /// Re-sends the invitation email to an organization member. + /// + /// The identifier of the member to re-invite. + [HttpPost("{id}/reinvite")] + [ProducesResponseType((int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task PostReinvite(Guid id) + { + var existingUser = await _organizationUserRepository.GetByIdAsync(id); + if (existingUser == null || existingUser.OrganizationId != _currentContext.OrganizationId) { - var user = await _organizationUserRepository.GetByIdAsync(id); - if (user == null || user.OrganizationId != _currentContext.OrganizationId) - { - return new NotFoundResult(); - } - await _organizationService.DeleteUserAsync(_currentContext.OrganizationId.Value, id, null); - return new OkResult(); - } - - /// - /// Re-invite a member. - /// - /// - /// Re-sends the invitation email to an organization member. - /// - /// The identifier of the member to re-invite. - [HttpPost("{id}/reinvite")] - [ProducesResponseType((int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task PostReinvite(Guid id) - { - var existingUser = await _organizationUserRepository.GetByIdAsync(id); - if (existingUser == null || existingUser.OrganizationId != _currentContext.OrganizationId) - { - return new NotFoundResult(); - } - await _organizationService.ResendInviteAsync(_currentContext.OrganizationId.Value, null, id); - return new OkResult(); + return new NotFoundResult(); } + await _organizationService.ResendInviteAsync(_currentContext.OrganizationId.Value, null, id); + return new OkResult(); } } diff --git a/src/Api/Public/Controllers/OrganizationController.cs b/src/Api/Public/Controllers/OrganizationController.cs index 978811d39..ce0683b95 100644 --- a/src/Api/Public/Controllers/OrganizationController.cs +++ b/src/Api/Public/Controllers/OrganizationController.cs @@ -8,52 +8,51 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Public.Controllers +namespace Bit.Api.Public.Controllers; + +[Route("public/organization")] +[Authorize("Organization")] +public class OrganizationController : Controller { - [Route("public/organization")] - [Authorize("Organization")] - public class OrganizationController : Controller + private readonly IOrganizationService _organizationService; + private readonly ICurrentContext _currentContext; + private readonly GlobalSettings _globalSettings; + + public OrganizationController( + IOrganizationService organizationService, + ICurrentContext currentContext, + GlobalSettings globalSettings) { - private readonly IOrganizationService _organizationService; - private readonly ICurrentContext _currentContext; - private readonly GlobalSettings _globalSettings; + _organizationService = organizationService; + _currentContext = currentContext; + _globalSettings = globalSettings; + } - public OrganizationController( - IOrganizationService organizationService, - ICurrentContext currentContext, - GlobalSettings globalSettings) + /// + /// Import members and groups. + /// + /// + /// Import members and groups from an external system. + /// + /// The request model. + [HttpPost("import")] + [ProducesResponseType(typeof(MemberResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + public async Task Import([FromBody] OrganizationImportRequestModel model) + { + if (!_globalSettings.SelfHosted && !model.LargeImport && + (model.Groups.Count() > 2000 || model.Members.Count(u => !u.Deleted) > 2000)) { - _organizationService = organizationService; - _currentContext = currentContext; - _globalSettings = globalSettings; + throw new BadRequestException("You cannot import this much data at once."); } - /// - /// Import members and groups. - /// - /// - /// Import members and groups from an external system. - /// - /// The request model. - [HttpPost("import")] - [ProducesResponseType(typeof(MemberResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - public async Task Import([FromBody] OrganizationImportRequestModel model) - { - if (!_globalSettings.SelfHosted && !model.LargeImport && - (model.Groups.Count() > 2000 || model.Members.Count(u => !u.Deleted) > 2000)) - { - throw new BadRequestException("You cannot import this much data at once."); - } - - await _organizationService.ImportAsync( - _currentContext.OrganizationId.Value, - null, - model.Groups.Select(g => g.ToImportedGroup(_currentContext.OrganizationId.Value)), - model.Members.Where(u => !u.Deleted).Select(u => u.ToImportedOrganizationUser()), - model.Members.Where(u => u.Deleted).Select(u => u.ExternalId), - model.OverwriteExisting.GetValueOrDefault()); - return new OkResult(); - } + await _organizationService.ImportAsync( + _currentContext.OrganizationId.Value, + null, + model.Groups.Select(g => g.ToImportedGroup(_currentContext.OrganizationId.Value)), + model.Members.Where(u => !u.Deleted).Select(u => u.ToImportedOrganizationUser()), + model.Members.Where(u => u.Deleted).Select(u => u.ExternalId), + model.OverwriteExisting.GetValueOrDefault()); + return new OkResult(); } } diff --git a/src/Api/Public/Controllers/PoliciesController.cs b/src/Api/Public/Controllers/PoliciesController.cs index 65556ebac..b208938ed 100644 --- a/src/Api/Public/Controllers/PoliciesController.cs +++ b/src/Api/Public/Controllers/PoliciesController.cs @@ -8,98 +8,97 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Public.Controllers +namespace Bit.Api.Public.Controllers; + +[Route("public/policies")] +[Authorize("Organization")] +public class PoliciesController : Controller { - [Route("public/policies")] - [Authorize("Organization")] - public class PoliciesController : Controller + private readonly IPolicyRepository _policyRepository; + private readonly IPolicyService _policyService; + private readonly IUserService _userService; + private readonly IOrganizationService _organizationService; + private readonly ICurrentContext _currentContext; + + public PoliciesController( + IPolicyRepository policyRepository, + IPolicyService policyService, + IUserService userService, + IOrganizationService organizationService, + ICurrentContext currentContext) { - private readonly IPolicyRepository _policyRepository; - private readonly IPolicyService _policyService; - private readonly IUserService _userService; - private readonly IOrganizationService _organizationService; - private readonly ICurrentContext _currentContext; + _policyRepository = policyRepository; + _policyService = policyService; + _userService = userService; + _organizationService = organizationService; + _currentContext = currentContext; + } - public PoliciesController( - IPolicyRepository policyRepository, - IPolicyService policyService, - IUserService userService, - IOrganizationService organizationService, - ICurrentContext currentContext) + /// + /// Retrieve a policy. + /// + /// + /// Retrieves the details of a policy. + /// + /// The type of policy to be retrieved. + [HttpGet("{type}")] + [ProducesResponseType(typeof(GroupResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Get(PolicyType type) + { + var policy = await _policyRepository.GetByOrganizationIdTypeAsync( + _currentContext.OrganizationId.Value, type); + if (policy == null) { - _policyRepository = policyRepository; - _policyService = policyService; - _userService = userService; - _organizationService = organizationService; - _currentContext = currentContext; + return new NotFoundResult(); } + var response = new PolicyResponseModel(policy); + return new JsonResult(response); + } - /// - /// Retrieve a policy. - /// - /// - /// Retrieves the details of a policy. - /// - /// The type of policy to be retrieved. - [HttpGet("{type}")] - [ProducesResponseType(typeof(GroupResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Get(PolicyType type) - { - var policy = await _policyRepository.GetByOrganizationIdTypeAsync( - _currentContext.OrganizationId.Value, type); - if (policy == null) - { - return new NotFoundResult(); - } - var response = new PolicyResponseModel(policy); - return new JsonResult(response); - } + /// + /// List all policies. + /// + /// + /// Returns a list of your organization's policies. + /// + [HttpGet] + [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] + public async Task List() + { + var policies = await _policyRepository.GetManyByOrganizationIdAsync(_currentContext.OrganizationId.Value); + var policyResponses = policies.Select(p => new PolicyResponseModel(p)); + var response = new ListResponseModel(policyResponses); + return new JsonResult(response); + } - /// - /// List all policies. - /// - /// - /// Returns a list of your organization's policies. - /// - [HttpGet] - [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] - public async Task List() + /// + /// Update a policy. + /// + /// + /// Updates the specified policy. If a property is not provided, + /// the value of the existing property will be reset. + /// + /// The type of policy to be updated. + /// The request model. + [HttpPut("{id}")] + [ProducesResponseType(typeof(PolicyResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Put(PolicyType type, [FromBody] PolicyUpdateRequestModel model) + { + var policy = await _policyRepository.GetByOrganizationIdTypeAsync( + _currentContext.OrganizationId.Value, type); + if (policy == null) { - var policies = await _policyRepository.GetManyByOrganizationIdAsync(_currentContext.OrganizationId.Value); - var policyResponses = policies.Select(p => new PolicyResponseModel(p)); - var response = new ListResponseModel(policyResponses); - return new JsonResult(response); + policy = model.ToPolicy(_currentContext.OrganizationId.Value); } - - /// - /// Update a policy. - /// - /// - /// Updates the specified policy. If a property is not provided, - /// the value of the existing property will be reset. - /// - /// The type of policy to be updated. - /// The request model. - [HttpPut("{id}")] - [ProducesResponseType(typeof(PolicyResponseModel), (int)HttpStatusCode.OK)] - [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] - [ProducesResponseType((int)HttpStatusCode.NotFound)] - public async Task Put(PolicyType type, [FromBody] PolicyUpdateRequestModel model) + else { - var policy = await _policyRepository.GetByOrganizationIdTypeAsync( - _currentContext.OrganizationId.Value, type); - if (policy == null) - { - policy = model.ToPolicy(_currentContext.OrganizationId.Value); - } - else - { - policy = model.ToPolicy(policy); - } - await _policyService.SaveAsync(policy, _userService, _organizationService, null); - var response = new PolicyResponseModel(policy); - return new JsonResult(response); + policy = model.ToPolicy(policy); } + await _policyService.SaveAsync(policy, _userService, _organizationService, null); + var response = new PolicyResponseModel(policy); + return new JsonResult(response); } } diff --git a/src/Api/Startup.cs b/src/Api/Startup.cs index 7bebb1a9b..20b707f5d 100644 --- a/src/Api/Startup.cs +++ b/src/Api/Startup.cs @@ -17,209 +17,208 @@ using Microsoft.Extensions.DependencyInjection.Extensions; using Bit.Commercial.Core.Utilities; #endif -namespace Bit.Api +namespace Bit.Api; + +public class Startup { - public class Startup + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Configuration = configuration; + Environment = env; + } + + public IConfiguration Configuration { get; private set; } + public IWebHostEnvironment Environment { get; set; } + + public void ConfigureServices(IServiceCollection services) + { + // Options + services.AddOptions(); + + // Settings + var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); + if (!globalSettings.SelfHosted) { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - Configuration = configuration; - Environment = env; + services.Configure(Configuration.GetSection("IpRateLimitOptions")); + services.Configure(Configuration.GetSection("IpRateLimitPolicies")); } - public IConfiguration Configuration { get; private set; } - public IWebHostEnvironment Environment { get; set; } + // Data Protection + services.AddCustomDataProtectionServices(Environment, globalSettings); - public void ConfigureServices(IServiceCollection services) + // Event Grid + if (!string.IsNullOrWhiteSpace(globalSettings.EventGridKey)) { - // Options - services.AddOptions(); + ApiHelpers.EventGridKey = globalSettings.EventGridKey; + } - // Settings - var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); - if (!globalSettings.SelfHosted) + // Stripe Billing + StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; + StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; + + // Repositories + services.AddSqlServerRepositories(globalSettings); + + // Context + services.AddScoped(); + services.TryAddSingleton(); + + // Caching + services.AddMemoryCache(); + services.AddDistributedCache(globalSettings); + + // BitPay + services.AddSingleton(); + + if (!globalSettings.SelfHosted) + { + services.AddIpRateLimiting(globalSettings); + } + + // Identity + services.AddCustomIdentityServices(globalSettings); + services.AddIdentityAuthenticationServices(globalSettings, Environment, config => + { + config.AddPolicy("Application", policy => { - services.Configure(Configuration.GetSection("IpRateLimitOptions")); - services.Configure(Configuration.GetSection("IpRateLimitPolicies")); - } - - // Data Protection - services.AddCustomDataProtectionServices(Environment, globalSettings); - - // Event Grid - if (!string.IsNullOrWhiteSpace(globalSettings.EventGridKey)) - { - ApiHelpers.EventGridKey = globalSettings.EventGridKey; - } - - // Stripe Billing - StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; - StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; - - // Repositories - services.AddSqlServerRepositories(globalSettings); - - // Context - services.AddScoped(); - services.TryAddSingleton(); - - // Caching - services.AddMemoryCache(); - services.AddDistributedCache(globalSettings); - - // BitPay - services.AddSingleton(); - - if (!globalSettings.SelfHosted) - { - services.AddIpRateLimiting(globalSettings); - } - - // Identity - services.AddCustomIdentityServices(globalSettings); - services.AddIdentityAuthenticationServices(globalSettings, Environment, config => - { - config.AddPolicy("Application", policy => - { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external"); - policy.RequireClaim(JwtClaimTypes.Scope, "api"); - }); - config.AddPolicy("Web", policy => - { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external"); - policy.RequireClaim(JwtClaimTypes.Scope, "api"); - policy.RequireClaim(JwtClaimTypes.ClientId, "web"); - }); - config.AddPolicy("Push", policy => - { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.Scope, "api.push"); - }); - config.AddPolicy("Licensing", policy => - { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.Scope, "api.licensing"); - }); - config.AddPolicy("Organization", policy => - { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.Scope, "api.organization"); - }); - config.AddPolicy("Installation", policy => - { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.Scope, "api.installation"); - }); + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external"); + policy.RequireClaim(JwtClaimTypes.Scope, "api"); }); + config.AddPolicy("Web", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external"); + policy.RequireClaim(JwtClaimTypes.Scope, "api"); + policy.RequireClaim(JwtClaimTypes.ClientId, "web"); + }); + config.AddPolicy("Push", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.Scope, "api.push"); + }); + config.AddPolicy("Licensing", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.Scope, "api.licensing"); + }); + config.AddPolicy("Organization", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.Scope, "api.organization"); + }); + config.AddPolicy("Installation", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.Scope, "api.installation"); + }); + }); - services.AddScoped(); + services.AddScoped(); - // Services - services.AddBaseServices(globalSettings); - services.AddDefaultServices(globalSettings); - services.AddCoreLocalizationServices(); + // Services + services.AddBaseServices(globalSettings); + services.AddDefaultServices(globalSettings); + services.AddCoreLocalizationServices(); #if OSS - services.AddOosServices(); + services.AddOosServices(); #else - services.AddCommCoreServices(); + services.AddCommCoreServices(); #endif - // MVC - services.AddMvc(config => - { - config.Conventions.Add(new ApiExplorerGroupConvention()); - config.Conventions.Add(new PublicApiControllersModelConvention()); - }); - - services.AddSwagger(globalSettings); - Jobs.JobsHostedService.AddJobsServices(services, globalSettings.SelfHosted); - services.AddHostedService(); - - if (CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ConnectionString) && - CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ApplicationCacheTopicName)) - { - services.AddHostedService(); - } - } - - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings, - ILogger logger) + // MVC + services.AddMvc(config => { - IdentityModelEventSource.ShowPII = true; - app.UseSerilog(env, appLifetime, globalSettings); + config.Conventions.Add(new ApiExplorerGroupConvention()); + config.Conventions.Add(new PublicApiControllersModelConvention()); + }); - // Add general security headers - app.UseMiddleware(); + services.AddSwagger(globalSettings); + Jobs.JobsHostedService.AddJobsServices(services, globalSettings.SelfHosted); + services.AddHostedService(); - // Default Middleware - app.UseDefaultMiddleware(env, globalSettings); - - if (!globalSettings.SelfHosted) - { - // Rate limiting - app.UseMiddleware(); - } - else - { - app.UseForwardedHeaders(globalSettings); - } - - // Add localization - app.UseCoreLocalization(); - - // Add static files to the request pipeline. - app.UseStaticFiles(); - - // Add routing - app.UseRouting(); - - // Add Cors - app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) - .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); - - // Add authentication and authorization to the request pipeline. - app.UseAuthentication(); - app.UseAuthorization(); - - // Add current context - app.UseMiddleware(); - - // Add endpoints to the request pipeline. - app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); - - // Add Swagger - if (Environment.IsDevelopment() || globalSettings.SelfHosted) - { - app.UseSwagger(config => - { - config.RouteTemplate = "specs/{documentName}/swagger.json"; - config.PreSerializeFilters.Add((swaggerDoc, httpReq) => - swaggerDoc.Servers = new List - { - new OpenApiServer { Url = globalSettings.BaseServiceUri.Api } - }); - }); - app.UseSwaggerUI(config => - { - config.DocumentTitle = "Bitwarden API Documentation"; - config.RoutePrefix = "docs"; - config.SwaggerEndpoint($"{globalSettings.BaseServiceUri.Api}/specs/public/swagger.json", - "Bitwarden Public API"); - config.OAuthClientId("accountType.id"); - config.OAuthClientSecret("secretKey"); - }); - } - - // Log startup - logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started."); + if (CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ConnectionString) && + CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ApplicationCacheTopicName)) + { + services.AddHostedService(); } } + + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings, + ILogger logger) + { + IdentityModelEventSource.ShowPII = true; + app.UseSerilog(env, appLifetime, globalSettings); + + // Add general security headers + app.UseMiddleware(); + + // Default Middleware + app.UseDefaultMiddleware(env, globalSettings); + + if (!globalSettings.SelfHosted) + { + // Rate limiting + app.UseMiddleware(); + } + else + { + app.UseForwardedHeaders(globalSettings); + } + + // Add localization + app.UseCoreLocalization(); + + // Add static files to the request pipeline. + app.UseStaticFiles(); + + // Add routing + app.UseRouting(); + + // Add Cors + app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) + .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); + + // Add authentication and authorization to the request pipeline. + app.UseAuthentication(); + app.UseAuthorization(); + + // Add current context + app.UseMiddleware(); + + // Add endpoints to the request pipeline. + app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); + + // Add Swagger + if (Environment.IsDevelopment() || globalSettings.SelfHosted) + { + app.UseSwagger(config => + { + config.RouteTemplate = "specs/{documentName}/swagger.json"; + config.PreSerializeFilters.Add((swaggerDoc, httpReq) => + swaggerDoc.Servers = new List + { + new OpenApiServer { Url = globalSettings.BaseServiceUri.Api } + }); + }); + app.UseSwaggerUI(config => + { + config.DocumentTitle = "Bitwarden API Documentation"; + config.RoutePrefix = "docs"; + config.SwaggerEndpoint($"{globalSettings.BaseServiceUri.Api}/specs/public/swagger.json", + "Bitwarden Public API"); + config.OAuthClientId("accountType.id"); + config.OAuthClientSecret("secretKey"); + }); + } + + // Log startup + logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started."); + } } diff --git a/src/Api/Utilities/ApiExplorerGroupConvention.cs b/src/Api/Utilities/ApiExplorerGroupConvention.cs index 5b8d7559a..42b1c8d6e 100644 --- a/src/Api/Utilities/ApiExplorerGroupConvention.cs +++ b/src/Api/Utilities/ApiExplorerGroupConvention.cs @@ -1,13 +1,12 @@ using Microsoft.AspNetCore.Mvc.ApplicationModels; -namespace Bit.Api.Utilities +namespace Bit.Api.Utilities; + +public class ApiExplorerGroupConvention : IControllerModelConvention { - public class ApiExplorerGroupConvention : IControllerModelConvention + public void Apply(ControllerModel controller) { - public void Apply(ControllerModel controller) - { - var controllerNamespace = controller.ControllerType.Namespace; - controller.ApiExplorer.GroupName = controllerNamespace.Contains(".Public.") ? "public" : "internal"; - } + var controllerNamespace = controller.ControllerType.Namespace; + controller.ApiExplorer.GroupName = controllerNamespace.Contains(".Public.") ? "public" : "internal"; } } diff --git a/src/Api/Utilities/ApiHelpers.cs b/src/Api/Utilities/ApiHelpers.cs index 920c15dd2..58097089f 100644 --- a/src/Api/Utilities/ApiHelpers.cs +++ b/src/Api/Utilities/ApiHelpers.cs @@ -4,70 +4,69 @@ using Azure.Messaging.EventGrid.SystemEvents; using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Utilities +namespace Bit.Api.Utilities; + +public static class ApiHelpers { - public static class ApiHelpers + public static string EventGridKey { get; set; } + public async static Task ReadJsonFileFromBody(HttpContext httpContext, IFormFile file, long maxSize = 51200) { - public static string EventGridKey { get; set; } - public async static Task ReadJsonFileFromBody(HttpContext httpContext, IFormFile file, long maxSize = 51200) + T obj = default(T); + if (file != null && httpContext.Request.ContentLength.HasValue && httpContext.Request.ContentLength.Value <= maxSize) { - T obj = default(T); - if (file != null && httpContext.Request.ContentLength.HasValue && httpContext.Request.ContentLength.Value <= maxSize) + try { - try - { - using var stream = file.OpenReadStream(); - obj = await JsonSerializer.DeserializeAsync(stream, JsonHelpers.IgnoreCase); - } - catch { } + using var stream = file.OpenReadStream(); + obj = await JsonSerializer.DeserializeAsync(stream, JsonHelpers.IgnoreCase); } - - return obj; + catch { } } - /// - /// Validates Azure event subscription and calls the appropriate event handler. Responds HttpOk. - /// - /// HttpRequest received from Azure - /// Dictionary of eventType strings and their associated handlers. - /// OkObjectResult - /// Reference https://docs.microsoft.com/en-us/azure/event-grid/receive-events - public async static Task HandleAzureEvents(HttpRequest request, - Dictionary> eventTypeHandlers) + return obj; + } + + /// + /// Validates Azure event subscription and calls the appropriate event handler. Responds HttpOk. + /// + /// HttpRequest received from Azure + /// Dictionary of eventType strings and their associated handlers. + /// OkObjectResult + /// Reference https://docs.microsoft.com/en-us/azure/event-grid/receive-events + public async static Task HandleAzureEvents(HttpRequest request, + Dictionary> eventTypeHandlers) + { + var queryKey = request.Query["key"]; + + if (!CoreHelpers.FixedTimeEquals(queryKey, EventGridKey)) { - var queryKey = request.Query["key"]; + return new UnauthorizedObjectResult("Authentication failed. Please use a valid key."); + } - if (!CoreHelpers.FixedTimeEquals(queryKey, EventGridKey)) + var response = string.Empty; + var requestData = await BinaryData.FromStreamAsync(request.Body); + var eventGridEvents = EventGridEvent.ParseMany(requestData); + foreach (var eventGridEvent in eventGridEvents) + { + if (eventGridEvent.TryGetSystemEventData(out object systemEvent)) { - return new UnauthorizedObjectResult("Authentication failed. Please use a valid key."); - } - - var response = string.Empty; - var requestData = await BinaryData.FromStreamAsync(request.Body); - var eventGridEvents = EventGridEvent.ParseMany(requestData); - foreach (var eventGridEvent in eventGridEvents) - { - if (eventGridEvent.TryGetSystemEventData(out object systemEvent)) + if (systemEvent is SubscriptionValidationEventData eventData) { - if (systemEvent is SubscriptionValidationEventData eventData) + // Might want to enable additional validation: subject, topic etc. + var responseData = new SubscriptionValidationResponse() { - // Might want to enable additional validation: subject, topic etc. - var responseData = new SubscriptionValidationResponse() - { - ValidationResponse = eventData.ValidationCode - }; + ValidationResponse = eventData.ValidationCode + }; - return new OkObjectResult(responseData); - } - } - - if (eventTypeHandlers.ContainsKey(eventGridEvent.EventType)) - { - await eventTypeHandlers[eventGridEvent.EventType](eventGridEvent); + return new OkObjectResult(responseData); } } - return new OkObjectResult(response); + if (eventTypeHandlers.ContainsKey(eventGridEvent.EventType)) + { + await eventTypeHandlers[eventGridEvent.EventType](eventGridEvent); + } } + + return new OkObjectResult(response); } } diff --git a/src/Api/Utilities/DisableFormValueModelBindingAttribute.cs b/src/Api/Utilities/DisableFormValueModelBindingAttribute.cs index 27169a57c..e0c604546 100644 --- a/src/Api/Utilities/DisableFormValueModelBindingAttribute.cs +++ b/src/Api/Utilities/DisableFormValueModelBindingAttribute.cs @@ -1,21 +1,20 @@ using Microsoft.AspNetCore.Mvc.Filters; using Microsoft.AspNetCore.Mvc.ModelBinding; -namespace Bit.Api.Utilities -{ - [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method)] - public class DisableFormValueModelBindingAttribute : Attribute, IResourceFilter - { - public void OnResourceExecuting(ResourceExecutingContext context) - { - var factories = context.ValueProviderFactories; - factories.RemoveType(); - factories.RemoveType(); - factories.RemoveType(); - } +namespace Bit.Api.Utilities; - public void OnResourceExecuted(ResourceExecutedContext context) - { - } +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method)] +public class DisableFormValueModelBindingAttribute : Attribute, IResourceFilter +{ + public void OnResourceExecuting(ResourceExecutingContext context) + { + var factories = context.ValueProviderFactories; + factories.RemoveType(); + factories.RemoveType(); + factories.RemoveType(); + } + + public void OnResourceExecuted(ResourceExecutedContext context) + { } } diff --git a/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs b/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs index 846bcda25..422bfa62d 100644 --- a/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs +++ b/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs @@ -6,117 +6,116 @@ using Microsoft.IdentityModel.Tokens; using Stripe; using InternalApi = Bit.Core.Models.Api; -namespace Bit.Api.Utilities -{ - public class ExceptionHandlerFilterAttribute : ExceptionFilterAttribute - { - private readonly bool _publicApi; +namespace Bit.Api.Utilities; - public ExceptionHandlerFilterAttribute(bool publicApi) +public class ExceptionHandlerFilterAttribute : ExceptionFilterAttribute +{ + private readonly bool _publicApi; + + public ExceptionHandlerFilterAttribute(bool publicApi) + { + _publicApi = publicApi; + } + + public override void OnException(ExceptionContext context) + { + var errorMessage = "An error has occurred."; + + var exception = context.Exception; + if (exception == null) { - _publicApi = publicApi; + // Should never happen. + return; } - public override void OnException(ExceptionContext context) + ErrorResponseModel publicErrorModel = null; + InternalApi.ErrorResponseModel internalErrorModel = null; + if (exception is BadRequestException badRequestException) { - var errorMessage = "An error has occurred."; - - var exception = context.Exception; - if (exception == null) + context.HttpContext.Response.StatusCode = 400; + if (badRequestException.ModelState != null) { - // Should never happen. - return; - } - - ErrorResponseModel publicErrorModel = null; - InternalApi.ErrorResponseModel internalErrorModel = null; - if (exception is BadRequestException badRequestException) - { - context.HttpContext.Response.StatusCode = 400; - if (badRequestException.ModelState != null) - { - if (_publicApi) - { - publicErrorModel = new ErrorResponseModel(badRequestException.ModelState); - } - else - { - internalErrorModel = new InternalApi.ErrorResponseModel(badRequestException.ModelState); - } - } - else - { - errorMessage = badRequestException.Message; - } - } - else if (exception is StripeException stripeException && stripeException?.StripeError?.Type == "card_error") - { - context.HttpContext.Response.StatusCode = 400; if (_publicApi) { - publicErrorModel = new ErrorResponseModel(stripeException.StripeError.Param, - stripeException.Message); + publicErrorModel = new ErrorResponseModel(badRequestException.ModelState); } else { - internalErrorModel = new InternalApi.ErrorResponseModel(stripeException.StripeError.Param, - stripeException.Message); + internalErrorModel = new InternalApi.ErrorResponseModel(badRequestException.ModelState); } } - else if (exception is GatewayException) - { - errorMessage = exception.Message; - context.HttpContext.Response.StatusCode = 400; - } - else if (exception is NotSupportedException && !string.IsNullOrWhiteSpace(exception.Message)) - { - errorMessage = exception.Message; - context.HttpContext.Response.StatusCode = 400; - } - else if (exception is ApplicationException) - { - context.HttpContext.Response.StatusCode = 402; - } - else if (exception is NotFoundException) - { - errorMessage = "Resource not found."; - context.HttpContext.Response.StatusCode = 404; - } - else if (exception is SecurityTokenValidationException) - { - errorMessage = "Invalid token."; - context.HttpContext.Response.StatusCode = 403; - } - else if (exception is UnauthorizedAccessException) - { - errorMessage = "Unauthorized."; - context.HttpContext.Response.StatusCode = 401; - } else { - var logger = context.HttpContext.RequestServices.GetRequiredService>(); - logger.LogError(0, exception, exception.Message); - errorMessage = "An unhandled server error has occurred."; - context.HttpContext.Response.StatusCode = 500; + errorMessage = badRequestException.Message; } - + } + else if (exception is StripeException stripeException && stripeException?.StripeError?.Type == "card_error") + { + context.HttpContext.Response.StatusCode = 400; if (_publicApi) { - var errorModel = publicErrorModel ?? new ErrorResponseModel(errorMessage); - context.Result = new ObjectResult(errorModel); + publicErrorModel = new ErrorResponseModel(stripeException.StripeError.Param, + stripeException.Message); } else { - var errorModel = internalErrorModel ?? new InternalApi.ErrorResponseModel(errorMessage); - var env = context.HttpContext.RequestServices.GetRequiredService(); - if (env.IsDevelopment()) - { - errorModel.ExceptionMessage = exception.Message; - errorModel.ExceptionStackTrace = exception.StackTrace; - errorModel.InnerExceptionMessage = exception?.InnerException?.Message; - } - context.Result = new ObjectResult(errorModel); + internalErrorModel = new InternalApi.ErrorResponseModel(stripeException.StripeError.Param, + stripeException.Message); } } + else if (exception is GatewayException) + { + errorMessage = exception.Message; + context.HttpContext.Response.StatusCode = 400; + } + else if (exception is NotSupportedException && !string.IsNullOrWhiteSpace(exception.Message)) + { + errorMessage = exception.Message; + context.HttpContext.Response.StatusCode = 400; + } + else if (exception is ApplicationException) + { + context.HttpContext.Response.StatusCode = 402; + } + else if (exception is NotFoundException) + { + errorMessage = "Resource not found."; + context.HttpContext.Response.StatusCode = 404; + } + else if (exception is SecurityTokenValidationException) + { + errorMessage = "Invalid token."; + context.HttpContext.Response.StatusCode = 403; + } + else if (exception is UnauthorizedAccessException) + { + errorMessage = "Unauthorized."; + context.HttpContext.Response.StatusCode = 401; + } + else + { + var logger = context.HttpContext.RequestServices.GetRequiredService>(); + logger.LogError(0, exception, exception.Message); + errorMessage = "An unhandled server error has occurred."; + context.HttpContext.Response.StatusCode = 500; + } + + if (_publicApi) + { + var errorModel = publicErrorModel ?? new ErrorResponseModel(errorMessage); + context.Result = new ObjectResult(errorModel); + } + else + { + var errorModel = internalErrorModel ?? new InternalApi.ErrorResponseModel(errorMessage); + var env = context.HttpContext.RequestServices.GetRequiredService(); + if (env.IsDevelopment()) + { + errorModel.ExceptionMessage = exception.Message; + errorModel.ExceptionStackTrace = exception.StackTrace; + errorModel.InnerExceptionMessage = exception?.InnerException?.Message; + } + context.Result = new ObjectResult(errorModel); + } } } diff --git a/src/Api/Utilities/ModelStateValidationFilterAttribute.cs b/src/Api/Utilities/ModelStateValidationFilterAttribute.cs index d6803f91c..3fe4f748f 100644 --- a/src/Api/Utilities/ModelStateValidationFilterAttribute.cs +++ b/src/Api/Utilities/ModelStateValidationFilterAttribute.cs @@ -3,27 +3,26 @@ using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.Filters; using InternalApi = Bit.Core.Models.Api; -namespace Bit.Api.Utilities +namespace Bit.Api.Utilities; + +public class ModelStateValidationFilterAttribute : SharedWeb.Utilities.ModelStateValidationFilterAttribute { - public class ModelStateValidationFilterAttribute : SharedWeb.Utilities.ModelStateValidationFilterAttribute + private readonly bool _publicApi; + + public ModelStateValidationFilterAttribute(bool publicApi) { - private readonly bool _publicApi; + _publicApi = publicApi; + } - public ModelStateValidationFilterAttribute(bool publicApi) + protected override void OnModelStateInvalid(ActionExecutingContext context) + { + if (_publicApi) { - _publicApi = publicApi; + context.Result = new BadRequestObjectResult(new ErrorResponseModel(context.ModelState)); } - - protected override void OnModelStateInvalid(ActionExecutingContext context) + else { - if (_publicApi) - { - context.Result = new BadRequestObjectResult(new ErrorResponseModel(context.ModelState)); - } - else - { - context.Result = new BadRequestObjectResult(new InternalApi.ErrorResponseModel(context.ModelState)); - } + context.Result = new BadRequestObjectResult(new InternalApi.ErrorResponseModel(context.ModelState)); } } } diff --git a/src/Api/Utilities/MultipartFormDataHelper.cs b/src/Api/Utilities/MultipartFormDataHelper.cs index a3e4b1967..c7ca42d50 100644 --- a/src/Api/Utilities/MultipartFormDataHelper.cs +++ b/src/Api/Utilities/MultipartFormDataHelper.cs @@ -5,75 +5,41 @@ using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.Primitives; using Microsoft.Net.Http.Headers; -namespace Bit.Api.Utilities +namespace Bit.Api.Utilities; + +public static class MultipartFormDataHelper { - public static class MultipartFormDataHelper + private static readonly FormOptions _defaultFormOptions = new FormOptions(); + + public static async Task GetFileAsync(this HttpRequest request, Func callback) { - private static readonly FormOptions _defaultFormOptions = new FormOptions(); + var boundary = GetBoundary(MediaTypeHeaderValue.Parse(request.ContentType), + _defaultFormOptions.MultipartBoundaryLengthLimit); + var reader = new MultipartReader(boundary, request.Body); - public static async Task GetFileAsync(this HttpRequest request, Func callback) + var firstSection = await reader.ReadNextSectionAsync(); + if (firstSection != null) { - var boundary = GetBoundary(MediaTypeHeaderValue.Parse(request.ContentType), - _defaultFormOptions.MultipartBoundaryLengthLimit); - var reader = new MultipartReader(boundary, request.Body); - - var firstSection = await reader.ReadNextSectionAsync(); - if (firstSection != null) + if (ContentDispositionHeaderValue.TryParse(firstSection.ContentDisposition, out var firstContent)) { - if (ContentDispositionHeaderValue.TryParse(firstSection.ContentDisposition, out var firstContent)) + if (HasFileContentDisposition(firstContent)) { - if (HasFileContentDisposition(firstContent)) + // Old style with just data + var fileName = HeaderUtilities.RemoveQuotes(firstContent.FileName).ToString(); + using (firstSection.Body) { - // Old style with just data - var fileName = HeaderUtilities.RemoveQuotes(firstContent.FileName).ToString(); - using (firstSection.Body) - { - await callback(firstSection.Body, fileName, null); - } - } - else if (HasDispositionName(firstContent, "key")) - { - // New style with key, then data - string key = null; - using (var sr = new StreamReader(firstSection.Body)) - { - key = await sr.ReadToEndAsync(); - } - - var secondSection = await reader.ReadNextSectionAsync(); - if (secondSection != null) - { - if (ContentDispositionHeaderValue.TryParse(secondSection.ContentDisposition, - out var secondContent) && HasFileContentDisposition(secondContent)) - { - var fileName = HeaderUtilities.RemoveQuotes(secondContent.FileName).ToString(); - using (secondSection.Body) - { - await callback(secondSection.Body, fileName, key); - } - } - - secondSection = null; - } + await callback(firstSection.Body, fileName, null); } } - - firstSection = null; - } - } - - public static async Task GetSendFileAsync(this HttpRequest request, Func callback) - { - var boundary = GetBoundary(MediaTypeHeaderValue.Parse(request.ContentType), - _defaultFormOptions.MultipartBoundaryLengthLimit); - var reader = new MultipartReader(boundary, request.Body); - - var firstSection = await reader.ReadNextSectionAsync(); - if (firstSection != null) - { - if (ContentDispositionHeaderValue.TryParse(firstSection.ContentDisposition, out _)) + else if (HasDispositionName(firstContent, "key")) { + // New style with key, then data + string key = null; + using (var sr = new StreamReader(firstSection.Body)) + { + key = await sr.ReadToEndAsync(); + } + var secondSection = await reader.ReadNextSectionAsync(); if (secondSection != null) { @@ -83,69 +49,102 @@ namespace Bit.Api.Utilities var fileName = HeaderUtilities.RemoveQuotes(secondContent.FileName).ToString(); using (secondSection.Body) { - var model = await JsonSerializer.DeserializeAsync(firstSection.Body); - await callback(secondSection.Body, fileName, model); + await callback(secondSection.Body, fileName, key); } } secondSection = null; } - } - - firstSection = null; - } - } - - public static async Task GetFileAsync(this HttpRequest request, Func callback) - { - var boundary = GetBoundary(MediaTypeHeaderValue.Parse(request.ContentType), - _defaultFormOptions.MultipartBoundaryLengthLimit); - var reader = new MultipartReader(boundary, request.Body); - - var dataSection = await reader.ReadNextSectionAsync(); - if (dataSection != null) - { - if (ContentDispositionHeaderValue.TryParse(dataSection.ContentDisposition, out var dataContent) - && HasFileContentDisposition(dataContent)) - { - using (dataSection.Body) - { - await callback(dataSection.Body); - } - } - dataSection = null; - } - } - - - private static string GetBoundary(MediaTypeHeaderValue contentType, int lengthLimit) - { - var boundary = HeaderUtilities.RemoveQuotes(contentType.Boundary); - if (StringSegment.IsNullOrEmpty(boundary)) - { - throw new InvalidDataException("Missing content-type boundary."); } - if (boundary.Length > lengthLimit) - { - throw new InvalidDataException($"Multipart boundary length limit {lengthLimit} exceeded."); - } - - return boundary.ToString(); - } - - private static bool HasFileContentDisposition(ContentDispositionHeaderValue content) - { - // Content-Disposition: form-data; name="data"; filename="Misc 002.jpg" - return content != null && content.DispositionType.Equals("form-data") && - (!StringSegment.IsNullOrEmpty(content.FileName) || !StringSegment.IsNullOrEmpty(content.FileNameStar)); - } - - private static bool HasDispositionName(ContentDispositionHeaderValue content, string name) - { - // Content-Disposition: form-data; name="key"; - return content != null && content.DispositionType.Equals("form-data") && content.Name == name; + firstSection = null; } } + + public static async Task GetSendFileAsync(this HttpRequest request, Func callback) + { + var boundary = GetBoundary(MediaTypeHeaderValue.Parse(request.ContentType), + _defaultFormOptions.MultipartBoundaryLengthLimit); + var reader = new MultipartReader(boundary, request.Body); + + var firstSection = await reader.ReadNextSectionAsync(); + if (firstSection != null) + { + if (ContentDispositionHeaderValue.TryParse(firstSection.ContentDisposition, out _)) + { + var secondSection = await reader.ReadNextSectionAsync(); + if (secondSection != null) + { + if (ContentDispositionHeaderValue.TryParse(secondSection.ContentDisposition, + out var secondContent) && HasFileContentDisposition(secondContent)) + { + var fileName = HeaderUtilities.RemoveQuotes(secondContent.FileName).ToString(); + using (secondSection.Body) + { + var model = await JsonSerializer.DeserializeAsync(firstSection.Body); + await callback(secondSection.Body, fileName, model); + } + } + + secondSection = null; + } + + } + + firstSection = null; + } + } + + public static async Task GetFileAsync(this HttpRequest request, Func callback) + { + var boundary = GetBoundary(MediaTypeHeaderValue.Parse(request.ContentType), + _defaultFormOptions.MultipartBoundaryLengthLimit); + var reader = new MultipartReader(boundary, request.Body); + + var dataSection = await reader.ReadNextSectionAsync(); + if (dataSection != null) + { + if (ContentDispositionHeaderValue.TryParse(dataSection.ContentDisposition, out var dataContent) + && HasFileContentDisposition(dataContent)) + { + using (dataSection.Body) + { + await callback(dataSection.Body); + } + } + dataSection = null; + } + } + + + private static string GetBoundary(MediaTypeHeaderValue contentType, int lengthLimit) + { + var boundary = HeaderUtilities.RemoveQuotes(contentType.Boundary); + if (StringSegment.IsNullOrEmpty(boundary)) + { + throw new InvalidDataException("Missing content-type boundary."); + } + + if (boundary.Length > lengthLimit) + { + throw new InvalidDataException($"Multipart boundary length limit {lengthLimit} exceeded."); + } + + return boundary.ToString(); + } + + private static bool HasFileContentDisposition(ContentDispositionHeaderValue content) + { + // Content-Disposition: form-data; name="data"; filename="Misc 002.jpg" + return content != null && content.DispositionType.Equals("form-data") && + (!StringSegment.IsNullOrEmpty(content.FileName) || !StringSegment.IsNullOrEmpty(content.FileNameStar)); + } + + private static bool HasDispositionName(ContentDispositionHeaderValue content, string name) + { + // Content-Disposition: form-data; name="key"; + return content != null && content.DispositionType.Equals("form-data") && content.Name == name; + } } diff --git a/src/Api/Utilities/PublicApiControllersModelConvention.cs b/src/Api/Utilities/PublicApiControllersModelConvention.cs index 64101148e..a7fabb031 100644 --- a/src/Api/Utilities/PublicApiControllersModelConvention.cs +++ b/src/Api/Utilities/PublicApiControllersModelConvention.cs @@ -1,15 +1,14 @@ using Microsoft.AspNetCore.Mvc.ApplicationModels; -namespace Bit.Api.Utilities +namespace Bit.Api.Utilities; + +public class PublicApiControllersModelConvention : IControllerModelConvention { - public class PublicApiControllersModelConvention : IControllerModelConvention + public void Apply(ControllerModel controller) { - public void Apply(ControllerModel controller) - { - var controllerNamespace = controller.ControllerType.Namespace; - var publicApi = controllerNamespace.Contains(".Public."); - controller.Filters.Add(new ExceptionHandlerFilterAttribute(publicApi)); - controller.Filters.Add(new ModelStateValidationFilterAttribute(publicApi)); - } + var controllerNamespace = controller.ControllerType.Namespace; + var publicApi = controllerNamespace.Contains(".Public."); + controller.Filters.Add(new ExceptionHandlerFilterAttribute(publicApi)); + controller.Filters.Add(new ModelStateValidationFilterAttribute(publicApi)); } } diff --git a/src/Api/Utilities/SecretsManagerAttribute.cs b/src/Api/Utilities/SecretsManagerAttribute.cs index 44ba46586..87540c56e 100644 --- a/src/Api/Utilities/SecretsManagerAttribute.cs +++ b/src/Api/Utilities/SecretsManagerAttribute.cs @@ -1,22 +1,20 @@ using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.Filters; -namespace Bit.Api.Utilities +namespace Bit.Api.Utilities; + +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method)] +public class SecretsManagerAttribute : Attribute, IResourceFilter { - - [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method)] - public class SecretsManagerAttribute : Attribute, IResourceFilter + public void OnResourceExecuting(ResourceExecutingContext context) { - public void OnResourceExecuting(ResourceExecutingContext context) + var env = context.HttpContext.RequestServices.GetService(); + if (!env.IsDevelopment()) { - var env = context.HttpContext.RequestServices.GetService(); - if (!env.IsDevelopment()) - { - context.Result = new NotFoundResult(); - } + context.Result = new NotFoundResult(); } - - public void OnResourceExecuted(ResourceExecutedContext context) { } } + + public void OnResourceExecuted(ResourceExecutedContext context) { } } diff --git a/src/Api/Utilities/ServiceCollectionExtensions.cs b/src/Api/Utilities/ServiceCollectionExtensions.cs index 4e57d5165..ff0ff0705 100644 --- a/src/Api/Utilities/ServiceCollectionExtensions.cs +++ b/src/Api/Utilities/ServiceCollectionExtensions.cs @@ -1,72 +1,71 @@ using Bit.Core.Settings; using Microsoft.OpenApi.Models; -namespace Bit.Api.Utilities +namespace Bit.Api.Utilities; + +public static class ServiceCollectionExtensions { - public static class ServiceCollectionExtensions + public static void AddSwagger(this IServiceCollection services, GlobalSettings globalSettings) { - public static void AddSwagger(this IServiceCollection services, GlobalSettings globalSettings) + services.AddSwaggerGen(config => { - services.AddSwaggerGen(config => + config.SwaggerDoc("public", new OpenApiInfo { - config.SwaggerDoc("public", new OpenApiInfo + Title = "Bitwarden Public API", + Version = "latest", + Contact = new OpenApiContact { - Title = "Bitwarden Public API", - Version = "latest", - Contact = new OpenApiContact - { - Name = "Bitwarden Support", - Url = new Uri("https://bitwarden.com"), - Email = "support@bitwarden.com" - }, - Description = "The Bitwarden public APIs.", - License = new OpenApiLicense - { - Name = "GNU Affero General Public License v3.0", - Url = new Uri("https://github.com/bitwarden/server/blob/master/LICENSE.txt") - } - }); - config.SwaggerDoc("internal", new OpenApiInfo { Title = "Bitwarden Internal API", Version = "latest" }); - - config.AddSecurityDefinition("OAuth2 Client Credentials", new OpenApiSecurityScheme + Name = "Bitwarden Support", + Url = new Uri("https://bitwarden.com"), + Email = "support@bitwarden.com" + }, + Description = "The Bitwarden public APIs.", + License = new OpenApiLicense { - Type = SecuritySchemeType.OAuth2, - Flows = new OpenApiOAuthFlows - { - ClientCredentials = new OpenApiOAuthFlow - { - TokenUrl = new Uri($"{globalSettings.BaseServiceUri.Identity}/connect/token"), - Scopes = new Dictionary - { - { "api.organization", "Organization APIs" }, - }, - } - }, - }); - - config.AddSecurityRequirement(new OpenApiSecurityRequirement - { - { - new OpenApiSecurityScheme - { - Reference = new OpenApiReference - { - Type = ReferenceType.SecurityScheme, - Id = "OAuth2 Client Credentials" - }, - }, - new[] { "api.organization" } - } - }); - - config.DescribeAllParametersInCamelCase(); - // config.UseReferencedDefinitionsForEnums(); - - var apiFilePath = Path.Combine(AppContext.BaseDirectory, "Api.xml"); - config.IncludeXmlComments(apiFilePath, true); - var coreFilePath = Path.Combine(AppContext.BaseDirectory, "Core.xml"); - config.IncludeXmlComments(coreFilePath); + Name = "GNU Affero General Public License v3.0", + Url = new Uri("https://github.com/bitwarden/server/blob/master/LICENSE.txt") + } }); - } + config.SwaggerDoc("internal", new OpenApiInfo { Title = "Bitwarden Internal API", Version = "latest" }); + + config.AddSecurityDefinition("OAuth2 Client Credentials", new OpenApiSecurityScheme + { + Type = SecuritySchemeType.OAuth2, + Flows = new OpenApiOAuthFlows + { + ClientCredentials = new OpenApiOAuthFlow + { + TokenUrl = new Uri($"{globalSettings.BaseServiceUri.Identity}/connect/token"), + Scopes = new Dictionary + { + { "api.organization", "Organization APIs" }, + }, + } + }, + }); + + config.AddSecurityRequirement(new OpenApiSecurityRequirement + { + { + new OpenApiSecurityScheme + { + Reference = new OpenApiReference + { + Type = ReferenceType.SecurityScheme, + Id = "OAuth2 Client Credentials" + }, + }, + new[] { "api.organization" } + } + }); + + config.DescribeAllParametersInCamelCase(); + // config.UseReferencedDefinitionsForEnums(); + + var apiFilePath = Path.Combine(AppContext.BaseDirectory, "Api.xml"); + config.IncludeXmlComments(apiFilePath, true); + var coreFilePath = Path.Combine(AppContext.BaseDirectory, "Core.xml"); + config.IncludeXmlComments(coreFilePath); + }); } } diff --git a/src/Billing/BillingSettings.cs b/src/Billing/BillingSettings.cs index 0e61775ee..5be6b205f 100644 --- a/src/Billing/BillingSettings.cs +++ b/src/Billing/BillingSettings.cs @@ -1,23 +1,22 @@ -namespace Bit.Billing -{ - public class BillingSettings - { - public virtual string JobsKey { get; set; } - public virtual string StripeWebhookKey { get; set; } - public virtual string StripeWebhookSecret { get; set; } - public virtual bool StripeEventParseThrowMismatch { get; set; } = true; - public virtual string BitPayWebhookKey { get; set; } - public virtual string AppleWebhookKey { get; set; } - public virtual string FreshdeskWebhookKey { get; set; } - public virtual string FreshdeskApiKey { get; set; } - public virtual string FreshsalesApiKey { get; set; } - public virtual PayPalSettings PayPal { get; set; } = new PayPalSettings(); +namespace Bit.Billing; - public class PayPalSettings - { - public virtual bool Production { get; set; } - public virtual string BusinessId { get; set; } - public virtual string WebhookKey { get; set; } - } +public class BillingSettings +{ + public virtual string JobsKey { get; set; } + public virtual string StripeWebhookKey { get; set; } + public virtual string StripeWebhookSecret { get; set; } + public virtual bool StripeEventParseThrowMismatch { get; set; } = true; + public virtual string BitPayWebhookKey { get; set; } + public virtual string AppleWebhookKey { get; set; } + public virtual string FreshdeskWebhookKey { get; set; } + public virtual string FreshdeskApiKey { get; set; } + public virtual string FreshsalesApiKey { get; set; } + public virtual PayPalSettings PayPal { get; set; } = new PayPalSettings(); + + public class PayPalSettings + { + public virtual bool Production { get; set; } + public virtual string BusinessId { get; set; } + public virtual string WebhookKey { get; set; } } } diff --git a/src/Billing/Constants/HandledStripeWebhook.cs b/src/Billing/Constants/HandledStripeWebhook.cs index 08d6daafc..f40b370f4 100644 --- a/src/Billing/Constants/HandledStripeWebhook.cs +++ b/src/Billing/Constants/HandledStripeWebhook.cs @@ -1,14 +1,13 @@ -namespace Bit.Billing.Constants +namespace Bit.Billing.Constants; + +public static class HandledStripeWebhook { - public static class HandledStripeWebhook - { - public static string SubscriptionDeleted => "customer.subscription.deleted"; - public static string SubscriptionUpdated => "customer.subscription.updated"; - public static string UpcomingInvoice => "invoice.upcoming"; - public static string ChargeSucceeded => "charge.succeeded"; - public static string ChargeRefunded => "charge.refunded"; - public static string PaymentSucceeded => "invoice.payment_succeeded"; - public static string PaymentFailed => "invoice.payment_failed"; - public static string InvoiceCreated => "invoice.created"; - } + public static string SubscriptionDeleted => "customer.subscription.deleted"; + public static string SubscriptionUpdated => "customer.subscription.updated"; + public static string UpcomingInvoice => "invoice.upcoming"; + public static string ChargeSucceeded => "charge.succeeded"; + public static string ChargeRefunded => "charge.refunded"; + public static string PaymentSucceeded => "invoice.payment_succeeded"; + public static string PaymentFailed => "invoice.payment_failed"; + public static string InvoiceCreated => "invoice.created"; } diff --git a/src/Billing/Controllers/AppleController.cs b/src/Billing/Controllers/AppleController.cs index dc8c82786..1bcbbf2ad 100644 --- a/src/Billing/Controllers/AppleController.cs +++ b/src/Billing/Controllers/AppleController.cs @@ -4,59 +4,58 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Options; -namespace Bit.Billing.Controllers -{ - [Route("apple")] - public class AppleController : Controller - { - private readonly BillingSettings _billingSettings; - private readonly ILogger _logger; +namespace Bit.Billing.Controllers; - public AppleController( - IOptions billingSettings, - ILogger logger) +[Route("apple")] +public class AppleController : Controller +{ + private readonly BillingSettings _billingSettings; + private readonly ILogger _logger; + + public AppleController( + IOptions billingSettings, + ILogger logger) + { + _billingSettings = billingSettings?.Value; + _logger = logger; + } + + [HttpPost("iap")] + public async Task PostIap() + { + if (HttpContext?.Request?.Query == null) { - _billingSettings = billingSettings?.Value; - _logger = logger; + return new BadRequestResult(); } - [HttpPost("iap")] - public async Task PostIap() + var key = HttpContext.Request.Query.ContainsKey("key") ? + HttpContext.Request.Query["key"].ToString() : null; + if (!CoreHelpers.FixedTimeEquals(key, _billingSettings.AppleWebhookKey)) { - if (HttpContext?.Request?.Query == null) - { - return new BadRequestResult(); - } + return new BadRequestResult(); + } - var key = HttpContext.Request.Query.ContainsKey("key") ? - HttpContext.Request.Query["key"].ToString() : null; - if (!CoreHelpers.FixedTimeEquals(key, _billingSettings.AppleWebhookKey)) - { - return new BadRequestResult(); - } + string body = null; + using (var reader = new StreamReader(HttpContext.Request.Body, Encoding.UTF8)) + { + body = await reader.ReadToEndAsync(); + } - string body = null; - using (var reader = new StreamReader(HttpContext.Request.Body, Encoding.UTF8)) - { - body = await reader.ReadToEndAsync(); - } + if (string.IsNullOrWhiteSpace(body)) + { + return new BadRequestResult(); + } - if (string.IsNullOrWhiteSpace(body)) - { - return new BadRequestResult(); - } - - try - { - var json = JsonSerializer.Serialize(JsonSerializer.Deserialize(body), JsonHelpers.Indented); - _logger.LogInformation(Bit.Core.Constants.BypassFiltersEventId, "Apple IAP Notification:\n\n{0}", json); - return new OkResult(); - } - catch (Exception e) - { - _logger.LogError(e, "Error processing IAP status notification."); - return new BadRequestResult(); - } + try + { + var json = JsonSerializer.Serialize(JsonSerializer.Deserialize(body), JsonHelpers.Indented); + _logger.LogInformation(Bit.Core.Constants.BypassFiltersEventId, "Apple IAP Notification:\n\n{0}", json); + return new OkResult(); + } + catch (Exception e) + { + _logger.LogError(e, "Error processing IAP status notification."); + return new BadRequestResult(); } } } diff --git a/src/Billing/Controllers/BitPayController.cs b/src/Billing/Controllers/BitPayController.cs index 520f00a22..539d35595 100644 --- a/src/Billing/Controllers/BitPayController.cs +++ b/src/Billing/Controllers/BitPayController.cs @@ -9,200 +9,199 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Options; -namespace Bit.Billing.Controllers -{ - [Route("bitpay")] - public class BitPayController : Controller - { - private readonly BillingSettings _billingSettings; - private readonly BitPayClient _bitPayClient; - private readonly ITransactionRepository _transactionRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IUserRepository _userRepository; - private readonly IMailService _mailService; - private readonly IPaymentService _paymentService; - private readonly ILogger _logger; +namespace Bit.Billing.Controllers; - public BitPayController( - IOptions billingSettings, - BitPayClient bitPayClient, - ITransactionRepository transactionRepository, - IOrganizationRepository organizationRepository, - IUserRepository userRepository, - IMailService mailService, - IPaymentService paymentService, - ILogger logger) +[Route("bitpay")] +public class BitPayController : Controller +{ + private readonly BillingSettings _billingSettings; + private readonly BitPayClient _bitPayClient; + private readonly ITransactionRepository _transactionRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IUserRepository _userRepository; + private readonly IMailService _mailService; + private readonly IPaymentService _paymentService; + private readonly ILogger _logger; + + public BitPayController( + IOptions billingSettings, + BitPayClient bitPayClient, + ITransactionRepository transactionRepository, + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IMailService mailService, + IPaymentService paymentService, + ILogger logger) + { + _billingSettings = billingSettings?.Value; + _bitPayClient = bitPayClient; + _transactionRepository = transactionRepository; + _organizationRepository = organizationRepository; + _userRepository = userRepository; + _mailService = mailService; + _paymentService = paymentService; + _logger = logger; + } + + [HttpPost("ipn")] + public async Task PostIpn([FromBody] BitPayEventModel model, [FromQuery] string key) + { + if (!CoreHelpers.FixedTimeEquals(key, _billingSettings.BitPayWebhookKey)) { - _billingSettings = billingSettings?.Value; - _bitPayClient = bitPayClient; - _transactionRepository = transactionRepository; - _organizationRepository = organizationRepository; - _userRepository = userRepository; - _mailService = mailService; - _paymentService = paymentService; - _logger = logger; + return new BadRequestResult(); + } + if (model == null || string.IsNullOrWhiteSpace(model.Data?.Id) || + string.IsNullOrWhiteSpace(model.Event?.Name)) + { + return new BadRequestResult(); } - [HttpPost("ipn")] - public async Task PostIpn([FromBody] BitPayEventModel model, [FromQuery] string key) + if (model.Event.Name != "invoice_confirmed") { - if (!CoreHelpers.FixedTimeEquals(key, _billingSettings.BitPayWebhookKey)) - { - return new BadRequestResult(); - } - if (model == null || string.IsNullOrWhiteSpace(model.Data?.Id) || - string.IsNullOrWhiteSpace(model.Event?.Name)) - { - return new BadRequestResult(); - } - - if (model.Event.Name != "invoice_confirmed") - { - // Only processing confirmed invoice events for now. - return new OkResult(); - } - - var invoice = await _bitPayClient.GetInvoiceAsync(model.Data.Id); - if (invoice == null) - { - // Request forged...? - _logger.LogWarning("Invoice not found. #" + model.Data.Id); - return new BadRequestResult(); - } - - if (invoice.Status != "confirmed" && invoice.Status != "completed") - { - _logger.LogWarning("Invoice status of '" + invoice.Status + "' is not acceptable. #" + invoice.Id); - return new BadRequestResult(); - } - - if (invoice.Currency != "USD") - { - // Only process USD payments - _logger.LogWarning("Non USD payment received. #" + invoice.Id); - return new OkResult(); - } - - var ids = GetIdsFromPosData(invoice); - if (!ids.Item1.HasValue && !ids.Item2.HasValue) - { - return new OkResult(); - } - - var isAccountCredit = IsAccountCredit(invoice); - if (!isAccountCredit) - { - // Only processing credits - _logger.LogWarning("Non-credit payment received. #" + invoice.Id); - return new OkResult(); - } - - var transaction = await _transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id); - if (transaction != null) - { - _logger.LogWarning("Already processed this invoice. #" + invoice.Id); - return new OkResult(); - } - - try - { - var tx = new Transaction - { - Amount = Convert.ToDecimal(invoice.Price), - CreationDate = GetTransactionDate(invoice), - OrganizationId = ids.Item1, - UserId = ids.Item2, - Type = TransactionType.Credit, - Gateway = GatewayType.BitPay, - GatewayId = invoice.Id, - PaymentMethodType = PaymentMethodType.BitPay, - Details = $"{invoice.Currency}, BitPay {invoice.Id}" - }; - await _transactionRepository.CreateAsync(tx); - - if (isAccountCredit) - { - string billingEmail = null; - if (tx.OrganizationId.HasValue) - { - var org = await _organizationRepository.GetByIdAsync(tx.OrganizationId.Value); - if (org != null) - { - billingEmail = org.BillingEmailAddress(); - if (await _paymentService.CreditAccountAsync(org, tx.Amount)) - { - await _organizationRepository.ReplaceAsync(org); - } - } - } - else - { - var user = await _userRepository.GetByIdAsync(tx.UserId.Value); - if (user != null) - { - billingEmail = user.BillingEmailAddress(); - if (await _paymentService.CreditAccountAsync(user, tx.Amount)) - { - await _userRepository.ReplaceAsync(user); - } - } - } - - if (!string.IsNullOrWhiteSpace(billingEmail)) - { - await _mailService.SendAddedCreditAsync(billingEmail, tx.Amount); - } - } - } - // Catch foreign key violations because user/org could have been deleted. - catch (SqlException e) when (e.Number == 547) { } - + // Only processing confirmed invoice events for now. return new OkResult(); } - private bool IsAccountCredit(BitPayLight.Models.Invoice.Invoice invoice) + var invoice = await _bitPayClient.GetInvoiceAsync(model.Data.Id); + if (invoice == null) { - return invoice != null && invoice.PosData != null && invoice.PosData.Contains("accountCredit:1"); + // Request forged...? + _logger.LogWarning("Invoice not found. #" + model.Data.Id); + return new BadRequestResult(); } - private DateTime GetTransactionDate(BitPayLight.Models.Invoice.Invoice invoice) + if (invoice.Status != "confirmed" && invoice.Status != "completed") { - var transactions = invoice.Transactions?.Where(t => t.Type == null && - !string.IsNullOrWhiteSpace(t.Confirmations) && t.Confirmations != "0"); - if (transactions != null && transactions.Count() == 1) - { - return DateTime.Parse(transactions.First().ReceivedTime, CultureInfo.InvariantCulture, - DateTimeStyles.RoundtripKind); - } - return CoreHelpers.FromEpocMilliseconds(invoice.CurrentTime); + _logger.LogWarning("Invoice status of '" + invoice.Status + "' is not acceptable. #" + invoice.Id); + return new BadRequestResult(); } - public Tuple GetIdsFromPosData(BitPayLight.Models.Invoice.Invoice invoice) + if (invoice.Currency != "USD") { - Guid? orgId = null; - Guid? userId = null; + // Only process USD payments + _logger.LogWarning("Non USD payment received. #" + invoice.Id); + return new OkResult(); + } - if (invoice != null && !string.IsNullOrWhiteSpace(invoice.PosData) && invoice.PosData.Contains(":")) + var ids = GetIdsFromPosData(invoice); + if (!ids.Item1.HasValue && !ids.Item2.HasValue) + { + return new OkResult(); + } + + var isAccountCredit = IsAccountCredit(invoice); + if (!isAccountCredit) + { + // Only processing credits + _logger.LogWarning("Non-credit payment received. #" + invoice.Id); + return new OkResult(); + } + + var transaction = await _transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id); + if (transaction != null) + { + _logger.LogWarning("Already processed this invoice. #" + invoice.Id); + return new OkResult(); + } + + try + { + var tx = new Transaction { - var mainParts = invoice.PosData.Split(','); - foreach (var mainPart in mainParts) + Amount = Convert.ToDecimal(invoice.Price), + CreationDate = GetTransactionDate(invoice), + OrganizationId = ids.Item1, + UserId = ids.Item2, + Type = TransactionType.Credit, + Gateway = GatewayType.BitPay, + GatewayId = invoice.Id, + PaymentMethodType = PaymentMethodType.BitPay, + Details = $"{invoice.Currency}, BitPay {invoice.Id}" + }; + await _transactionRepository.CreateAsync(tx); + + if (isAccountCredit) + { + string billingEmail = null; + if (tx.OrganizationId.HasValue) { - var parts = mainPart.Split(':'); - if (parts.Length > 1 && Guid.TryParse(parts[1], out var id)) + var org = await _organizationRepository.GetByIdAsync(tx.OrganizationId.Value); + if (org != null) { - if (parts[0] == "userId") + billingEmail = org.BillingEmailAddress(); + if (await _paymentService.CreditAccountAsync(org, tx.Amount)) { - userId = id; - } - else if (parts[0] == "organizationId") - { - orgId = id; + await _organizationRepository.ReplaceAsync(org); + } + } + } + else + { + var user = await _userRepository.GetByIdAsync(tx.UserId.Value); + if (user != null) + { + billingEmail = user.BillingEmailAddress(); + if (await _paymentService.CreditAccountAsync(user, tx.Amount)) + { + await _userRepository.ReplaceAsync(user); } } } - } - return new Tuple(orgId, userId); + if (!string.IsNullOrWhiteSpace(billingEmail)) + { + await _mailService.SendAddedCreditAsync(billingEmail, tx.Amount); + } + } } + // Catch foreign key violations because user/org could have been deleted. + catch (SqlException e) when (e.Number == 547) { } + + return new OkResult(); + } + + private bool IsAccountCredit(BitPayLight.Models.Invoice.Invoice invoice) + { + return invoice != null && invoice.PosData != null && invoice.PosData.Contains("accountCredit:1"); + } + + private DateTime GetTransactionDate(BitPayLight.Models.Invoice.Invoice invoice) + { + var transactions = invoice.Transactions?.Where(t => t.Type == null && + !string.IsNullOrWhiteSpace(t.Confirmations) && t.Confirmations != "0"); + if (transactions != null && transactions.Count() == 1) + { + return DateTime.Parse(transactions.First().ReceivedTime, CultureInfo.InvariantCulture, + DateTimeStyles.RoundtripKind); + } + return CoreHelpers.FromEpocMilliseconds(invoice.CurrentTime); + } + + public Tuple GetIdsFromPosData(BitPayLight.Models.Invoice.Invoice invoice) + { + Guid? orgId = null; + Guid? userId = null; + + if (invoice != null && !string.IsNullOrWhiteSpace(invoice.PosData) && invoice.PosData.Contains(":")) + { + var mainParts = invoice.PosData.Split(','); + foreach (var mainPart in mainParts) + { + var parts = mainPart.Split(':'); + if (parts.Length > 1 && Guid.TryParse(parts[1], out var id)) + { + if (parts[0] == "userId") + { + userId = id; + } + else if (parts[0] == "organizationId") + { + orgId = id; + } + } + } + } + + return new Tuple(orgId, userId); } } diff --git a/src/Billing/Controllers/FreshdeskController.cs b/src/Billing/Controllers/FreshdeskController.cs index 7e7b0a6b4..e38a89242 100644 --- a/src/Billing/Controllers/FreshdeskController.cs +++ b/src/Billing/Controllers/FreshdeskController.cs @@ -8,166 +8,165 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Options; -namespace Bit.Billing.Controllers +namespace Bit.Billing.Controllers; + +[Route("freshdesk")] +public class FreshdeskController : Controller { - [Route("freshdesk")] - public class FreshdeskController : Controller + private readonly BillingSettings _billingSettings; + private readonly IUserRepository _userRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly ILogger _logger; + private readonly GlobalSettings _globalSettings; + private readonly IHttpClientFactory _httpClientFactory; + + public FreshdeskController( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOptions billingSettings, + ILogger logger, + GlobalSettings globalSettings, + IHttpClientFactory httpClientFactory) { - private readonly BillingSettings _billingSettings; - private readonly IUserRepository _userRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly ILogger _logger; - private readonly GlobalSettings _globalSettings; - private readonly IHttpClientFactory _httpClientFactory; + _billingSettings = billingSettings?.Value; + _userRepository = userRepository; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _logger = logger; + _globalSettings = globalSettings; + _httpClientFactory = httpClientFactory; + } - public FreshdeskController( - IUserRepository userRepository, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IOptions billingSettings, - ILogger logger, - GlobalSettings globalSettings, - IHttpClientFactory httpClientFactory) + [HttpPost("webhook")] + public async Task PostWebhook([FromQuery, Required] string key, + [FromBody, Required] FreshdeskWebhookModel model) + { + if (string.IsNullOrWhiteSpace(key) || !CoreHelpers.FixedTimeEquals(key, _billingSettings.FreshdeskWebhookKey)) { - _billingSettings = billingSettings?.Value; - _userRepository = userRepository; - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _logger = logger; - _globalSettings = globalSettings; - _httpClientFactory = httpClientFactory; + return new BadRequestResult(); } - [HttpPost("webhook")] - public async Task PostWebhook([FromQuery, Required] string key, - [FromBody, Required] FreshdeskWebhookModel model) + try { - if (string.IsNullOrWhiteSpace(key) || !CoreHelpers.FixedTimeEquals(key, _billingSettings.FreshdeskWebhookKey)) + var ticketId = model.TicketId; + var ticketContactEmail = model.TicketContactEmail; + var ticketTags = model.TicketTags; + if (string.IsNullOrWhiteSpace(ticketId) || string.IsNullOrWhiteSpace(ticketContactEmail)) { return new BadRequestResult(); } - try + var updateBody = new Dictionary(); + var note = string.Empty; + var customFields = new Dictionary(); + var user = await _userRepository.GetByEmailAsync(ticketContactEmail); + if (user != null) { - var ticketId = model.TicketId; - var ticketContactEmail = model.TicketContactEmail; - var ticketTags = model.TicketTags; - if (string.IsNullOrWhiteSpace(ticketId) || string.IsNullOrWhiteSpace(ticketContactEmail)) + var userLink = $"{_globalSettings.BaseServiceUri.Admin}/users/edit/{user.Id}"; + note += $"
  • User, {user.Email}: {userLink}
  • "; + customFields.Add("cf_user", userLink); + var tags = new HashSet(); + if (user.Premium) { - return new BadRequestResult(); + tags.Add("Premium"); + } + var orgs = await _organizationRepository.GetManyByUserIdAsync(user.Id); + + foreach (var org in orgs) + { + var orgNote = $"{org.Name} ({org.Seats.GetValueOrDefault()}): " + + $"{_globalSettings.BaseServiceUri.Admin}/organizations/edit/{org.Id}"; + note += $"
  • Org, {orgNote}
  • "; + if (!customFields.Any(kvp => kvp.Key == "cf_org")) + { + customFields.Add("cf_org", orgNote); + } + else + { + customFields["cf_org"] += $"\n{orgNote}"; + } + + var planName = GetAttribute(org.PlanType).Name.Split(" ").FirstOrDefault(); + if (!string.IsNullOrWhiteSpace(planName)) + { + tags.Add(string.Format("Org: {0}", planName)); + } + } + if (tags.Any()) + { + var tagsToUpdate = tags.ToList(); + if (!string.IsNullOrWhiteSpace(ticketTags)) + { + var splitTicketTags = ticketTags.Split(','); + for (var i = 0; i < splitTicketTags.Length; i++) + { + tagsToUpdate.Insert(i, splitTicketTags[i]); + } + } + updateBody.Add("tags", tagsToUpdate); } - var updateBody = new Dictionary(); - var note = string.Empty; - var customFields = new Dictionary(); - var user = await _userRepository.GetByEmailAsync(ticketContactEmail); - if (user != null) + if (customFields.Any()) { - var userLink = $"{_globalSettings.BaseServiceUri.Admin}/users/edit/{user.Id}"; - note += $"
  • User, {user.Email}: {userLink}
  • "; - customFields.Add("cf_user", userLink); - var tags = new HashSet(); - if (user.Premium) - { - tags.Add("Premium"); - } - var orgs = await _organizationRepository.GetManyByUserIdAsync(user.Id); - - foreach (var org in orgs) - { - var orgNote = $"{org.Name} ({org.Seats.GetValueOrDefault()}): " + - $"{_globalSettings.BaseServiceUri.Admin}/organizations/edit/{org.Id}"; - note += $"
  • Org, {orgNote}
  • "; - if (!customFields.Any(kvp => kvp.Key == "cf_org")) - { - customFields.Add("cf_org", orgNote); - } - else - { - customFields["cf_org"] += $"\n{orgNote}"; - } - - var planName = GetAttribute(org.PlanType).Name.Split(" ").FirstOrDefault(); - if (!string.IsNullOrWhiteSpace(planName)) - { - tags.Add(string.Format("Org: {0}", planName)); - } - } - if (tags.Any()) - { - var tagsToUpdate = tags.ToList(); - if (!string.IsNullOrWhiteSpace(ticketTags)) - { - var splitTicketTags = ticketTags.Split(','); - for (var i = 0; i < splitTicketTags.Length; i++) - { - tagsToUpdate.Insert(i, splitTicketTags[i]); - } - } - updateBody.Add("tags", tagsToUpdate); - } - - if (customFields.Any()) - { - updateBody.Add("custom_fields", customFields); - } - var updateRequest = new HttpRequestMessage(HttpMethod.Put, - string.Format("https://bitwarden.freshdesk.com/api/v2/tickets/{0}", ticketId)) - { - Content = JsonContent.Create(updateBody), - }; - await CallFreshdeskApiAsync(updateRequest); - - var noteBody = new Dictionary - { - { "body", $"
      {note}
    " }, - { "private", true } - }; - var noteRequest = new HttpRequestMessage(HttpMethod.Post, - string.Format("https://bitwarden.freshdesk.com/api/v2/tickets/{0}/notes", ticketId)) - { - Content = JsonContent.Create(noteBody), - }; - await CallFreshdeskApiAsync(noteRequest); + updateBody.Add("custom_fields", customFields); } + var updateRequest = new HttpRequestMessage(HttpMethod.Put, + string.Format("https://bitwarden.freshdesk.com/api/v2/tickets/{0}", ticketId)) + { + Content = JsonContent.Create(updateBody), + }; + await CallFreshdeskApiAsync(updateRequest); - return new OkResult(); - } - catch (Exception e) - { - _logger.LogError(e, "Error processing freshdesk webhook."); - return new BadRequestResult(); + var noteBody = new Dictionary + { + { "body", $"
      {note}
    " }, + { "private", true } + }; + var noteRequest = new HttpRequestMessage(HttpMethod.Post, + string.Format("https://bitwarden.freshdesk.com/api/v2/tickets/{0}/notes", ticketId)) + { + Content = JsonContent.Create(noteBody), + }; + await CallFreshdeskApiAsync(noteRequest); } + + return new OkResult(); } - - private async Task CallFreshdeskApiAsync(HttpRequestMessage request, int retriedCount = 0) + catch (Exception e) { - try - { - var freshdeskAuthkey = Convert.ToBase64String(Encoding.UTF8.GetBytes($"{_billingSettings.FreshdeskApiKey}:X")); - var httpClient = _httpClientFactory.CreateClient("FreshdeskApi"); - request.Headers.Add("Authorization", freshdeskAuthkey); - var response = await httpClient.SendAsync(request); - if (response.StatusCode != System.Net.HttpStatusCode.TooManyRequests || retriedCount > 3) - { - return response; - } - } - catch - { - if (retriedCount > 3) - { - throw; - } - } - await Task.Delay(30000 * (retriedCount + 1)); - return await CallFreshdeskApiAsync(request, retriedCount++); - } - - private TAttribute GetAttribute(Enum enumValue) where TAttribute : Attribute - { - return enumValue.GetType().GetMember(enumValue.ToString()).First().GetCustomAttribute(); + _logger.LogError(e, "Error processing freshdesk webhook."); + return new BadRequestResult(); } } + + private async Task CallFreshdeskApiAsync(HttpRequestMessage request, int retriedCount = 0) + { + try + { + var freshdeskAuthkey = Convert.ToBase64String(Encoding.UTF8.GetBytes($"{_billingSettings.FreshdeskApiKey}:X")); + var httpClient = _httpClientFactory.CreateClient("FreshdeskApi"); + request.Headers.Add("Authorization", freshdeskAuthkey); + var response = await httpClient.SendAsync(request); + if (response.StatusCode != System.Net.HttpStatusCode.TooManyRequests || retriedCount > 3) + { + return response; + } + } + catch + { + if (retriedCount > 3) + { + throw; + } + } + await Task.Delay(30000 * (retriedCount + 1)); + return await CallFreshdeskApiAsync(request, retriedCount++); + } + + private TAttribute GetAttribute(Enum enumValue) where TAttribute : Attribute + { + return enumValue.GetType().GetMember(enumValue.ToString()).First().GetCustomAttribute(); + } } diff --git a/src/Billing/Controllers/FreshsalesController.cs b/src/Billing/Controllers/FreshsalesController.cs index 866b95d17..95b9e2506 100644 --- a/src/Billing/Controllers/FreshsalesController.cs +++ b/src/Billing/Controllers/FreshsalesController.cs @@ -7,229 +7,228 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Options; -namespace Bit.Billing.Controllers +namespace Bit.Billing.Controllers; + +[Route("freshsales")] +public class FreshsalesController : Controller { - [Route("freshsales")] - public class FreshsalesController : Controller + private readonly IUserRepository _userRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly ILogger _logger; + private readonly GlobalSettings _globalSettings; + + private readonly string _freshsalesApiKey; + + private readonly HttpClient _httpClient; + + public FreshsalesController(IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOptions billingSettings, + ILogger logger, + GlobalSettings globalSettings) { - private readonly IUserRepository _userRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly ILogger _logger; - private readonly GlobalSettings _globalSettings; + _userRepository = userRepository; + _organizationRepository = organizationRepository; + _logger = logger; + _globalSettings = globalSettings; - private readonly string _freshsalesApiKey; - - private readonly HttpClient _httpClient; - - public FreshsalesController(IUserRepository userRepository, - IOrganizationRepository organizationRepository, - IOptions billingSettings, - ILogger logger, - GlobalSettings globalSettings) + _httpClient = new HttpClient { - _userRepository = userRepository; - _organizationRepository = organizationRepository; - _logger = logger; - _globalSettings = globalSettings; + BaseAddress = new Uri("https://bitwarden.freshsales.io/api/") + }; - _httpClient = new HttpClient - { - BaseAddress = new Uri("https://bitwarden.freshsales.io/api/") - }; + _freshsalesApiKey = billingSettings.Value.FreshsalesApiKey; - _freshsalesApiKey = billingSettings.Value.FreshsalesApiKey; + _httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue( + "Token", + $"token={_freshsalesApiKey}"); + } - _httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue( - "Token", - $"token={_freshsalesApiKey}"); + + [HttpPost("webhook")] + public async Task PostWebhook([FromHeader(Name = "Authorization")] string key, + [FromBody] CustomWebhookRequestModel request, + CancellationToken cancellationToken) + { + if (string.IsNullOrWhiteSpace(key) || !CoreHelpers.FixedTimeEquals(_freshsalesApiKey, key)) + { + return Unauthorized(); } - - [HttpPost("webhook")] - public async Task PostWebhook([FromHeader(Name = "Authorization")] string key, - [FromBody] CustomWebhookRequestModel request, - CancellationToken cancellationToken) + try { - if (string.IsNullOrWhiteSpace(key) || !CoreHelpers.FixedTimeEquals(_freshsalesApiKey, key)) + var leadResponse = await _httpClient.GetFromJsonAsync>( + $"leads/{request.LeadId}", + cancellationToken); + + var lead = leadResponse.Lead; + + var primaryEmail = lead.Emails + .Where(e => e.IsPrimary) + .FirstOrDefault(); + + if (primaryEmail == null) { - return Unauthorized(); + return BadRequest(new { Message = "Lead has not primary email." }); } - try + var user = await _userRepository.GetByEmailAsync(primaryEmail.Value); + + if (user == null) { - var leadResponse = await _httpClient.GetFromJsonAsync>( - $"leads/{request.LeadId}", - cancellationToken); - - var lead = leadResponse.Lead; - - var primaryEmail = lead.Emails - .Where(e => e.IsPrimary) - .FirstOrDefault(); - - if (primaryEmail == null) - { - return BadRequest(new { Message = "Lead has not primary email." }); - } - - var user = await _userRepository.GetByEmailAsync(primaryEmail.Value); - - if (user == null) - { - return NoContent(); - } - - var newTags = new HashSet(); - - if (user.Premium) - { - newTags.Add("Premium"); - } - - var noteItems = new List - { - $"User, {user.Email}: {_globalSettings.BaseServiceUri.Admin}/users/edit/{user.Id}" - }; - - var orgs = await _organizationRepository.GetManyByUserIdAsync(user.Id); - - foreach (var org in orgs) - { - noteItems.Add($"Org, {org.Name}: {_globalSettings.BaseServiceUri.Admin}/organizations/edit/{org.Id}"); - if (TryGetPlanName(org.PlanType, out var planName)) - { - newTags.Add($"Org: {planName}"); - } - } - - if (newTags.Any()) - { - var allTags = newTags.Concat(lead.Tags); - var updateLeadResponse = await _httpClient.PutAsJsonAsync( - $"leads/{request.LeadId}", - CreateWrapper(new { tags = allTags }), - cancellationToken); - updateLeadResponse.EnsureSuccessStatusCode(); - } - - var createNoteResponse = await _httpClient.PostAsJsonAsync( - "notes", - CreateNoteRequestModel(request.LeadId, string.Join('\n', noteItems)), cancellationToken); - createNoteResponse.EnsureSuccessStatusCode(); return NoContent(); } - catch (Exception ex) + + var newTags = new HashSet(); + + if (user.Premium) { - Console.WriteLine(ex); - _logger.LogError(ex, "Error processing freshsales webhook"); - return BadRequest(new { ex.Message }); + newTags.Add("Premium"); } - } - private static LeadWrapper CreateWrapper(T lead) - { - return new LeadWrapper + var noteItems = new List { - Lead = lead, + $"User, {user.Email}: {_globalSettings.BaseServiceUri.Admin}/users/edit/{user.Id}" }; - } - private static CreateNoteRequestModel CreateNoteRequestModel(long leadId, string content) - { - return new CreateNoteRequestModel + var orgs = await _organizationRepository.GetManyByUserIdAsync(user.Id); + + foreach (var org in orgs) { - Note = new EditNoteModel + noteItems.Add($"Org, {org.Name}: {_globalSettings.BaseServiceUri.Admin}/organizations/edit/{org.Id}"); + if (TryGetPlanName(org.PlanType, out var planName)) { - Description = content, - TargetableType = "Lead", - TargetableId = leadId, - }, - }; - } - - private static bool TryGetPlanName(PlanType planType, out string planName) - { - switch (planType) - { - case PlanType.Free: - planName = "Free"; - return true; - case PlanType.FamiliesAnnually: - case PlanType.FamiliesAnnually2019: - planName = "Families"; - return true; - case PlanType.TeamsAnnually: - case PlanType.TeamsAnnually2019: - case PlanType.TeamsMonthly: - case PlanType.TeamsMonthly2019: - planName = "Teams"; - return true; - case PlanType.EnterpriseAnnually: - case PlanType.EnterpriseAnnually2019: - case PlanType.EnterpriseMonthly: - case PlanType.EnterpriseMonthly2019: - planName = "Enterprise"; - return true; - case PlanType.Custom: - planName = "Custom"; - return true; - default: - planName = null; - return false; + newTags.Add($"Org: {planName}"); + } } - } - } - public class CustomWebhookRequestModel - { - [JsonPropertyName("leadId")] - public long LeadId { get; set; } - } - - public class LeadWrapper - { - [JsonPropertyName("lead")] - public T Lead { get; set; } - - public static LeadWrapper Create(TItem lead) - { - return new LeadWrapper + if (newTags.Any()) { - Lead = lead, - }; + var allTags = newTags.Concat(lead.Tags); + var updateLeadResponse = await _httpClient.PutAsJsonAsync( + $"leads/{request.LeadId}", + CreateWrapper(new { tags = allTags }), + cancellationToken); + updateLeadResponse.EnsureSuccessStatusCode(); + } + + var createNoteResponse = await _httpClient.PostAsJsonAsync( + "notes", + CreateNoteRequestModel(request.LeadId, string.Join('\n', noteItems)), cancellationToken); + createNoteResponse.EnsureSuccessStatusCode(); + return NoContent(); + } + catch (Exception ex) + { + Console.WriteLine(ex); + _logger.LogError(ex, "Error processing freshsales webhook"); + return BadRequest(new { ex.Message }); } } - public class FreshsalesLeadModel + private static LeadWrapper CreateWrapper(T lead) { - public string[] Tags { get; set; } - public FreshsalesEmailModel[] Emails { get; set; } + return new LeadWrapper + { + Lead = lead, + }; } - public class FreshsalesEmailModel + private static CreateNoteRequestModel CreateNoteRequestModel(long leadId, string content) { - [JsonPropertyName("value")] - public string Value { get; set; } - - [JsonPropertyName("is_primary")] - public bool IsPrimary { get; set; } + return new CreateNoteRequestModel + { + Note = new EditNoteModel + { + Description = content, + TargetableType = "Lead", + TargetableId = leadId, + }, + }; } - public class CreateNoteRequestModel + private static bool TryGetPlanName(PlanType planType, out string planName) { - [JsonPropertyName("note")] - public EditNoteModel Note { get; set; } - } - - public class EditNoteModel - { - [JsonPropertyName("description")] - public string Description { get; set; } - - [JsonPropertyName("targetable_type")] - public string TargetableType { get; set; } - - [JsonPropertyName("targetable_id")] - public long TargetableId { get; set; } + switch (planType) + { + case PlanType.Free: + planName = "Free"; + return true; + case PlanType.FamiliesAnnually: + case PlanType.FamiliesAnnually2019: + planName = "Families"; + return true; + case PlanType.TeamsAnnually: + case PlanType.TeamsAnnually2019: + case PlanType.TeamsMonthly: + case PlanType.TeamsMonthly2019: + planName = "Teams"; + return true; + case PlanType.EnterpriseAnnually: + case PlanType.EnterpriseAnnually2019: + case PlanType.EnterpriseMonthly: + case PlanType.EnterpriseMonthly2019: + planName = "Enterprise"; + return true; + case PlanType.Custom: + planName = "Custom"; + return true; + default: + planName = null; + return false; + } } } + +public class CustomWebhookRequestModel +{ + [JsonPropertyName("leadId")] + public long LeadId { get; set; } +} + +public class LeadWrapper +{ + [JsonPropertyName("lead")] + public T Lead { get; set; } + + public static LeadWrapper Create(TItem lead) + { + return new LeadWrapper + { + Lead = lead, + }; + } +} + +public class FreshsalesLeadModel +{ + public string[] Tags { get; set; } + public FreshsalesEmailModel[] Emails { get; set; } +} + +public class FreshsalesEmailModel +{ + [JsonPropertyName("value")] + public string Value { get; set; } + + [JsonPropertyName("is_primary")] + public bool IsPrimary { get; set; } +} + +public class CreateNoteRequestModel +{ + [JsonPropertyName("note")] + public EditNoteModel Note { get; set; } +} + +public class EditNoteModel +{ + [JsonPropertyName("description")] + public string Description { get; set; } + + [JsonPropertyName("targetable_type")] + public string TargetableType { get; set; } + + [JsonPropertyName("targetable_id")] + public long TargetableId { get; set; } +} diff --git a/src/Billing/Controllers/InfoController.cs b/src/Billing/Controllers/InfoController.cs index 5d7ce5754..58b29f4c4 100644 --- a/src/Billing/Controllers/InfoController.cs +++ b/src/Billing/Controllers/InfoController.cs @@ -1,21 +1,20 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Billing.Controllers -{ - public class InfoController : Controller - { - [HttpGet("~/alive")] - [HttpGet("~/now")] - public DateTime GetAlive() - { - return DateTime.UtcNow; - } +namespace Bit.Billing.Controllers; - [HttpGet("~/version")] - public JsonResult GetVersion() - { - return Json(CoreHelpers.GetVersion()); - } +public class InfoController : Controller +{ + [HttpGet("~/alive")] + [HttpGet("~/now")] + public DateTime GetAlive() + { + return DateTime.UtcNow; + } + + [HttpGet("~/version")] + public JsonResult GetVersion() + { + return Json(CoreHelpers.GetVersion()); } } diff --git a/src/Billing/Controllers/LoginController.cs b/src/Billing/Controllers/LoginController.cs index 448e2b9b2..c2df41b92 100644 --- a/src/Billing/Controllers/LoginController.cs +++ b/src/Billing/Controllers/LoginController.cs @@ -1,54 +1,53 @@ using Microsoft.AspNetCore.Mvc; -namespace Billing.Controllers +namespace Billing.Controllers; + +public class LoginController : Controller { - public class LoginController : Controller + /* + private readonly PasswordlessSignInManager _signInManager; + + public LoginController( + PasswordlessSignInManager signInManager) { - /* - private readonly PasswordlessSignInManager _signInManager; - - public LoginController( - PasswordlessSignInManager signInManager) - { - _signInManager = signInManager; - } - - public IActionResult Index() - { - return View(); - } - - [HttpPost] - [ValidateAntiForgeryToken] - public async Task Index(LoginModel model) - { - if (ModelState.IsValid) - { - var result = await _signInManager.PasswordlessSignInAsync(model.Email, - Url.Action("Confirm", "Login", null, Request.Scheme)); - if (result.Succeeded) - { - return RedirectToAction("Index", "Home"); - } - else - { - ModelState.AddModelError(string.Empty, "Account not found."); - } - } - - return View(model); - } - - public async Task Confirm(string email, string token) - { - var result = await _signInManager.PasswordlessSignInAsync(email, token, false); - if (!result.Succeeded) - { - return View("Error"); - } - - return RedirectToAction("Index", "Home"); - } - */ + _signInManager = signInManager; } + + public IActionResult Index() + { + return View(); + } + + [HttpPost] + [ValidateAntiForgeryToken] + public async Task Index(LoginModel model) + { + if (ModelState.IsValid) + { + var result = await _signInManager.PasswordlessSignInAsync(model.Email, + Url.Action("Confirm", "Login", null, Request.Scheme)); + if (result.Succeeded) + { + return RedirectToAction("Index", "Home"); + } + else + { + ModelState.AddModelError(string.Empty, "Account not found."); + } + } + + return View(model); + } + + public async Task Confirm(string email, string token) + { + var result = await _signInManager.PasswordlessSignInAsync(email, token, false); + if (!result.Succeeded) + { + return View("Error"); + } + + return RedirectToAction("Index", "Home"); + } + */ } diff --git a/src/Billing/Controllers/PayPalController.cs b/src/Billing/Controllers/PayPalController.cs index 64811b5ae..67826afc6 100644 --- a/src/Billing/Controllers/PayPalController.cs +++ b/src/Billing/Controllers/PayPalController.cs @@ -9,227 +9,226 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Options; -namespace Bit.Billing.Controllers -{ - [Route("paypal")] - public class PayPalController : Controller - { - private readonly BillingSettings _billingSettings; - private readonly PayPalIpnClient _paypalIpnClient; - private readonly ITransactionRepository _transactionRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IUserRepository _userRepository; - private readonly IMailService _mailService; - private readonly IPaymentService _paymentService; - private readonly ILogger _logger; +namespace Bit.Billing.Controllers; - public PayPalController( - IOptions billingSettings, - PayPalIpnClient paypalIpnClient, - ITransactionRepository transactionRepository, - IOrganizationRepository organizationRepository, - IUserRepository userRepository, - IMailService mailService, - IPaymentService paymentService, - ILogger logger) +[Route("paypal")] +public class PayPalController : Controller +{ + private readonly BillingSettings _billingSettings; + private readonly PayPalIpnClient _paypalIpnClient; + private readonly ITransactionRepository _transactionRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IUserRepository _userRepository; + private readonly IMailService _mailService; + private readonly IPaymentService _paymentService; + private readonly ILogger _logger; + + public PayPalController( + IOptions billingSettings, + PayPalIpnClient paypalIpnClient, + ITransactionRepository transactionRepository, + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IMailService mailService, + IPaymentService paymentService, + ILogger logger) + { + _billingSettings = billingSettings?.Value; + _paypalIpnClient = paypalIpnClient; + _transactionRepository = transactionRepository; + _organizationRepository = organizationRepository; + _userRepository = userRepository; + _mailService = mailService; + _paymentService = paymentService; + _logger = logger; + } + + [HttpPost("ipn")] + public async Task PostIpn() + { + _logger.LogDebug("PayPal webhook has been hit."); + if (HttpContext?.Request?.Query == null) { - _billingSettings = billingSettings?.Value; - _paypalIpnClient = paypalIpnClient; - _transactionRepository = transactionRepository; - _organizationRepository = organizationRepository; - _userRepository = userRepository; - _mailService = mailService; - _paymentService = paymentService; - _logger = logger; + return new BadRequestResult(); } - [HttpPost("ipn")] - public async Task PostIpn() + var key = HttpContext.Request.Query.ContainsKey("key") ? + HttpContext.Request.Query["key"].ToString() : null; + if (!CoreHelpers.FixedTimeEquals(key, _billingSettings.PayPal.WebhookKey)) { - _logger.LogDebug("PayPal webhook has been hit."); - if (HttpContext?.Request?.Query == null) - { - return new BadRequestResult(); - } + _logger.LogWarning("PayPal webhook key is incorrect or does not exist."); + return new BadRequestResult(); + } - var key = HttpContext.Request.Query.ContainsKey("key") ? - HttpContext.Request.Query["key"].ToString() : null; - if (!CoreHelpers.FixedTimeEquals(key, _billingSettings.PayPal.WebhookKey)) - { - _logger.LogWarning("PayPal webhook key is incorrect or does not exist."); - return new BadRequestResult(); - } + string body = null; + using (var reader = new StreamReader(HttpContext.Request.Body, Encoding.UTF8)) + { + body = await reader.ReadToEndAsync(); + } - string body = null; - using (var reader = new StreamReader(HttpContext.Request.Body, Encoding.UTF8)) - { - body = await reader.ReadToEndAsync(); - } + if (string.IsNullOrWhiteSpace(body)) + { + return new BadRequestResult(); + } - if (string.IsNullOrWhiteSpace(body)) - { - return new BadRequestResult(); - } - - var verified = await _paypalIpnClient.VerifyIpnAsync(body); - if (!verified) - { - _logger.LogWarning("Unverified IPN received."); - return new BadRequestResult(); - } - - var ipnTransaction = new PayPalIpnClient.IpnTransaction(body); - if (ipnTransaction.TxnType != "web_accept" && ipnTransaction.TxnType != "merch_pmt" && - ipnTransaction.PaymentStatus != "Refunded") - { - // Only processing billing agreement payments, buy now button payments, and refunds for now. - return new OkResult(); - } - - if (ipnTransaction.ReceiverId != _billingSettings.PayPal.BusinessId) - { - _logger.LogWarning("Receiver was not proper business id. " + ipnTransaction.ReceiverId); - return new BadRequestResult(); - } - - if (ipnTransaction.PaymentStatus == "Refunded" && ipnTransaction.ParentTxnId == null) - { - // Refunds require parent transaction - return new OkResult(); - } - - if (ipnTransaction.PaymentType == "echeck" && ipnTransaction.PaymentStatus != "Refunded") - { - // Not accepting eChecks, unless it is a refund - _logger.LogWarning("Got an eCheck payment. " + ipnTransaction.TxnId); - return new OkResult(); - } - - if (ipnTransaction.McCurrency != "USD") - { - // Only process USD payments - _logger.LogWarning("Received a payment not in USD. " + ipnTransaction.TxnId); - return new OkResult(); - } - - var ids = ipnTransaction.GetIdsFromCustom(); - if (!ids.Item1.HasValue && !ids.Item2.HasValue) - { - return new OkResult(); - } - - if (ipnTransaction.PaymentStatus == "Completed") - { - var transaction = await _transactionRepository.GetByGatewayIdAsync( - GatewayType.PayPal, ipnTransaction.TxnId); - if (transaction != null) - { - _logger.LogWarning("Already processed this completed transaction. #" + ipnTransaction.TxnId); - return new OkResult(); - } - - var isAccountCredit = ipnTransaction.IsAccountCredit(); - try - { - var tx = new Transaction - { - Amount = ipnTransaction.McGross, - CreationDate = ipnTransaction.PaymentDate, - OrganizationId = ids.Item1, - UserId = ids.Item2, - Type = isAccountCredit ? TransactionType.Credit : TransactionType.Charge, - Gateway = GatewayType.PayPal, - GatewayId = ipnTransaction.TxnId, - PaymentMethodType = PaymentMethodType.PayPal, - Details = ipnTransaction.TxnId - }; - await _transactionRepository.CreateAsync(tx); - - if (isAccountCredit) - { - string billingEmail = null; - if (tx.OrganizationId.HasValue) - { - var org = await _organizationRepository.GetByIdAsync(tx.OrganizationId.Value); - if (org != null) - { - billingEmail = org.BillingEmailAddress(); - if (await _paymentService.CreditAccountAsync(org, tx.Amount)) - { - await _organizationRepository.ReplaceAsync(org); - } - } - } - else - { - var user = await _userRepository.GetByIdAsync(tx.UserId.Value); - if (user != null) - { - billingEmail = user.BillingEmailAddress(); - if (await _paymentService.CreditAccountAsync(user, tx.Amount)) - { - await _userRepository.ReplaceAsync(user); - } - } - } - - if (!string.IsNullOrWhiteSpace(billingEmail)) - { - await _mailService.SendAddedCreditAsync(billingEmail, tx.Amount); - } - } - } - // Catch foreign key violations because user/org could have been deleted. - catch (SqlException e) when (e.Number == 547) { } - } - else if (ipnTransaction.PaymentStatus == "Refunded" || ipnTransaction.PaymentStatus == "Reversed") - { - var refundTransaction = await _transactionRepository.GetByGatewayIdAsync( - GatewayType.PayPal, ipnTransaction.TxnId); - if (refundTransaction != null) - { - _logger.LogWarning("Already processed this refunded transaction. #" + ipnTransaction.TxnId); - return new OkResult(); - } - - var parentTransaction = await _transactionRepository.GetByGatewayIdAsync( - GatewayType.PayPal, ipnTransaction.ParentTxnId); - if (parentTransaction == null) - { - _logger.LogWarning("Parent transaction was not found. " + ipnTransaction.TxnId); - return new BadRequestResult(); - } - - var refundAmount = System.Math.Abs(ipnTransaction.McGross); - var remainingAmount = parentTransaction.Amount - - parentTransaction.RefundedAmount.GetValueOrDefault(); - if (refundAmount > 0 && !parentTransaction.Refunded.GetValueOrDefault() && - remainingAmount >= refundAmount) - { - parentTransaction.RefundedAmount = - parentTransaction.RefundedAmount.GetValueOrDefault() + refundAmount; - if (parentTransaction.RefundedAmount == parentTransaction.Amount) - { - parentTransaction.Refunded = true; - } - - await _transactionRepository.ReplaceAsync(parentTransaction); - await _transactionRepository.CreateAsync(new Transaction - { - Amount = refundAmount, - CreationDate = ipnTransaction.PaymentDate, - OrganizationId = ids.Item1, - UserId = ids.Item2, - Type = TransactionType.Refund, - Gateway = GatewayType.PayPal, - GatewayId = ipnTransaction.TxnId, - PaymentMethodType = PaymentMethodType.PayPal, - Details = ipnTransaction.TxnId - }); - } - } + var verified = await _paypalIpnClient.VerifyIpnAsync(body); + if (!verified) + { + _logger.LogWarning("Unverified IPN received."); + return new BadRequestResult(); + } + var ipnTransaction = new PayPalIpnClient.IpnTransaction(body); + if (ipnTransaction.TxnType != "web_accept" && ipnTransaction.TxnType != "merch_pmt" && + ipnTransaction.PaymentStatus != "Refunded") + { + // Only processing billing agreement payments, buy now button payments, and refunds for now. return new OkResult(); } + + if (ipnTransaction.ReceiverId != _billingSettings.PayPal.BusinessId) + { + _logger.LogWarning("Receiver was not proper business id. " + ipnTransaction.ReceiverId); + return new BadRequestResult(); + } + + if (ipnTransaction.PaymentStatus == "Refunded" && ipnTransaction.ParentTxnId == null) + { + // Refunds require parent transaction + return new OkResult(); + } + + if (ipnTransaction.PaymentType == "echeck" && ipnTransaction.PaymentStatus != "Refunded") + { + // Not accepting eChecks, unless it is a refund + _logger.LogWarning("Got an eCheck payment. " + ipnTransaction.TxnId); + return new OkResult(); + } + + if (ipnTransaction.McCurrency != "USD") + { + // Only process USD payments + _logger.LogWarning("Received a payment not in USD. " + ipnTransaction.TxnId); + return new OkResult(); + } + + var ids = ipnTransaction.GetIdsFromCustom(); + if (!ids.Item1.HasValue && !ids.Item2.HasValue) + { + return new OkResult(); + } + + if (ipnTransaction.PaymentStatus == "Completed") + { + var transaction = await _transactionRepository.GetByGatewayIdAsync( + GatewayType.PayPal, ipnTransaction.TxnId); + if (transaction != null) + { + _logger.LogWarning("Already processed this completed transaction. #" + ipnTransaction.TxnId); + return new OkResult(); + } + + var isAccountCredit = ipnTransaction.IsAccountCredit(); + try + { + var tx = new Transaction + { + Amount = ipnTransaction.McGross, + CreationDate = ipnTransaction.PaymentDate, + OrganizationId = ids.Item1, + UserId = ids.Item2, + Type = isAccountCredit ? TransactionType.Credit : TransactionType.Charge, + Gateway = GatewayType.PayPal, + GatewayId = ipnTransaction.TxnId, + PaymentMethodType = PaymentMethodType.PayPal, + Details = ipnTransaction.TxnId + }; + await _transactionRepository.CreateAsync(tx); + + if (isAccountCredit) + { + string billingEmail = null; + if (tx.OrganizationId.HasValue) + { + var org = await _organizationRepository.GetByIdAsync(tx.OrganizationId.Value); + if (org != null) + { + billingEmail = org.BillingEmailAddress(); + if (await _paymentService.CreditAccountAsync(org, tx.Amount)) + { + await _organizationRepository.ReplaceAsync(org); + } + } + } + else + { + var user = await _userRepository.GetByIdAsync(tx.UserId.Value); + if (user != null) + { + billingEmail = user.BillingEmailAddress(); + if (await _paymentService.CreditAccountAsync(user, tx.Amount)) + { + await _userRepository.ReplaceAsync(user); + } + } + } + + if (!string.IsNullOrWhiteSpace(billingEmail)) + { + await _mailService.SendAddedCreditAsync(billingEmail, tx.Amount); + } + } + } + // Catch foreign key violations because user/org could have been deleted. + catch (SqlException e) when (e.Number == 547) { } + } + else if (ipnTransaction.PaymentStatus == "Refunded" || ipnTransaction.PaymentStatus == "Reversed") + { + var refundTransaction = await _transactionRepository.GetByGatewayIdAsync( + GatewayType.PayPal, ipnTransaction.TxnId); + if (refundTransaction != null) + { + _logger.LogWarning("Already processed this refunded transaction. #" + ipnTransaction.TxnId); + return new OkResult(); + } + + var parentTransaction = await _transactionRepository.GetByGatewayIdAsync( + GatewayType.PayPal, ipnTransaction.ParentTxnId); + if (parentTransaction == null) + { + _logger.LogWarning("Parent transaction was not found. " + ipnTransaction.TxnId); + return new BadRequestResult(); + } + + var refundAmount = System.Math.Abs(ipnTransaction.McGross); + var remainingAmount = parentTransaction.Amount - + parentTransaction.RefundedAmount.GetValueOrDefault(); + if (refundAmount > 0 && !parentTransaction.Refunded.GetValueOrDefault() && + remainingAmount >= refundAmount) + { + parentTransaction.RefundedAmount = + parentTransaction.RefundedAmount.GetValueOrDefault() + refundAmount; + if (parentTransaction.RefundedAmount == parentTransaction.Amount) + { + parentTransaction.Refunded = true; + } + + await _transactionRepository.ReplaceAsync(parentTransaction); + await _transactionRepository.CreateAsync(new Transaction + { + Amount = refundAmount, + CreationDate = ipnTransaction.PaymentDate, + OrganizationId = ids.Item1, + UserId = ids.Item2, + Type = TransactionType.Refund, + Gateway = GatewayType.PayPal, + GatewayId = ipnTransaction.TxnId, + PaymentMethodType = PaymentMethodType.PayPal, + Details = ipnTransaction.TxnId + }); + } + } + + return new OkResult(); } } diff --git a/src/Billing/Controllers/StripeController.cs b/src/Billing/Controllers/StripeController.cs index 4cabb9645..d9f3bc744 100644 --- a/src/Billing/Controllers/StripeController.cs +++ b/src/Billing/Controllers/StripeController.cs @@ -13,826 +13,825 @@ using Microsoft.Extensions.Options; using Stripe; using TaxRate = Bit.Core.Entities.TaxRate; -namespace Bit.Billing.Controllers +namespace Bit.Billing.Controllers; + +[Route("stripe")] +public class StripeController : Controller { - [Route("stripe")] - public class StripeController : Controller + private const decimal PremiumPlanAppleIapPrice = 14.99M; + private const string PremiumPlanId = "premium-annually"; + private const string PremiumPlanIdAppStore = "premium-annually-app"; + + private readonly BillingSettings _billingSettings; + private readonly IWebHostEnvironment _hostingEnvironment; + private readonly IOrganizationService _organizationService; + private readonly IValidateSponsorshipCommand _validateSponsorshipCommand; + private readonly IOrganizationSponsorshipRenewCommand _organizationSponsorshipRenewCommand; + private readonly IOrganizationRepository _organizationRepository; + private readonly ITransactionRepository _transactionRepository; + private readonly IUserService _userService; + private readonly IAppleIapService _appleIapService; + private readonly IMailService _mailService; + private readonly ILogger _logger; + private readonly Braintree.BraintreeGateway _btGateway; + private readonly IReferenceEventService _referenceEventService; + private readonly ITaxRateRepository _taxRateRepository; + private readonly IUserRepository _userRepository; + + public StripeController( + GlobalSettings globalSettings, + IOptions billingSettings, + IWebHostEnvironment hostingEnvironment, + IOrganizationService organizationService, + IValidateSponsorshipCommand validateSponsorshipCommand, + IOrganizationSponsorshipRenewCommand organizationSponsorshipRenewCommand, + IOrganizationRepository organizationRepository, + ITransactionRepository transactionRepository, + IUserService userService, + IAppleIapService appleIapService, + IMailService mailService, + IReferenceEventService referenceEventService, + ILogger logger, + ITaxRateRepository taxRateRepository, + IUserRepository userRepository) { - private const decimal PremiumPlanAppleIapPrice = 14.99M; - private const string PremiumPlanId = "premium-annually"; - private const string PremiumPlanIdAppStore = "premium-annually-app"; - - private readonly BillingSettings _billingSettings; - private readonly IWebHostEnvironment _hostingEnvironment; - private readonly IOrganizationService _organizationService; - private readonly IValidateSponsorshipCommand _validateSponsorshipCommand; - private readonly IOrganizationSponsorshipRenewCommand _organizationSponsorshipRenewCommand; - private readonly IOrganizationRepository _organizationRepository; - private readonly ITransactionRepository _transactionRepository; - private readonly IUserService _userService; - private readonly IAppleIapService _appleIapService; - private readonly IMailService _mailService; - private readonly ILogger _logger; - private readonly Braintree.BraintreeGateway _btGateway; - private readonly IReferenceEventService _referenceEventService; - private readonly ITaxRateRepository _taxRateRepository; - private readonly IUserRepository _userRepository; - - public StripeController( - GlobalSettings globalSettings, - IOptions billingSettings, - IWebHostEnvironment hostingEnvironment, - IOrganizationService organizationService, - IValidateSponsorshipCommand validateSponsorshipCommand, - IOrganizationSponsorshipRenewCommand organizationSponsorshipRenewCommand, - IOrganizationRepository organizationRepository, - ITransactionRepository transactionRepository, - IUserService userService, - IAppleIapService appleIapService, - IMailService mailService, - IReferenceEventService referenceEventService, - ILogger logger, - ITaxRateRepository taxRateRepository, - IUserRepository userRepository) + _billingSettings = billingSettings?.Value; + _hostingEnvironment = hostingEnvironment; + _organizationService = organizationService; + _validateSponsorshipCommand = validateSponsorshipCommand; + _organizationSponsorshipRenewCommand = organizationSponsorshipRenewCommand; + _organizationRepository = organizationRepository; + _transactionRepository = transactionRepository; + _userService = userService; + _appleIapService = appleIapService; + _mailService = mailService; + _referenceEventService = referenceEventService; + _taxRateRepository = taxRateRepository; + _userRepository = userRepository; + _logger = logger; + _btGateway = new Braintree.BraintreeGateway { - _billingSettings = billingSettings?.Value; - _hostingEnvironment = hostingEnvironment; - _organizationService = organizationService; - _validateSponsorshipCommand = validateSponsorshipCommand; - _organizationSponsorshipRenewCommand = organizationSponsorshipRenewCommand; - _organizationRepository = organizationRepository; - _transactionRepository = transactionRepository; - _userService = userService; - _appleIapService = appleIapService; - _mailService = mailService; - _referenceEventService = referenceEventService; - _taxRateRepository = taxRateRepository; - _userRepository = userRepository; - _logger = logger; - _btGateway = new Braintree.BraintreeGateway - { - Environment = globalSettings.Braintree.Production ? - Braintree.Environment.PRODUCTION : Braintree.Environment.SANDBOX, - MerchantId = globalSettings.Braintree.MerchantId, - PublicKey = globalSettings.Braintree.PublicKey, - PrivateKey = globalSettings.Braintree.PrivateKey - }; + Environment = globalSettings.Braintree.Production ? + Braintree.Environment.PRODUCTION : Braintree.Environment.SANDBOX, + MerchantId = globalSettings.Braintree.MerchantId, + PublicKey = globalSettings.Braintree.PublicKey, + PrivateKey = globalSettings.Braintree.PrivateKey + }; + } + + [HttpPost("webhook")] + public async Task PostWebhook([FromQuery] string key) + { + if (!CoreHelpers.FixedTimeEquals(key, _billingSettings.StripeWebhookKey)) + { + return new BadRequestResult(); } - [HttpPost("webhook")] - public async Task PostWebhook([FromQuery] string key) + Stripe.Event parsedEvent; + using (var sr = new StreamReader(HttpContext.Request.Body)) { - if (!CoreHelpers.FixedTimeEquals(key, _billingSettings.StripeWebhookKey)) + var json = await sr.ReadToEndAsync(); + parsedEvent = EventUtility.ConstructEvent(json, Request.Headers["Stripe-Signature"], + _billingSettings.StripeWebhookSecret, + throwOnApiVersionMismatch: _billingSettings.StripeEventParseThrowMismatch); + } + + if (string.IsNullOrWhiteSpace(parsedEvent?.Id)) + { + _logger.LogWarning("No event id."); + return new BadRequestResult(); + } + + if (_hostingEnvironment.IsProduction() && !parsedEvent.Livemode) + { + _logger.LogWarning("Getting test events in production."); + return new BadRequestResult(); + } + + var subDeleted = parsedEvent.Type.Equals(HandledStripeWebhook.SubscriptionDeleted); + var subUpdated = parsedEvent.Type.Equals(HandledStripeWebhook.SubscriptionUpdated); + + if (subDeleted || subUpdated) + { + var subscription = await GetSubscriptionAsync(parsedEvent, true); + var ids = GetIdsFromMetaData(subscription.Metadata); + + var subCanceled = subDeleted && subscription.Status == "canceled"; + var subUnpaid = subUpdated && subscription.Status == "unpaid"; + var subIncompleteExpired = subUpdated && subscription.Status == "incomplete_expired"; + + if (subCanceled || subUnpaid || subIncompleteExpired) { - return new BadRequestResult(); - } - - Stripe.Event parsedEvent; - using (var sr = new StreamReader(HttpContext.Request.Body)) - { - var json = await sr.ReadToEndAsync(); - parsedEvent = EventUtility.ConstructEvent(json, Request.Headers["Stripe-Signature"], - _billingSettings.StripeWebhookSecret, - throwOnApiVersionMismatch: _billingSettings.StripeEventParseThrowMismatch); - } - - if (string.IsNullOrWhiteSpace(parsedEvent?.Id)) - { - _logger.LogWarning("No event id."); - return new BadRequestResult(); - } - - if (_hostingEnvironment.IsProduction() && !parsedEvent.Livemode) - { - _logger.LogWarning("Getting test events in production."); - return new BadRequestResult(); - } - - var subDeleted = parsedEvent.Type.Equals(HandledStripeWebhook.SubscriptionDeleted); - var subUpdated = parsedEvent.Type.Equals(HandledStripeWebhook.SubscriptionUpdated); - - if (subDeleted || subUpdated) - { - var subscription = await GetSubscriptionAsync(parsedEvent, true); - var ids = GetIdsFromMetaData(subscription.Metadata); - - var subCanceled = subDeleted && subscription.Status == "canceled"; - var subUnpaid = subUpdated && subscription.Status == "unpaid"; - var subIncompleteExpired = subUpdated && subscription.Status == "incomplete_expired"; - - if (subCanceled || subUnpaid || subIncompleteExpired) - { - // org - if (ids.Item1.HasValue) - { - await _organizationService.DisableAsync(ids.Item1.Value, subscription.CurrentPeriodEnd); - } - // user - else if (ids.Item2.HasValue) - { - await _userService.DisablePremiumAsync(ids.Item2.Value, subscription.CurrentPeriodEnd); - } - } - - if (subUpdated) - { - // org - if (ids.Item1.HasValue) - { - await _organizationService.UpdateExpirationDateAsync(ids.Item1.Value, - subscription.CurrentPeriodEnd); - if (IsSponsoredSubscription(subscription)) - { - await _organizationSponsorshipRenewCommand.UpdateExpirationDateAsync(ids.Item1.Value, subscription.CurrentPeriodEnd); - } - } - // user - else if (ids.Item2.HasValue) - { - await _userService.UpdatePremiumExpirationAsync(ids.Item2.Value, - subscription.CurrentPeriodEnd); - } - } - } - else if (parsedEvent.Type.Equals(HandledStripeWebhook.UpcomingInvoice)) - { - var invoice = await GetInvoiceAsync(parsedEvent); - var subscriptionService = new SubscriptionService(); - var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); - if (subscription == null) - { - throw new Exception("Invoice subscription is null. " + invoice.Id); - } - - subscription = await VerifyCorrectTaxRateForCharge(invoice, subscription); - - string email = null; - var ids = GetIdsFromMetaData(subscription.Metadata); // org if (ids.Item1.HasValue) { - // sponsored org + await _organizationService.DisableAsync(ids.Item1.Value, subscription.CurrentPeriodEnd); + } + // user + else if (ids.Item2.HasValue) + { + await _userService.DisablePremiumAsync(ids.Item2.Value, subscription.CurrentPeriodEnd); + } + } + + if (subUpdated) + { + // org + if (ids.Item1.HasValue) + { + await _organizationService.UpdateExpirationDateAsync(ids.Item1.Value, + subscription.CurrentPeriodEnd); if (IsSponsoredSubscription(subscription)) { - await _validateSponsorshipCommand.ValidateSponsorshipAsync(ids.Item1.Value); - } - - var org = await _organizationRepository.GetByIdAsync(ids.Item1.Value); - if (org != null && OrgPlanForInvoiceNotifications(org)) - { - email = org.BillingEmail; + await _organizationSponsorshipRenewCommand.UpdateExpirationDateAsync(ids.Item1.Value, subscription.CurrentPeriodEnd); } } // user else if (ids.Item2.HasValue) { - var user = await _userService.GetUserByIdAsync(ids.Item2.Value); - if (user.Premium) - { - email = user.Email; - } - } - - if (!string.IsNullOrWhiteSpace(email) && invoice.NextPaymentAttempt.HasValue) - { - var items = invoice.Lines.Select(i => i.Description).ToList(); - await _mailService.SendInvoiceUpcomingAsync(email, invoice.AmountDue / 100M, - invoice.NextPaymentAttempt.Value, items, true); + await _userService.UpdatePremiumExpirationAsync(ids.Item2.Value, + subscription.CurrentPeriodEnd); } } - else if (parsedEvent.Type.Equals(HandledStripeWebhook.ChargeSucceeded)) + } + else if (parsedEvent.Type.Equals(HandledStripeWebhook.UpcomingInvoice)) + { + var invoice = await GetInvoiceAsync(parsedEvent); + var subscriptionService = new SubscriptionService(); + var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); + if (subscription == null) { - var charge = await GetChargeAsync(parsedEvent); - var chargeTransaction = await _transactionRepository.GetByGatewayIdAsync( - GatewayType.Stripe, charge.Id); - if (chargeTransaction != null) + throw new Exception("Invoice subscription is null. " + invoice.Id); + } + + subscription = await VerifyCorrectTaxRateForCharge(invoice, subscription); + + string email = null; + var ids = GetIdsFromMetaData(subscription.Metadata); + // org + if (ids.Item1.HasValue) + { + // sponsored org + if (IsSponsoredSubscription(subscription)) { - _logger.LogWarning("Charge success already processed. " + charge.Id); - return new OkResult(); + await _validateSponsorshipCommand.ValidateSponsorshipAsync(ids.Item1.Value); } - Tuple ids = null; - Subscription subscription = null; - var subscriptionService = new SubscriptionService(); - - if (charge.InvoiceId != null) + var org = await _organizationRepository.GetByIdAsync(ids.Item1.Value); + if (org != null && OrgPlanForInvoiceNotifications(org)) { - var invoiceService = new InvoiceService(); - var invoice = await invoiceService.GetAsync(charge.InvoiceId); - if (invoice?.SubscriptionId != null) - { - subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); - ids = GetIdsFromMetaData(subscription?.Metadata); - } + email = org.BillingEmail; } - - if (subscription == null || ids == null || (ids.Item1.HasValue && ids.Item2.HasValue)) + } + // user + else if (ids.Item2.HasValue) + { + var user = await _userService.GetUserByIdAsync(ids.Item2.Value); + if (user.Premium) { - var subscriptions = await subscriptionService.ListAsync(new SubscriptionListOptions + email = user.Email; + } + } + + if (!string.IsNullOrWhiteSpace(email) && invoice.NextPaymentAttempt.HasValue) + { + var items = invoice.Lines.Select(i => i.Description).ToList(); + await _mailService.SendInvoiceUpcomingAsync(email, invoice.AmountDue / 100M, + invoice.NextPaymentAttempt.Value, items, true); + } + } + else if (parsedEvent.Type.Equals(HandledStripeWebhook.ChargeSucceeded)) + { + var charge = await GetChargeAsync(parsedEvent); + var chargeTransaction = await _transactionRepository.GetByGatewayIdAsync( + GatewayType.Stripe, charge.Id); + if (chargeTransaction != null) + { + _logger.LogWarning("Charge success already processed. " + charge.Id); + return new OkResult(); + } + + Tuple ids = null; + Subscription subscription = null; + var subscriptionService = new SubscriptionService(); + + if (charge.InvoiceId != null) + { + var invoiceService = new InvoiceService(); + var invoice = await invoiceService.GetAsync(charge.InvoiceId); + if (invoice?.SubscriptionId != null) + { + subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); + ids = GetIdsFromMetaData(subscription?.Metadata); + } + } + + if (subscription == null || ids == null || (ids.Item1.HasValue && ids.Item2.HasValue)) + { + var subscriptions = await subscriptionService.ListAsync(new SubscriptionListOptions + { + Customer = charge.CustomerId + }); + foreach (var sub in subscriptions) + { + if (sub.Status != "canceled" && sub.Status != "incomplete_expired") { - Customer = charge.CustomerId - }); - foreach (var sub in subscriptions) - { - if (sub.Status != "canceled" && sub.Status != "incomplete_expired") + ids = GetIdsFromMetaData(sub.Metadata); + if (ids.Item1.HasValue || ids.Item2.HasValue) { - ids = GetIdsFromMetaData(sub.Metadata); - if (ids.Item1.HasValue || ids.Item2.HasValue) - { - subscription = sub; - break; - } + subscription = sub; + break; } } } + } - if (!ids.Item1.HasValue && !ids.Item2.HasValue) - { - _logger.LogWarning("Charge success has no subscriber ids. " + charge.Id); - return new BadRequestResult(); - } + if (!ids.Item1.HasValue && !ids.Item2.HasValue) + { + _logger.LogWarning("Charge success has no subscriber ids. " + charge.Id); + return new BadRequestResult(); + } - var tx = new Transaction - { - Amount = charge.Amount / 100M, - CreationDate = charge.Created, - OrganizationId = ids.Item1, - UserId = ids.Item2, - Type = TransactionType.Charge, - Gateway = GatewayType.Stripe, - GatewayId = charge.Id - }; + var tx = new Transaction + { + Amount = charge.Amount / 100M, + CreationDate = charge.Created, + OrganizationId = ids.Item1, + UserId = ids.Item2, + Type = TransactionType.Charge, + Gateway = GatewayType.Stripe, + GatewayId = charge.Id + }; - if (charge.Source != null && charge.Source is Card card) + if (charge.Source != null && charge.Source is Card card) + { + tx.PaymentMethodType = PaymentMethodType.Card; + tx.Details = $"{card.Brand}, *{card.Last4}"; + } + else if (charge.Source != null && charge.Source is BankAccount bankAccount) + { + tx.PaymentMethodType = PaymentMethodType.BankAccount; + tx.Details = $"{bankAccount.BankName}, *{bankAccount.Last4}"; + } + else if (charge.Source != null && charge.Source is Source source) + { + if (source.Card != null) { tx.PaymentMethodType = PaymentMethodType.Card; - tx.Details = $"{card.Brand}, *{card.Last4}"; + tx.Details = $"{source.Card.Brand}, *{source.Card.Last4}"; } - else if (charge.Source != null && charge.Source is BankAccount bankAccount) + else if (source.AchDebit != null) { tx.PaymentMethodType = PaymentMethodType.BankAccount; - tx.Details = $"{bankAccount.BankName}, *{bankAccount.Last4}"; + tx.Details = $"{source.AchDebit.BankName}, *{source.AchDebit.Last4}"; } - else if (charge.Source != null && charge.Source is Source source) + else if (source.AchCreditTransfer != null) { - if (source.Card != null) - { - tx.PaymentMethodType = PaymentMethodType.Card; - tx.Details = $"{source.Card.Brand}, *{source.Card.Last4}"; - } - else if (source.AchDebit != null) - { - tx.PaymentMethodType = PaymentMethodType.BankAccount; - tx.Details = $"{source.AchDebit.BankName}, *{source.AchDebit.Last4}"; - } - else if (source.AchCreditTransfer != null) - { - tx.PaymentMethodType = PaymentMethodType.BankAccount; - tx.Details = $"ACH => {source.AchCreditTransfer.BankName}, " + - $"{source.AchCreditTransfer.AccountNumber}"; - } - } - else if (charge.PaymentMethodDetails != null) - { - if (charge.PaymentMethodDetails.Card != null) - { - tx.PaymentMethodType = PaymentMethodType.Card; - tx.Details = $"{charge.PaymentMethodDetails.Card.Brand?.ToUpperInvariant()}, " + - $"*{charge.PaymentMethodDetails.Card.Last4}"; - } - else if (charge.PaymentMethodDetails.AchDebit != null) - { - tx.PaymentMethodType = PaymentMethodType.BankAccount; - tx.Details = $"{charge.PaymentMethodDetails.AchDebit.BankName}, " + - $"*{charge.PaymentMethodDetails.AchDebit.Last4}"; - } - else if (charge.PaymentMethodDetails.AchCreditTransfer != null) - { - tx.PaymentMethodType = PaymentMethodType.BankAccount; - tx.Details = $"ACH => {charge.PaymentMethodDetails.AchCreditTransfer.BankName}, " + - $"{charge.PaymentMethodDetails.AchCreditTransfer.AccountNumber}"; - } - } - - if (!tx.PaymentMethodType.HasValue) - { - _logger.LogWarning("Charge success from unsupported source/method. " + charge.Id); - return new OkResult(); - } - - try - { - await _transactionRepository.CreateAsync(tx); - } - // Catch foreign key violations because user/org could have been deleted. - catch (SqlException e) when (e.Number == 547) { } - } - else if (parsedEvent.Type.Equals(HandledStripeWebhook.ChargeRefunded)) - { - var charge = await GetChargeAsync(parsedEvent); - var chargeTransaction = await _transactionRepository.GetByGatewayIdAsync( - GatewayType.Stripe, charge.Id); - if (chargeTransaction == null) - { - throw new Exception("Cannot find refunded charge. " + charge.Id); - } - - var amountRefunded = charge.AmountRefunded / 100M; - - if (!chargeTransaction.Refunded.GetValueOrDefault() && - chargeTransaction.RefundedAmount.GetValueOrDefault() < amountRefunded) - { - chargeTransaction.RefundedAmount = amountRefunded; - if (charge.Refunded) - { - chargeTransaction.Refunded = true; - } - await _transactionRepository.ReplaceAsync(chargeTransaction); - - foreach (var refund in charge.Refunds) - { - var refundTransaction = await _transactionRepository.GetByGatewayIdAsync( - GatewayType.Stripe, refund.Id); - if (refundTransaction != null) - { - continue; - } - - await _transactionRepository.CreateAsync(new Transaction - { - Amount = refund.Amount / 100M, - CreationDate = refund.Created, - OrganizationId = chargeTransaction.OrganizationId, - UserId = chargeTransaction.UserId, - Type = TransactionType.Refund, - Gateway = GatewayType.Stripe, - GatewayId = refund.Id, - PaymentMethodType = chargeTransaction.PaymentMethodType, - Details = chargeTransaction.Details - }); - } - } - else - { - _logger.LogWarning("Charge refund amount doesn't seem correct. " + charge.Id); + tx.PaymentMethodType = PaymentMethodType.BankAccount; + tx.Details = $"ACH => {source.AchCreditTransfer.BankName}, " + + $"{source.AchCreditTransfer.AccountNumber}"; } } - else if (parsedEvent.Type.Equals(HandledStripeWebhook.PaymentSucceeded)) + else if (charge.PaymentMethodDetails != null) { - var invoice = await GetInvoiceAsync(parsedEvent, true); - if (invoice.Paid && invoice.BillingReason == "subscription_create") + if (charge.PaymentMethodDetails.Card != null) { - var subscriptionService = new SubscriptionService(); - var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); - if (subscription?.Status == "active") - { - if (DateTime.UtcNow - invoice.Created < TimeSpan.FromMinutes(1)) - { - await Task.Delay(5000); - } - - var ids = GetIdsFromMetaData(subscription.Metadata); - // org - if (ids.Item1.HasValue) - { - if (subscription.Items.Any(i => StaticStore.Plans.Any(p => p.StripePlanId == i.Plan.Id))) - { - await _organizationService.EnableAsync(ids.Item1.Value, subscription.CurrentPeriodEnd); - - var organization = await _organizationRepository.GetByIdAsync(ids.Item1.Value); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.Rebilled, organization) - { - PlanName = organization?.Plan, - PlanType = organization?.PlanType, - Seats = organization?.Seats, - Storage = organization?.MaxStorageGb, - }); - } - } - // user - else if (ids.Item2.HasValue) - { - if (subscription.Items.Any(i => i.Plan.Id == PremiumPlanId)) - { - await _userService.EnablePremiumAsync(ids.Item2.Value, subscription.CurrentPeriodEnd); - - var user = await _userRepository.GetByIdAsync(ids.Item2.Value); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.Rebilled, user) - { - PlanName = PremiumPlanId, - Storage = user?.MaxStorageGb, - }); - } - } - } + tx.PaymentMethodType = PaymentMethodType.Card; + tx.Details = $"{charge.PaymentMethodDetails.Card.Brand?.ToUpperInvariant()}, " + + $"*{charge.PaymentMethodDetails.Card.Last4}"; + } + else if (charge.PaymentMethodDetails.AchDebit != null) + { + tx.PaymentMethodType = PaymentMethodType.BankAccount; + tx.Details = $"{charge.PaymentMethodDetails.AchDebit.BankName}, " + + $"*{charge.PaymentMethodDetails.AchDebit.Last4}"; + } + else if (charge.PaymentMethodDetails.AchCreditTransfer != null) + { + tx.PaymentMethodType = PaymentMethodType.BankAccount; + tx.Details = $"ACH => {charge.PaymentMethodDetails.AchCreditTransfer.BankName}, " + + $"{charge.PaymentMethodDetails.AchCreditTransfer.AccountNumber}"; } } - else if (parsedEvent.Type.Equals(HandledStripeWebhook.PaymentFailed)) + + if (!tx.PaymentMethodType.HasValue) { - await HandlePaymentFailed(await GetInvoiceAsync(parsedEvent, true)); + _logger.LogWarning("Charge success from unsupported source/method. " + charge.Id); + return new OkResult(); } - else if (parsedEvent.Type.Equals(HandledStripeWebhook.InvoiceCreated)) + + try { - var invoice = await GetInvoiceAsync(parsedEvent, true); - if (!invoice.Paid && UnpaidAutoChargeInvoiceForSubscriptionCycle(invoice)) + await _transactionRepository.CreateAsync(tx); + } + // Catch foreign key violations because user/org could have been deleted. + catch (SqlException e) when (e.Number == 547) { } + } + else if (parsedEvent.Type.Equals(HandledStripeWebhook.ChargeRefunded)) + { + var charge = await GetChargeAsync(parsedEvent); + var chargeTransaction = await _transactionRepository.GetByGatewayIdAsync( + GatewayType.Stripe, charge.Id); + if (chargeTransaction == null) + { + throw new Exception("Cannot find refunded charge. " + charge.Id); + } + + var amountRefunded = charge.AmountRefunded / 100M; + + if (!chargeTransaction.Refunded.GetValueOrDefault() && + chargeTransaction.RefundedAmount.GetValueOrDefault() < amountRefunded) + { + chargeTransaction.RefundedAmount = amountRefunded; + if (charge.Refunded) { - await AttemptToPayInvoiceAsync(invoice); + chargeTransaction.Refunded = true; + } + await _transactionRepository.ReplaceAsync(chargeTransaction); + + foreach (var refund in charge.Refunds) + { + var refundTransaction = await _transactionRepository.GetByGatewayIdAsync( + GatewayType.Stripe, refund.Id); + if (refundTransaction != null) + { + continue; + } + + await _transactionRepository.CreateAsync(new Transaction + { + Amount = refund.Amount / 100M, + CreationDate = refund.Created, + OrganizationId = chargeTransaction.OrganizationId, + UserId = chargeTransaction.UserId, + Type = TransactionType.Refund, + Gateway = GatewayType.Stripe, + GatewayId = refund.Id, + PaymentMethodType = chargeTransaction.PaymentMethodType, + Details = chargeTransaction.Details + }); } } else { - _logger.LogWarning("Unsupported event received. " + parsedEvent.Type); + _logger.LogWarning("Charge refund amount doesn't seem correct. " + charge.Id); } - - return new OkResult(); } - - private Tuple GetIdsFromMetaData(IDictionary metaData) + else if (parsedEvent.Type.Equals(HandledStripeWebhook.PaymentSucceeded)) { - if (metaData == null || !metaData.Any()) + var invoice = await GetInvoiceAsync(parsedEvent, true); + if (invoice.Paid && invoice.BillingReason == "subscription_create") { - return new Tuple(null, null); - } - - Guid? orgId = null; - Guid? userId = null; - - if (metaData.ContainsKey("organizationId")) - { - orgId = new Guid(metaData["organizationId"]); - } - else if (metaData.ContainsKey("userId")) - { - userId = new Guid(metaData["userId"]); - } - - if (userId == null && orgId == null) - { - var orgIdKey = metaData.Keys.FirstOrDefault(k => k.ToLowerInvariant() == "organizationid"); - if (!string.IsNullOrWhiteSpace(orgIdKey)) + var subscriptionService = new SubscriptionService(); + var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); + if (subscription?.Status == "active") { - orgId = new Guid(metaData[orgIdKey]); - } - else - { - var userIdKey = metaData.Keys.FirstOrDefault(k => k.ToLowerInvariant() == "userid"); - if (!string.IsNullOrWhiteSpace(userIdKey)) + if (DateTime.UtcNow - invoice.Created < TimeSpan.FromMinutes(1)) { - userId = new Guid(metaData[userIdKey]); + await Task.Delay(5000); + } + + var ids = GetIdsFromMetaData(subscription.Metadata); + // org + if (ids.Item1.HasValue) + { + if (subscription.Items.Any(i => StaticStore.Plans.Any(p => p.StripePlanId == i.Plan.Id))) + { + await _organizationService.EnableAsync(ids.Item1.Value, subscription.CurrentPeriodEnd); + + var organization = await _organizationRepository.GetByIdAsync(ids.Item1.Value); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.Rebilled, organization) + { + PlanName = organization?.Plan, + PlanType = organization?.PlanType, + Seats = organization?.Seats, + Storage = organization?.MaxStorageGb, + }); + } + } + // user + else if (ids.Item2.HasValue) + { + if (subscription.Items.Any(i => i.Plan.Id == PremiumPlanId)) + { + await _userService.EnablePremiumAsync(ids.Item2.Value, subscription.CurrentPeriodEnd); + + var user = await _userRepository.GetByIdAsync(ids.Item2.Value); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.Rebilled, user) + { + PlanName = PremiumPlanId, + Storage = user?.MaxStorageGb, + }); + } } } } - - return new Tuple(orgId, userId); + } + else if (parsedEvent.Type.Equals(HandledStripeWebhook.PaymentFailed)) + { + await HandlePaymentFailed(await GetInvoiceAsync(parsedEvent, true)); + } + else if (parsedEvent.Type.Equals(HandledStripeWebhook.InvoiceCreated)) + { + var invoice = await GetInvoiceAsync(parsedEvent, true); + if (!invoice.Paid && UnpaidAutoChargeInvoiceForSubscriptionCycle(invoice)) + { + await AttemptToPayInvoiceAsync(invoice); + } + } + else + { + _logger.LogWarning("Unsupported event received. " + parsedEvent.Type); } - private bool OrgPlanForInvoiceNotifications(Organization org) + return new OkResult(); + } + + private Tuple GetIdsFromMetaData(IDictionary metaData) + { + if (metaData == null || !metaData.Any()) { - switch (org.PlanType) + return new Tuple(null, null); + } + + Guid? orgId = null; + Guid? userId = null; + + if (metaData.ContainsKey("organizationId")) + { + orgId = new Guid(metaData["organizationId"]); + } + else if (metaData.ContainsKey("userId")) + { + userId = new Guid(metaData["userId"]); + } + + if (userId == null && orgId == null) + { + var orgIdKey = metaData.Keys.FirstOrDefault(k => k.ToLowerInvariant() == "organizationid"); + if (!string.IsNullOrWhiteSpace(orgIdKey)) { - case PlanType.FamiliesAnnually: - case PlanType.TeamsAnnually: - case PlanType.EnterpriseAnnually: - return true; - default: - return false; + orgId = new Guid(metaData[orgIdKey]); + } + else + { + var userIdKey = metaData.Keys.FirstOrDefault(k => k.ToLowerInvariant() == "userid"); + if (!string.IsNullOrWhiteSpace(userIdKey)) + { + userId = new Guid(metaData[userIdKey]); + } } } - private async Task AttemptToPayInvoiceAsync(Invoice invoice) + return new Tuple(orgId, userId); + } + + private bool OrgPlanForInvoiceNotifications(Organization org) + { + switch (org.PlanType) { - var customerService = new CustomerService(); - var customer = await customerService.GetAsync(invoice.CustomerId); - if (customer?.Metadata?.ContainsKey("appleReceipt") ?? false) + case PlanType.FamiliesAnnually: + case PlanType.TeamsAnnually: + case PlanType.EnterpriseAnnually: + return true; + default: + return false; + } + } + + private async Task AttemptToPayInvoiceAsync(Invoice invoice) + { + var customerService = new CustomerService(); + var customer = await customerService.GetAsync(invoice.CustomerId); + if (customer?.Metadata?.ContainsKey("appleReceipt") ?? false) + { + return await AttemptToPayInvoiceWithAppleReceiptAsync(invoice, customer); + } + else if (customer?.Metadata?.ContainsKey("btCustomerId") ?? false) + { + return await AttemptToPayInvoiceWithBraintreeAsync(invoice, customer); + } + return false; + } + + private async Task AttemptToPayInvoiceWithAppleReceiptAsync(Invoice invoice, Customer customer) + { + if (!customer?.Metadata?.ContainsKey("appleReceipt") ?? true) + { + return false; + } + + var originalAppleReceiptTransactionId = customer.Metadata["appleReceipt"]; + var appleReceiptRecord = await _appleIapService.GetReceiptAsync(originalAppleReceiptTransactionId); + if (string.IsNullOrWhiteSpace(appleReceiptRecord?.Item1) || !appleReceiptRecord.Item2.HasValue) + { + return false; + } + + var subscriptionService = new SubscriptionService(); + var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); + var ids = GetIdsFromMetaData(subscription?.Metadata); + if (!ids.Item2.HasValue) + { + // Apple receipt is only for user subscriptions + return false; + } + + if (appleReceiptRecord.Item2.Value != ids.Item2.Value) + { + _logger.LogError("User Ids for Apple Receipt and subscription do not match: {0} != {1}.", + appleReceiptRecord.Item2.Value, ids.Item2.Value); + return false; + } + + var appleReceiptStatus = await _appleIapService.GetVerifiedReceiptStatusAsync(appleReceiptRecord.Item1); + if (appleReceiptStatus == null) + { + // TODO: cancel sub if receipt is cancelled? + return false; + } + + var receiptExpiration = appleReceiptStatus.GetLastExpiresDate().GetValueOrDefault(DateTime.MinValue); + var invoiceDue = invoice.DueDate.GetValueOrDefault(DateTime.MinValue); + if (receiptExpiration <= invoiceDue) + { + _logger.LogWarning("Apple receipt expiration is before invoice due date. {0} <= {1}", + receiptExpiration, invoiceDue); + return false; + } + + var receiptLastTransactionId = appleReceiptStatus.GetLastTransactionId(); + var existingTransaction = await _transactionRepository.GetByGatewayIdAsync( + GatewayType.AppStore, receiptLastTransactionId); + if (existingTransaction != null) + { + _logger.LogWarning("There is already an existing transaction for this Apple receipt.", + receiptLastTransactionId); + return false; + } + + var appleTransaction = appleReceiptStatus.BuildTransactionFromLastTransaction( + PremiumPlanAppleIapPrice, ids.Item2.Value); + appleTransaction.Type = TransactionType.Charge; + + var invoiceService = new InvoiceService(); + try + { + await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions { - return await AttemptToPayInvoiceWithAppleReceiptAsync(invoice, customer); + Metadata = new Dictionary + { + ["appleReceipt"] = appleReceiptStatus.GetOriginalTransactionId(), + ["appleReceiptTransactionId"] = receiptLastTransactionId + } + }); + + await _transactionRepository.CreateAsync(appleTransaction); + await invoiceService.PayAsync(invoice.Id, new InvoicePayOptions { PaidOutOfBand = true }); + } + catch (Exception e) + { + if (e.Message.Contains("Invoice is already paid")) + { + await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions + { + Metadata = invoice.Metadata + }); } - else if (customer?.Metadata?.ContainsKey("btCustomerId") ?? false) + else { - return await AttemptToPayInvoiceWithBraintreeAsync(invoice, customer); + throw; + } + } + + return true; + } + + private async Task AttemptToPayInvoiceWithBraintreeAsync(Invoice invoice, Customer customer) + { + if (!customer?.Metadata?.ContainsKey("btCustomerId") ?? true) + { + return false; + } + + var subscriptionService = new SubscriptionService(); + var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); + var ids = GetIdsFromMetaData(subscription?.Metadata); + if (!ids.Item1.HasValue && !ids.Item2.HasValue) + { + return false; + } + + var orgTransaction = ids.Item1.HasValue; + var btObjIdField = orgTransaction ? "organization_id" : "user_id"; + var btObjId = ids.Item1 ?? ids.Item2.Value; + var btInvoiceAmount = (invoice.AmountDue / 100M); + + var existingTransactions = orgTransaction ? + await _transactionRepository.GetManyByOrganizationIdAsync(ids.Item1.Value) : + await _transactionRepository.GetManyByUserIdAsync(ids.Item2.Value); + var duplicateTimeSpan = TimeSpan.FromHours(24); + var now = DateTime.UtcNow; + var duplicateTransaction = existingTransactions? + .FirstOrDefault(t => (now - t.CreationDate) < duplicateTimeSpan); + if (duplicateTransaction != null) + { + _logger.LogWarning("There is already a recent PayPal transaction ({0}). " + + "Do not charge again to prevent possible duplicate.", duplicateTransaction.GatewayId); + return false; + } + + var transactionResult = await _btGateway.Transaction.SaleAsync( + new Braintree.TransactionRequest + { + Amount = btInvoiceAmount, + CustomerId = customer.Metadata["btCustomerId"], + Options = new Braintree.TransactionOptionsRequest + { + SubmitForSettlement = true, + PayPal = new Braintree.TransactionOptionsPayPalRequest + { + CustomField = $"{btObjIdField}:{btObjId}" + } + }, + CustomFields = new Dictionary + { + [btObjIdField] = btObjId.ToString() + } + }); + + if (!transactionResult.IsSuccess()) + { + if (invoice.AttemptCount < 4) + { + await _mailService.SendPaymentFailedAsync(customer.Email, btInvoiceAmount, true); } return false; } - private async Task AttemptToPayInvoiceWithAppleReceiptAsync(Invoice invoice, Customer customer) + var invoiceService = new InvoiceService(); + try { - if (!customer?.Metadata?.ContainsKey("appleReceipt") ?? true) + await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions { - return false; - } - - var originalAppleReceiptTransactionId = customer.Metadata["appleReceipt"]; - var appleReceiptRecord = await _appleIapService.GetReceiptAsync(originalAppleReceiptTransactionId); - if (string.IsNullOrWhiteSpace(appleReceiptRecord?.Item1) || !appleReceiptRecord.Item2.HasValue) - { - return false; - } - - var subscriptionService = new SubscriptionService(); - var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); - var ids = GetIdsFromMetaData(subscription?.Metadata); - if (!ids.Item2.HasValue) - { - // Apple receipt is only for user subscriptions - return false; - } - - if (appleReceiptRecord.Item2.Value != ids.Item2.Value) - { - _logger.LogError("User Ids for Apple Receipt and subscription do not match: {0} != {1}.", - appleReceiptRecord.Item2.Value, ids.Item2.Value); - return false; - } - - var appleReceiptStatus = await _appleIapService.GetVerifiedReceiptStatusAsync(appleReceiptRecord.Item1); - if (appleReceiptStatus == null) - { - // TODO: cancel sub if receipt is cancelled? - return false; - } - - var receiptExpiration = appleReceiptStatus.GetLastExpiresDate().GetValueOrDefault(DateTime.MinValue); - var invoiceDue = invoice.DueDate.GetValueOrDefault(DateTime.MinValue); - if (receiptExpiration <= invoiceDue) - { - _logger.LogWarning("Apple receipt expiration is before invoice due date. {0} <= {1}", - receiptExpiration, invoiceDue); - return false; - } - - var receiptLastTransactionId = appleReceiptStatus.GetLastTransactionId(); - var existingTransaction = await _transactionRepository.GetByGatewayIdAsync( - GatewayType.AppStore, receiptLastTransactionId); - if (existingTransaction != null) - { - _logger.LogWarning("There is already an existing transaction for this Apple receipt.", - receiptLastTransactionId); - return false; - } - - var appleTransaction = appleReceiptStatus.BuildTransactionFromLastTransaction( - PremiumPlanAppleIapPrice, ids.Item2.Value); - appleTransaction.Type = TransactionType.Charge; - - var invoiceService = new InvoiceService(); - try + Metadata = new Dictionary + { + ["btTransactionId"] = transactionResult.Target.Id, + ["btPayPalTransactionId"] = + transactionResult.Target.PayPalDetails?.AuthorizationId + } + }); + await invoiceService.PayAsync(invoice.Id, new InvoicePayOptions { PaidOutOfBand = true }); + } + catch (Exception e) + { + await _btGateway.Transaction.RefundAsync(transactionResult.Target.Id); + if (e.Message.Contains("Invoice is already paid")) { await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions { - Metadata = new Dictionary - { - ["appleReceipt"] = appleReceiptStatus.GetOriginalTransactionId(), - ["appleReceiptTransactionId"] = receiptLastTransactionId - } + Metadata = invoice.Metadata }); - - await _transactionRepository.CreateAsync(appleTransaction); - await invoiceService.PayAsync(invoice.Id, new InvoicePayOptions { PaidOutOfBand = true }); } - catch (Exception e) + else { - if (e.Message.Contains("Invoice is already paid")) - { - await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions - { - Metadata = invoice.Metadata - }); - } - else - { - throw; - } + throw; } - - return true; } - private async Task AttemptToPayInvoiceWithBraintreeAsync(Invoice invoice, Customer customer) - { - if (!customer?.Metadata?.ContainsKey("btCustomerId") ?? true) - { - return false; - } + return true; + } + private bool UnpaidAutoChargeInvoiceForSubscriptionCycle(Invoice invoice) + { + return invoice.AmountDue > 0 && !invoice.Paid && invoice.CollectionMethod == "charge_automatically" && + invoice.BillingReason == "subscription_cycle" && invoice.SubscriptionId != null; + } + + private async Task GetChargeAsync(Stripe.Event parsedEvent, bool fresh = false) + { + if (!(parsedEvent.Data.Object is Charge eventCharge)) + { + throw new Exception("Charge is null (from parsed event). " + parsedEvent.Id); + } + if (!fresh) + { + return eventCharge; + } + var chargeService = new ChargeService(); + var charge = await chargeService.GetAsync(eventCharge.Id); + if (charge == null) + { + throw new Exception("Charge is null. " + eventCharge.Id); + } + return charge; + } + + private async Task GetInvoiceAsync(Stripe.Event parsedEvent, bool fresh = false) + { + if (!(parsedEvent.Data.Object is Invoice eventInvoice)) + { + throw new Exception("Invoice is null (from parsed event). " + parsedEvent.Id); + } + if (!fresh) + { + return eventInvoice; + } + var invoiceService = new InvoiceService(); + var invoice = await invoiceService.GetAsync(eventInvoice.Id); + if (invoice == null) + { + throw new Exception("Invoice is null. " + eventInvoice.Id); + } + return invoice; + } + + private async Task GetSubscriptionAsync(Stripe.Event parsedEvent, bool fresh = false) + { + if (!(parsedEvent.Data.Object is Subscription eventSubscription)) + { + throw new Exception("Subscription is null (from parsed event). " + parsedEvent.Id); + } + if (!fresh) + { + return eventSubscription; + } + var subscriptionService = new SubscriptionService(); + var subscription = await subscriptionService.GetAsync(eventSubscription.Id); + if (subscription == null) + { + throw new Exception("Subscription is null. " + eventSubscription.Id); + } + return subscription; + } + + private async Task VerifyCorrectTaxRateForCharge(Invoice invoice, Subscription subscription) + { + if (!string.IsNullOrWhiteSpace(invoice?.CustomerAddress?.Country) && !string.IsNullOrWhiteSpace(invoice?.CustomerAddress?.PostalCode)) + { + var localBitwardenTaxRates = await _taxRateRepository.GetByLocationAsync( + new TaxRate() + { + Country = invoice.CustomerAddress.Country, + PostalCode = invoice.CustomerAddress.PostalCode + } + ); + + if (localBitwardenTaxRates.Any()) + { + var stripeTaxRate = await new TaxRateService().GetAsync(localBitwardenTaxRates.First().Id); + if (stripeTaxRate != null && !subscription.DefaultTaxRates.Any(x => x == stripeTaxRate)) + { + subscription.DefaultTaxRates = new List { stripeTaxRate }; + var subscriptionOptions = new SubscriptionUpdateOptions() { DefaultTaxRates = new List() { stripeTaxRate.Id } }; + subscription = await new SubscriptionService().UpdateAsync(subscription.Id, subscriptionOptions); + } + } + } + return subscription; + } + + private static bool IsSponsoredSubscription(Subscription subscription) => + StaticStore.SponsoredPlans.Any(p => p.StripePlanId == subscription.Id); + + private async Task HandlePaymentFailed(Invoice invoice) + { + if (!invoice.Paid && invoice.AttemptCount > 1 && UnpaidAutoChargeInvoiceForSubscriptionCycle(invoice)) + { var subscriptionService = new SubscriptionService(); var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); - var ids = GetIdsFromMetaData(subscription?.Metadata); - if (!ids.Item1.HasValue && !ids.Item2.HasValue) + // attempt count 4 = 11 days after initial failure + if (invoice.AttemptCount > 3 && subscription.Items.Any(i => i.Price.Id == PremiumPlanId || i.Price.Id == PremiumPlanIdAppStore)) { - return false; + await CancelSubscription(invoice.SubscriptionId); + await VoidOpenInvoices(invoice.SubscriptionId); } - - var orgTransaction = ids.Item1.HasValue; - var btObjIdField = orgTransaction ? "organization_id" : "user_id"; - var btObjId = ids.Item1 ?? ids.Item2.Value; - var btInvoiceAmount = (invoice.AmountDue / 100M); - - var existingTransactions = orgTransaction ? - await _transactionRepository.GetManyByOrganizationIdAsync(ids.Item1.Value) : - await _transactionRepository.GetManyByUserIdAsync(ids.Item2.Value); - var duplicateTimeSpan = TimeSpan.FromHours(24); - var now = DateTime.UtcNow; - var duplicateTransaction = existingTransactions? - .FirstOrDefault(t => (now - t.CreationDate) < duplicateTimeSpan); - if (duplicateTransaction != null) + else { - _logger.LogWarning("There is already a recent PayPal transaction ({0}). " + - "Do not charge again to prevent possible duplicate.", duplicateTransaction.GatewayId); - return false; - } - - var transactionResult = await _btGateway.Transaction.SaleAsync( - new Braintree.TransactionRequest - { - Amount = btInvoiceAmount, - CustomerId = customer.Metadata["btCustomerId"], - Options = new Braintree.TransactionOptionsRequest - { - SubmitForSettlement = true, - PayPal = new Braintree.TransactionOptionsPayPalRequest - { - CustomField = $"{btObjIdField}:{btObjId}" - } - }, - CustomFields = new Dictionary - { - [btObjIdField] = btObjId.ToString() - } - }); - - if (!transactionResult.IsSuccess()) - { - if (invoice.AttemptCount < 4) - { - await _mailService.SendPaymentFailedAsync(customer.Email, btInvoiceAmount, true); - } - return false; - } - - var invoiceService = new InvoiceService(); - try - { - await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions - { - Metadata = new Dictionary - { - ["btTransactionId"] = transactionResult.Target.Id, - ["btPayPalTransactionId"] = - transactionResult.Target.PayPalDetails?.AuthorizationId - } - }); - await invoiceService.PayAsync(invoice.Id, new InvoicePayOptions { PaidOutOfBand = true }); - } - catch (Exception e) - { - await _btGateway.Transaction.RefundAsync(transactionResult.Target.Id); - if (e.Message.Contains("Invoice is already paid")) - { - await invoiceService.UpdateAsync(invoice.Id, new InvoiceUpdateOptions - { - Metadata = invoice.Metadata - }); - } - else - { - throw; - } - } - - return true; - } - - private bool UnpaidAutoChargeInvoiceForSubscriptionCycle(Invoice invoice) - { - return invoice.AmountDue > 0 && !invoice.Paid && invoice.CollectionMethod == "charge_automatically" && - invoice.BillingReason == "subscription_cycle" && invoice.SubscriptionId != null; - } - - private async Task GetChargeAsync(Stripe.Event parsedEvent, bool fresh = false) - { - if (!(parsedEvent.Data.Object is Charge eventCharge)) - { - throw new Exception("Charge is null (from parsed event). " + parsedEvent.Id); - } - if (!fresh) - { - return eventCharge; - } - var chargeService = new ChargeService(); - var charge = await chargeService.GetAsync(eventCharge.Id); - if (charge == null) - { - throw new Exception("Charge is null. " + eventCharge.Id); - } - return charge; - } - - private async Task GetInvoiceAsync(Stripe.Event parsedEvent, bool fresh = false) - { - if (!(parsedEvent.Data.Object is Invoice eventInvoice)) - { - throw new Exception("Invoice is null (from parsed event). " + parsedEvent.Id); - } - if (!fresh) - { - return eventInvoice; - } - var invoiceService = new InvoiceService(); - var invoice = await invoiceService.GetAsync(eventInvoice.Id); - if (invoice == null) - { - throw new Exception("Invoice is null. " + eventInvoice.Id); - } - return invoice; - } - - private async Task GetSubscriptionAsync(Stripe.Event parsedEvent, bool fresh = false) - { - if (!(parsedEvent.Data.Object is Subscription eventSubscription)) - { - throw new Exception("Subscription is null (from parsed event). " + parsedEvent.Id); - } - if (!fresh) - { - return eventSubscription; - } - var subscriptionService = new SubscriptionService(); - var subscription = await subscriptionService.GetAsync(eventSubscription.Id); - if (subscription == null) - { - throw new Exception("Subscription is null. " + eventSubscription.Id); - } - return subscription; - } - - private async Task VerifyCorrectTaxRateForCharge(Invoice invoice, Subscription subscription) - { - if (!string.IsNullOrWhiteSpace(invoice?.CustomerAddress?.Country) && !string.IsNullOrWhiteSpace(invoice?.CustomerAddress?.PostalCode)) - { - var localBitwardenTaxRates = await _taxRateRepository.GetByLocationAsync( - new TaxRate() - { - Country = invoice.CustomerAddress.Country, - PostalCode = invoice.CustomerAddress.PostalCode - } - ); - - if (localBitwardenTaxRates.Any()) - { - var stripeTaxRate = await new TaxRateService().GetAsync(localBitwardenTaxRates.First().Id); - if (stripeTaxRate != null && !subscription.DefaultTaxRates.Any(x => x == stripeTaxRate)) - { - subscription.DefaultTaxRates = new List { stripeTaxRate }; - var subscriptionOptions = new SubscriptionUpdateOptions() { DefaultTaxRates = new List() { stripeTaxRate.Id } }; - subscription = await new SubscriptionService().UpdateAsync(subscription.Id, subscriptionOptions); - } - } - } - return subscription; - } - - private static bool IsSponsoredSubscription(Subscription subscription) => - StaticStore.SponsoredPlans.Any(p => p.StripePlanId == subscription.Id); - - private async Task HandlePaymentFailed(Invoice invoice) - { - if (!invoice.Paid && invoice.AttemptCount > 1 && UnpaidAutoChargeInvoiceForSubscriptionCycle(invoice)) - { - var subscriptionService = new SubscriptionService(); - var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); - // attempt count 4 = 11 days after initial failure - if (invoice.AttemptCount > 3 && subscription.Items.Any(i => i.Price.Id == PremiumPlanId || i.Price.Id == PremiumPlanIdAppStore)) - { - await CancelSubscription(invoice.SubscriptionId); - await VoidOpenInvoices(invoice.SubscriptionId); - } - else - { - await AttemptToPayInvoiceAsync(invoice); - } - } - } - - private async Task CancelSubscription(string subscriptionId) - { - await new SubscriptionService().CancelAsync(subscriptionId, new SubscriptionCancelOptions()); - } - - private async Task VoidOpenInvoices(string subscriptionId) - { - var invoiceService = new InvoiceService(); - var options = new InvoiceListOptions - { - Status = "open", - Subscription = subscriptionId - }; - var invoices = invoiceService.List(options); - foreach (var invoice in invoices) - { - await invoiceService.VoidInvoiceAsync(invoice.Id); + await AttemptToPayInvoiceAsync(invoice); } } } + + private async Task CancelSubscription(string subscriptionId) + { + await new SubscriptionService().CancelAsync(subscriptionId, new SubscriptionCancelOptions()); + } + + private async Task VoidOpenInvoices(string subscriptionId) + { + var invoiceService = new InvoiceService(); + var options = new InvoiceListOptions + { + Status = "open", + Subscription = subscriptionId + }; + var invoices = invoiceService.List(options); + foreach (var invoice in invoices) + { + await invoiceService.VoidInvoiceAsync(invoice.Id); + } + } } diff --git a/src/Billing/Jobs/JobsHostedService.cs b/src/Billing/Jobs/JobsHostedService.cs index ea91924a1..1a5c80774 100644 --- a/src/Billing/Jobs/JobsHostedService.cs +++ b/src/Billing/Jobs/JobsHostedService.cs @@ -3,43 +3,42 @@ using Bit.Core.Jobs; using Bit.Core.Settings; using Quartz; -namespace Bit.Billing.Jobs +namespace Bit.Billing.Jobs; + +public class JobsHostedService : BaseJobsHostedService { - public class JobsHostedService : BaseJobsHostedService + public JobsHostedService( + GlobalSettings globalSettings, + IServiceProvider serviceProvider, + ILogger logger, + ILogger listenerLogger) + : base(globalSettings, serviceProvider, logger, listenerLogger) { } + + public override async Task StartAsync(CancellationToken cancellationToken) { - public JobsHostedService( - GlobalSettings globalSettings, - IServiceProvider serviceProvider, - ILogger logger, - ILogger listenerLogger) - : base(globalSettings, serviceProvider, logger, listenerLogger) { } - - public override async Task StartAsync(CancellationToken cancellationToken) + var timeZone = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? + TimeZoneInfo.FindSystemTimeZoneById("Eastern Standard Time") : + TimeZoneInfo.FindSystemTimeZoneById("America/New_York"); + if (_globalSettings.SelfHosted) { - var timeZone = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? - TimeZoneInfo.FindSystemTimeZoneById("Eastern Standard Time") : - TimeZoneInfo.FindSystemTimeZoneById("America/New_York"); - if (_globalSettings.SelfHosted) - { - timeZone = TimeZoneInfo.Local; - } - - var everyDayAtNinePmTrigger = TriggerBuilder.Create() - .WithIdentity("EveryDayAtNinePmTrigger") - .StartNow() - .WithCronSchedule("0 0 21 * * ?", x => x.InTimeZone(timeZone)) - .Build(); - - Jobs = new List>(); - - // Add jobs here - - await base.StartAsync(cancellationToken); + timeZone = TimeZoneInfo.Local; } - public static void AddJobsServices(IServiceCollection services) - { - // Register jobs here - } + var everyDayAtNinePmTrigger = TriggerBuilder.Create() + .WithIdentity("EveryDayAtNinePmTrigger") + .StartNow() + .WithCronSchedule("0 0 21 * * ?", x => x.InTimeZone(timeZone)) + .Build(); + + Jobs = new List>(); + + // Add jobs here + + await base.StartAsync(cancellationToken); + } + + public static void AddJobsServices(IServiceCollection services) + { + // Register jobs here } } diff --git a/src/Billing/Models/BitPayEventModel.cs b/src/Billing/Models/BitPayEventModel.cs index b7ed06462..e16391317 100644 --- a/src/Billing/Models/BitPayEventModel.cs +++ b/src/Billing/Models/BitPayEventModel.cs @@ -1,28 +1,27 @@ -namespace Bit.Billing.Models +namespace Bit.Billing.Models; + +public class BitPayEventModel { - public class BitPayEventModel + public EventModel Event { get; set; } + public InvoiceDataModel Data { get; set; } + + public class EventModel { - public EventModel Event { get; set; } - public InvoiceDataModel Data { get; set; } + public int Code { get; set; } + public string Name { get; set; } + } - public class EventModel - { - public int Code { get; set; } - public string Name { get; set; } - } - - public class InvoiceDataModel - { - public string Id { get; set; } - public string Url { get; set; } - public string Status { get; set; } - public string Currency { get; set; } - public decimal Price { get; set; } - public string PosData { get; set; } - public bool ExceptionStatus { get; set; } - public long CurrentTime { get; set; } - public long AmountPaid { get; set; } - public string TransactionCurrency { get; set; } - } + public class InvoiceDataModel + { + public string Id { get; set; } + public string Url { get; set; } + public string Status { get; set; } + public string Currency { get; set; } + public decimal Price { get; set; } + public string PosData { get; set; } + public bool ExceptionStatus { get; set; } + public long CurrentTime { get; set; } + public long AmountPaid { get; set; } + public string TransactionCurrency { get; set; } } } diff --git a/src/Billing/Models/FreshdeskWebhookModel.cs b/src/Billing/Models/FreshdeskWebhookModel.cs index c371c70fb..e9fe8e026 100644 --- a/src/Billing/Models/FreshdeskWebhookModel.cs +++ b/src/Billing/Models/FreshdeskWebhookModel.cs @@ -1,16 +1,15 @@ using System.Text.Json.Serialization; -namespace Bit.Billing.Models +namespace Bit.Billing.Models; + +public class FreshdeskWebhookModel { - public class FreshdeskWebhookModel - { - [JsonPropertyName("ticket_id")] - public string TicketId { get; set; } + [JsonPropertyName("ticket_id")] + public string TicketId { get; set; } - [JsonPropertyName("ticket_contact_email")] - public string TicketContactEmail { get; set; } + [JsonPropertyName("ticket_contact_email")] + public string TicketContactEmail { get; set; } - [JsonPropertyName("ticket_tags")] - public string TicketTags { get; set; } - } + [JsonPropertyName("ticket_tags")] + public string TicketTags { get; set; } } diff --git a/src/Billing/Models/LoginModel.cs b/src/Billing/Models/LoginModel.cs index 51fdf0915..5fe04ad45 100644 --- a/src/Billing/Models/LoginModel.cs +++ b/src/Billing/Models/LoginModel.cs @@ -1,11 +1,10 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Billing.Models +namespace Bit.Billing.Models; + +public class LoginModel { - public class LoginModel - { - [Required] - [EmailAddress] - public string Email { get; set; } - } + [Required] + [EmailAddress] + public string Email { get; set; } } diff --git a/src/Billing/Program.cs b/src/Billing/Program.cs index 7b42ad73f..d7ebadd92 100644 --- a/src/Billing/Program.cs +++ b/src/Billing/Program.cs @@ -1,39 +1,38 @@ using Bit.Core.Utilities; using Serilog.Events; -namespace Bit.Billing +namespace Bit.Billing; + +public class Program { - public class Program + public static void Main(string[] args) { - public static void Main(string[] args) - { - Host - .CreateDefaultBuilder(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => + Host + .CreateDefaultBuilder(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => + { + var context = e.Properties["SourceContext"].ToString(); + if (e.Level == LogEventLevel.Information && + (context.StartsWith("\"Bit.Billing.Jobs") || context.StartsWith("\"Bit.Core.Jobs"))) { - var context = e.Properties["SourceContext"].ToString(); - if (e.Level == LogEventLevel.Information && - (context.StartsWith("\"Bit.Billing.Jobs") || context.StartsWith("\"Bit.Core.Jobs"))) - { - return true; - } + return true; + } - if (e.Properties.ContainsKey("RequestPath") && - !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) - { - return false; - } + if (e.Properties.ContainsKey("RequestPath") && + !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && + (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) + { + return false; + } - return e.Level >= LogEventLevel.Warning; - })); - }) - .Build() - .Run(); - } + return e.Level >= LogEventLevel.Warning; + })); + }) + .Build() + .Run(); } } diff --git a/src/Billing/Startup.cs b/src/Billing/Startup.cs index a2a161a88..328e6133d 100644 --- a/src/Billing/Startup.cs +++ b/src/Billing/Startup.cs @@ -6,94 +6,93 @@ using Bit.SharedWeb.Utilities; using Microsoft.Extensions.DependencyInjection.Extensions; using Stripe; -namespace Bit.Billing +namespace Bit.Billing; + +public class Startup { - public class Startup + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Configuration = configuration; + Environment = env; + } + + public IConfiguration Configuration { get; } + public IWebHostEnvironment Environment { get; set; } + + public void ConfigureServices(IServiceCollection services) + { + // Options + services.AddOptions(); + + // Settings + var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); + services.Configure(Configuration.GetSection("BillingSettings")); + + // Stripe Billing + StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; + StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; + + // Repositories + services.AddSqlServerRepositories(globalSettings); + + // PayPal Client + services.AddSingleton(); + + // BitPay Client + services.AddSingleton(); + + // Context + services.AddScoped(); + + // Identity + services.AddCustomIdentityServices(globalSettings); + //services.AddPasswordlessIdentityServices(globalSettings); + + // Services + services.AddBaseServices(globalSettings); + services.AddDefaultServices(globalSettings); + + services.TryAddSingleton(); + + // Mvc + services.AddMvc(config => { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - Configuration = configuration; - Environment = env; + config.Filters.Add(new LoggingExceptionHandlerFilterAttribute()); + }); + services.Configure(options => options.LowercaseUrls = true); + + // Authentication + services.AddAuthentication(); + + // Jobs service, uncomment when we have some jobs to run + // Jobs.JobsHostedService.AddJobsServices(services); + // services.AddHostedService(); + + // Set up HttpClients + services.AddHttpClient("FreshdeskApi"); + } + + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings) + { + app.UseSerilog(env, appLifetime, globalSettings); + + // Add general security headers + app.UseMiddleware(); + + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); } - public IConfiguration Configuration { get; } - public IWebHostEnvironment Environment { get; set; } - - public void ConfigureServices(IServiceCollection services) - { - // Options - services.AddOptions(); - - // Settings - var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); - services.Configure(Configuration.GetSection("BillingSettings")); - - // Stripe Billing - StripeConfiguration.ApiKey = globalSettings.Stripe.ApiKey; - StripeConfiguration.MaxNetworkRetries = globalSettings.Stripe.MaxNetworkRetries; - - // Repositories - services.AddSqlServerRepositories(globalSettings); - - // PayPal Client - services.AddSingleton(); - - // BitPay Client - services.AddSingleton(); - - // Context - services.AddScoped(); - - // Identity - services.AddCustomIdentityServices(globalSettings); - //services.AddPasswordlessIdentityServices(globalSettings); - - // Services - services.AddBaseServices(globalSettings); - services.AddDefaultServices(globalSettings); - - services.TryAddSingleton(); - - // Mvc - services.AddMvc(config => - { - config.Filters.Add(new LoggingExceptionHandlerFilterAttribute()); - }); - services.Configure(options => options.LowercaseUrls = true); - - // Authentication - services.AddAuthentication(); - - // Jobs service, uncomment when we have some jobs to run - // Jobs.JobsHostedService.AddJobsServices(services); - // services.AddHostedService(); - - // Set up HttpClients - services.AddHttpClient("FreshdeskApi"); - } - - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings) - { - app.UseSerilog(env, appLifetime, globalSettings); - - // Add general security headers - app.UseMiddleware(); - - if (env.IsDevelopment()) - { - app.UseDeveloperExceptionPage(); - } - - app.UseStaticFiles(); - app.UseRouting(); - app.UseAuthentication(); - app.UseAuthorization(); - app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); - } + app.UseStaticFiles(); + app.UseRouting(); + app.UseAuthentication(); + app.UseAuthorization(); + app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); } } diff --git a/src/Billing/Utilities/PayPalIpnClient.cs b/src/Billing/Utilities/PayPalIpnClient.cs index 2a7b0e852..15f7a2f15 100644 --- a/src/Billing/Utilities/PayPalIpnClient.cs +++ b/src/Billing/Utilities/PayPalIpnClient.cs @@ -4,171 +4,170 @@ using System.Text; using System.Web; using Microsoft.Extensions.Options; -namespace Bit.Billing.Utilities +namespace Bit.Billing.Utilities; + +public class PayPalIpnClient { - public class PayPalIpnClient + private readonly HttpClient _httpClient = new HttpClient(); + private readonly Uri _ipnUri; + + public PayPalIpnClient(IOptions billingSettings) { - private readonly HttpClient _httpClient = new HttpClient(); - private readonly Uri _ipnUri; + var bSettings = billingSettings?.Value; + _ipnUri = new Uri(bSettings.PayPal.Production ? "https://www.paypal.com/cgi-bin/webscr" : + "https://www.sandbox.paypal.com/cgi-bin/webscr"); + } - public PayPalIpnClient(IOptions billingSettings) + public async Task VerifyIpnAsync(string ipnBody) + { + if (ipnBody == null) { - var bSettings = billingSettings?.Value; - _ipnUri = new Uri(bSettings.PayPal.Production ? "https://www.paypal.com/cgi-bin/webscr" : - "https://www.sandbox.paypal.com/cgi-bin/webscr"); + throw new ArgumentException("No IPN body."); } - public async Task VerifyIpnAsync(string ipnBody) + var request = new HttpRequestMessage { - if (ipnBody == null) + Method = HttpMethod.Post, + RequestUri = _ipnUri + }; + var cmdIpnBody = string.Concat("cmd=_notify-validate&", ipnBody); + request.Content = new StringContent(cmdIpnBody, Encoding.UTF8, "application/x-www-form-urlencoded"); + var response = await _httpClient.SendAsync(request); + if (!response.IsSuccessStatusCode) + { + throw new Exception("Failed to verify IPN, status: " + response.StatusCode); + } + var responseContent = await response.Content.ReadAsStringAsync(); + if (responseContent.Equals("VERIFIED")) + { + return true; + } + else if (responseContent.Equals("INVALID")) + { + return false; + } + else + { + throw new Exception("Failed to verify IPN."); + } + } + + public class IpnTransaction + { + private string[] _dateFormats = new string[] + { + "HH:mm:ss dd MMM yyyy PDT", "HH:mm:ss dd MMM yyyy PST", "HH:mm:ss dd MMM, yyyy PST", + "HH:mm:ss dd MMM, yyyy PDT","HH:mm:ss MMM dd, yyyy PST", "HH:mm:ss MMM dd, yyyy PDT" + }; + + public IpnTransaction(string ipnFormData) + { + if (string.IsNullOrWhiteSpace(ipnFormData)) { - throw new ArgumentException("No IPN body."); + return; } - var request = new HttpRequestMessage + var qsData = HttpUtility.ParseQueryString(ipnFormData); + var dataDict = qsData.Keys.Cast().ToDictionary(k => k, v => qsData[v].ToString()); + + TxnId = GetDictValue(dataDict, "txn_id"); + TxnType = GetDictValue(dataDict, "txn_type"); + ParentTxnId = GetDictValue(dataDict, "parent_txn_id"); + PaymentStatus = GetDictValue(dataDict, "payment_status"); + PaymentType = GetDictValue(dataDict, "payment_type"); + McCurrency = GetDictValue(dataDict, "mc_currency"); + Custom = GetDictValue(dataDict, "custom"); + ItemName = GetDictValue(dataDict, "item_name"); + ItemNumber = GetDictValue(dataDict, "item_number"); + PayerId = GetDictValue(dataDict, "payer_id"); + PayerEmail = GetDictValue(dataDict, "payer_email"); + ReceiverId = GetDictValue(dataDict, "receiver_id"); + ReceiverEmail = GetDictValue(dataDict, "receiver_email"); + + PaymentDate = ConvertDate(GetDictValue(dataDict, "payment_date")); + + var mcGrossString = GetDictValue(dataDict, "mc_gross"); + if (!string.IsNullOrWhiteSpace(mcGrossString) && decimal.TryParse(mcGrossString, out var mcGross)) { - Method = HttpMethod.Post, - RequestUri = _ipnUri - }; - var cmdIpnBody = string.Concat("cmd=_notify-validate&", ipnBody); - request.Content = new StringContent(cmdIpnBody, Encoding.UTF8, "application/x-www-form-urlencoded"); - var response = await _httpClient.SendAsync(request); - if (!response.IsSuccessStatusCode) - { - throw new Exception("Failed to verify IPN, status: " + response.StatusCode); + McGross = mcGross; } - var responseContent = await response.Content.ReadAsStringAsync(); - if (responseContent.Equals("VERIFIED")) + var mcFeeString = GetDictValue(dataDict, "mc_fee"); + if (!string.IsNullOrWhiteSpace(mcFeeString) && decimal.TryParse(mcFeeString, out var mcFee)) { - return true; - } - else if (responseContent.Equals("INVALID")) - { - return false; - } - else - { - throw new Exception("Failed to verify IPN."); + McFee = mcFee; } } - public class IpnTransaction + public string TxnId { get; set; } + public string TxnType { get; set; } + public string ParentTxnId { get; set; } + public string PaymentStatus { get; set; } + public string PaymentType { get; set; } + public decimal McGross { get; set; } + public decimal McFee { get; set; } + public string McCurrency { get; set; } + public string Custom { get; set; } + public string ItemName { get; set; } + public string ItemNumber { get; set; } + public string PayerId { get; set; } + public string PayerEmail { get; set; } + public string ReceiverId { get; set; } + public string ReceiverEmail { get; set; } + public DateTime PaymentDate { get; set; } + + public Tuple GetIdsFromCustom() { - private string[] _dateFormats = new string[] + Guid? orgId = null; + Guid? userId = null; + + if (!string.IsNullOrWhiteSpace(Custom) && Custom.Contains(":")) { - "HH:mm:ss dd MMM yyyy PDT", "HH:mm:ss dd MMM yyyy PST", "HH:mm:ss dd MMM, yyyy PST", - "HH:mm:ss dd MMM, yyyy PDT","HH:mm:ss MMM dd, yyyy PST", "HH:mm:ss MMM dd, yyyy PDT" - }; - - public IpnTransaction(string ipnFormData) - { - if (string.IsNullOrWhiteSpace(ipnFormData)) + var mainParts = Custom.Split(','); + foreach (var mainPart in mainParts) { - return; - } - - var qsData = HttpUtility.ParseQueryString(ipnFormData); - var dataDict = qsData.Keys.Cast().ToDictionary(k => k, v => qsData[v].ToString()); - - TxnId = GetDictValue(dataDict, "txn_id"); - TxnType = GetDictValue(dataDict, "txn_type"); - ParentTxnId = GetDictValue(dataDict, "parent_txn_id"); - PaymentStatus = GetDictValue(dataDict, "payment_status"); - PaymentType = GetDictValue(dataDict, "payment_type"); - McCurrency = GetDictValue(dataDict, "mc_currency"); - Custom = GetDictValue(dataDict, "custom"); - ItemName = GetDictValue(dataDict, "item_name"); - ItemNumber = GetDictValue(dataDict, "item_number"); - PayerId = GetDictValue(dataDict, "payer_id"); - PayerEmail = GetDictValue(dataDict, "payer_email"); - ReceiverId = GetDictValue(dataDict, "receiver_id"); - ReceiverEmail = GetDictValue(dataDict, "receiver_email"); - - PaymentDate = ConvertDate(GetDictValue(dataDict, "payment_date")); - - var mcGrossString = GetDictValue(dataDict, "mc_gross"); - if (!string.IsNullOrWhiteSpace(mcGrossString) && decimal.TryParse(mcGrossString, out var mcGross)) - { - McGross = mcGross; - } - var mcFeeString = GetDictValue(dataDict, "mc_fee"); - if (!string.IsNullOrWhiteSpace(mcFeeString) && decimal.TryParse(mcFeeString, out var mcFee)) - { - McFee = mcFee; - } - } - - public string TxnId { get; set; } - public string TxnType { get; set; } - public string ParentTxnId { get; set; } - public string PaymentStatus { get; set; } - public string PaymentType { get; set; } - public decimal McGross { get; set; } - public decimal McFee { get; set; } - public string McCurrency { get; set; } - public string Custom { get; set; } - public string ItemName { get; set; } - public string ItemNumber { get; set; } - public string PayerId { get; set; } - public string PayerEmail { get; set; } - public string ReceiverId { get; set; } - public string ReceiverEmail { get; set; } - public DateTime PaymentDate { get; set; } - - public Tuple GetIdsFromCustom() - { - Guid? orgId = null; - Guid? userId = null; - - if (!string.IsNullOrWhiteSpace(Custom) && Custom.Contains(":")) - { - var mainParts = Custom.Split(','); - foreach (var mainPart in mainParts) + var parts = mainPart.Split(':'); + if (parts.Length > 1 && Guid.TryParse(parts[1], out var id)) { - var parts = mainPart.Split(':'); - if (parts.Length > 1 && Guid.TryParse(parts[1], out var id)) + if (parts[0] == "user_id") { - if (parts[0] == "user_id") - { - userId = id; - } - else if (parts[0] == "organization_id") - { - orgId = id; - } + userId = id; + } + else if (parts[0] == "organization_id") + { + orgId = id; } } } - - return new Tuple(orgId, userId); } - public bool IsAccountCredit() - { - return !string.IsNullOrWhiteSpace(Custom) && Custom.Contains("account_credit:1"); - } + return new Tuple(orgId, userId); + } - private string GetDictValue(IDictionary dict, string key) - { - return dict.ContainsKey(key) ? dict[key] : null; - } + public bool IsAccountCredit() + { + return !string.IsNullOrWhiteSpace(Custom) && Custom.Contains("account_credit:1"); + } - private DateTime ConvertDate(string dateString) + private string GetDictValue(IDictionary dict, string key) + { + return dict.ContainsKey(key) ? dict[key] : null; + } + + private DateTime ConvertDate(string dateString) + { + if (!string.IsNullOrWhiteSpace(dateString)) { - if (!string.IsNullOrWhiteSpace(dateString)) + var parsed = DateTime.TryParseExact(dateString, _dateFormats, + CultureInfo.InvariantCulture, DateTimeStyles.None, out var paymentDate); + if (parsed) { - var parsed = DateTime.TryParseExact(dateString, _dateFormats, - CultureInfo.InvariantCulture, DateTimeStyles.None, out var paymentDate); - if (parsed) - { - var pacificTime = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? - TimeZoneInfo.FindSystemTimeZoneById("Pacific Standard Time") : - TimeZoneInfo.FindSystemTimeZoneById("America/Los_Angeles"); - return TimeZoneInfo.ConvertTimeToUtc(paymentDate, pacificTime); - } + var pacificTime = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? + TimeZoneInfo.FindSystemTimeZoneById("Pacific Standard Time") : + TimeZoneInfo.FindSystemTimeZoneById("America/Los_Angeles"); + return TimeZoneInfo.ConvertTimeToUtc(paymentDate, pacificTime); } - return default(DateTime); } + return default(DateTime); } } } diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 68fd94295..8d1f009d4 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -1,23 +1,22 @@ -namespace Bit.Core +namespace Bit.Core; + +public static class Constants { - public static class Constants - { - public const int BypassFiltersEventId = 12482444; + public const int BypassFiltersEventId = 12482444; - // File size limits - give 1 MB extra for cushion. - // Note: if request size limits are changed, 'client_max_body_size' - // in nginx/proxy.conf may also need to be updated accordingly. - public const long FileSize101mb = 101L * 1024L * 1024L; - public const long FileSize501mb = 501L * 1024L * 1024L; - } - - public static class TokenPurposes - { - public const string LinkSso = "LinkSso"; - } - - public static class AuthenticationSchemes - { - public const string BitwardenExternalCookieAuthenticationScheme = "bw.external"; - } + // File size limits - give 1 MB extra for cushion. + // Note: if request size limits are changed, 'client_max_body_size' + // in nginx/proxy.conf may also need to be updated accordingly. + public const long FileSize101mb = 101L * 1024L * 1024L; + public const long FileSize501mb = 501L * 1024L * 1024L; +} + +public static class TokenPurposes +{ + public const string LinkSso = "LinkSso"; +} + +public static class AuthenticationSchemes +{ + public const string BitwardenExternalCookieAuthenticationScheme = "bw.external"; } diff --git a/src/Core/Context/CurrentContentOrganization.cs b/src/Core/Context/CurrentContentOrganization.cs index 7a54b2727..040c1ece4 100644 --- a/src/Core/Context/CurrentContentOrganization.cs +++ b/src/Core/Context/CurrentContentOrganization.cs @@ -3,21 +3,20 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Core.Context +namespace Bit.Core.Context; + +public class CurrentContentOrganization { - public class CurrentContentOrganization + public CurrentContentOrganization() { } + + public CurrentContentOrganization(OrganizationUser orgUser) { - public CurrentContentOrganization() { } - - public CurrentContentOrganization(OrganizationUser orgUser) - { - Id = orgUser.OrganizationId; - Type = orgUser.Type; - Permissions = CoreHelpers.LoadClassFromJsonData(orgUser.Permissions); - } - - public Guid Id { get; set; } - public OrganizationUserType Type { get; set; } - public Permissions Permissions { get; set; } + Id = orgUser.OrganizationId; + Type = orgUser.Type; + Permissions = CoreHelpers.LoadClassFromJsonData(orgUser.Permissions); } + + public Guid Id { get; set; } + public OrganizationUserType Type { get; set; } + public Permissions Permissions { get; set; } } diff --git a/src/Core/Context/CurrentContentProvider.cs b/src/Core/Context/CurrentContentProvider.cs index f1925f551..f089be7b8 100644 --- a/src/Core/Context/CurrentContentProvider.cs +++ b/src/Core/Context/CurrentContentProvider.cs @@ -3,21 +3,20 @@ using Bit.Core.Enums.Provider; using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Core.Context +namespace Bit.Core.Context; + +public class CurrentContentProvider { - public class CurrentContentProvider + public CurrentContentProvider() { } + + public CurrentContentProvider(ProviderUser providerUser) { - public CurrentContentProvider() { } - - public CurrentContentProvider(ProviderUser providerUser) - { - Id = providerUser.ProviderId; - Type = providerUser.Type; - Permissions = CoreHelpers.LoadClassFromJsonData(providerUser.Permissions); - } - - public Guid Id { get; set; } - public ProviderUserType Type { get; set; } - public Permissions Permissions { get; set; } + Id = providerUser.ProviderId; + Type = providerUser.Type; + Permissions = CoreHelpers.LoadClassFromJsonData(providerUser.Permissions); } + + public Guid Id { get; set; } + public ProviderUserType Type { get; set; } + public Permissions Permissions { get; set; } } diff --git a/src/Core/Context/CurrentContext.cs b/src/Core/Context/CurrentContext.cs index 47effcab1..d78340d70 100644 --- a/src/Core/Context/CurrentContext.cs +++ b/src/Core/Context/CurrentContext.cs @@ -8,486 +8,485 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.AspNetCore.Http; -namespace Bit.Core.Context +namespace Bit.Core.Context; + +public class CurrentContext : ICurrentContext { - public class CurrentContext : ICurrentContext + private readonly IProviderUserRepository _providerUserRepository; + private bool _builtHttpContext; + private bool _builtClaimsPrincipal; + private IEnumerable _providerUserOrganizations; + + public virtual HttpContext HttpContext { get; set; } + public virtual Guid? UserId { get; set; } + public virtual User User { get; set; } + public virtual string DeviceIdentifier { get; set; } + public virtual DeviceType? DeviceType { get; set; } + public virtual string IpAddress { get; set; } + public virtual List Organizations { get; set; } + public virtual List Providers { get; set; } + public virtual Guid? InstallationId { get; set; } + public virtual Guid? OrganizationId { get; set; } + public virtual bool CloudflareWorkerProxied { get; set; } + public virtual bool IsBot { get; set; } + public virtual bool MaybeBot { get; set; } + public virtual int? BotScore { get; set; } + public virtual string ClientId { get; set; } + + public CurrentContext(IProviderUserRepository providerUserRepository) { - private readonly IProviderUserRepository _providerUserRepository; - private bool _builtHttpContext; - private bool _builtClaimsPrincipal; - private IEnumerable _providerUserOrganizations; + _providerUserRepository = providerUserRepository; + } - public virtual HttpContext HttpContext { get; set; } - public virtual Guid? UserId { get; set; } - public virtual User User { get; set; } - public virtual string DeviceIdentifier { get; set; } - public virtual DeviceType? DeviceType { get; set; } - public virtual string IpAddress { get; set; } - public virtual List Organizations { get; set; } - public virtual List Providers { get; set; } - public virtual Guid? InstallationId { get; set; } - public virtual Guid? OrganizationId { get; set; } - public virtual bool CloudflareWorkerProxied { get; set; } - public virtual bool IsBot { get; set; } - public virtual bool MaybeBot { get; set; } - public virtual int? BotScore { get; set; } - public virtual string ClientId { get; set; } - - public CurrentContext(IProviderUserRepository providerUserRepository) + public async virtual Task BuildAsync(HttpContext httpContext, GlobalSettings globalSettings) + { + if (_builtHttpContext) { - _providerUserRepository = providerUserRepository; + return; } - public async virtual Task BuildAsync(HttpContext httpContext, GlobalSettings globalSettings) + _builtHttpContext = true; + HttpContext = httpContext; + await BuildAsync(httpContext.User, globalSettings); + + if (DeviceIdentifier == null && httpContext.Request.Headers.ContainsKey("Device-Identifier")) { - if (_builtHttpContext) - { - return; - } - - _builtHttpContext = true; - HttpContext = httpContext; - await BuildAsync(httpContext.User, globalSettings); - - if (DeviceIdentifier == null && httpContext.Request.Headers.ContainsKey("Device-Identifier")) - { - DeviceIdentifier = httpContext.Request.Headers["Device-Identifier"]; - } - - if (httpContext.Request.Headers.ContainsKey("Device-Type") && - Enum.TryParse(httpContext.Request.Headers["Device-Type"].ToString(), out DeviceType dType)) - { - DeviceType = dType; - } - - if (!BotScore.HasValue && httpContext.Request.Headers.ContainsKey("X-Cf-Bot-Score") && - int.TryParse(httpContext.Request.Headers["X-Cf-Bot-Score"], out var parsedBotScore)) - { - BotScore = parsedBotScore; - } - - if (httpContext.Request.Headers.ContainsKey("X-Cf-Worked-Proxied")) - { - CloudflareWorkerProxied = httpContext.Request.Headers["X-Cf-Worked-Proxied"] == "1"; - } - - if (httpContext.Request.Headers.ContainsKey("X-Cf-Is-Bot")) - { - IsBot = httpContext.Request.Headers["X-Cf-Is-Bot"] == "1"; - } - - if (httpContext.Request.Headers.ContainsKey("X-Cf-Maybe-Bot")) - { - MaybeBot = httpContext.Request.Headers["X-Cf-Maybe-Bot"] == "1"; - } + DeviceIdentifier = httpContext.Request.Headers["Device-Identifier"]; } - public async virtual Task BuildAsync(ClaimsPrincipal user, GlobalSettings globalSettings) + if (httpContext.Request.Headers.ContainsKey("Device-Type") && + Enum.TryParse(httpContext.Request.Headers["Device-Type"].ToString(), out DeviceType dType)) { - if (_builtClaimsPrincipal) - { - return; - } - - _builtClaimsPrincipal = true; - IpAddress = HttpContext.GetIpAddress(globalSettings); - await SetContextAsync(user); + DeviceType = dType; } - public virtual Task SetContextAsync(ClaimsPrincipal user) + if (!BotScore.HasValue && httpContext.Request.Headers.ContainsKey("X-Cf-Bot-Score") && + int.TryParse(httpContext.Request.Headers["X-Cf-Bot-Score"], out var parsedBotScore)) { - if (user == null || !user.Claims.Any()) - { - return Task.FromResult(0); - } + BotScore = parsedBotScore; + } - var claimsDict = user.Claims.GroupBy(c => c.Type).ToDictionary(c => c.Key, c => c.Select(v => v)); + if (httpContext.Request.Headers.ContainsKey("X-Cf-Worked-Proxied")) + { + CloudflareWorkerProxied = httpContext.Request.Headers["X-Cf-Worked-Proxied"] == "1"; + } - var subject = GetClaimValue(claimsDict, "sub"); - if (Guid.TryParse(subject, out var subIdGuid)) - { - UserId = subIdGuid; - } + if (httpContext.Request.Headers.ContainsKey("X-Cf-Is-Bot")) + { + IsBot = httpContext.Request.Headers["X-Cf-Is-Bot"] == "1"; + } - ClientId = GetClaimValue(claimsDict, "client_id"); - var clientSubject = GetClaimValue(claimsDict, "client_sub"); - var orgApi = false; - if (clientSubject != null) - { - if (ClientId?.StartsWith("installation.") ?? false) - { - if (Guid.TryParse(clientSubject, out var idGuid)) - { - InstallationId = idGuid; - } - } - else if (ClientId?.StartsWith("organization.") ?? false) - { - if (Guid.TryParse(clientSubject, out var idGuid)) - { - OrganizationId = idGuid; - orgApi = true; - } - } - } + if (httpContext.Request.Headers.ContainsKey("X-Cf-Maybe-Bot")) + { + MaybeBot = httpContext.Request.Headers["X-Cf-Maybe-Bot"] == "1"; + } + } - DeviceIdentifier = GetClaimValue(claimsDict, "device"); + public async virtual Task BuildAsync(ClaimsPrincipal user, GlobalSettings globalSettings) + { + if (_builtClaimsPrincipal) + { + return; + } - Organizations = GetOrganizations(claimsDict, orgApi); - - Providers = GetProviders(claimsDict); + _builtClaimsPrincipal = true; + IpAddress = HttpContext.GetIpAddress(globalSettings); + await SetContextAsync(user); + } + public virtual Task SetContextAsync(ClaimsPrincipal user) + { + if (user == null || !user.Claims.Any()) + { return Task.FromResult(0); } - private List GetOrganizations(Dictionary> claimsDict, bool orgApi) + var claimsDict = user.Claims.GroupBy(c => c.Type).ToDictionary(c => c.Key, c => c.Select(v => v)); + + var subject = GetClaimValue(claimsDict, "sub"); + if (Guid.TryParse(subject, out var subIdGuid)) { - var organizations = new List(); - if (claimsDict.ContainsKey("orgowner")) + UserId = subIdGuid; + } + + ClientId = GetClaimValue(claimsDict, "client_id"); + var clientSubject = GetClaimValue(claimsDict, "client_sub"); + var orgApi = false; + if (clientSubject != null) + { + if (ClientId?.StartsWith("installation.") ?? false) { - organizations.AddRange(claimsDict["orgowner"].Select(c => - new CurrentContentOrganization - { - Id = new Guid(c.Value), - Type = OrganizationUserType.Owner - })); - } - else if (orgApi && OrganizationId.HasValue) - { - organizations.Add(new CurrentContentOrganization + if (Guid.TryParse(clientSubject, out var idGuid)) { - Id = OrganizationId.Value, + InstallationId = idGuid; + } + } + else if (ClientId?.StartsWith("organization.") ?? false) + { + if (Guid.TryParse(clientSubject, out var idGuid)) + { + OrganizationId = idGuid; + orgApi = true; + } + } + } + + DeviceIdentifier = GetClaimValue(claimsDict, "device"); + + Organizations = GetOrganizations(claimsDict, orgApi); + + Providers = GetProviders(claimsDict); + + return Task.FromResult(0); + } + + private List GetOrganizations(Dictionary> claimsDict, bool orgApi) + { + var organizations = new List(); + if (claimsDict.ContainsKey("orgowner")) + { + organizations.AddRange(claimsDict["orgowner"].Select(c => + new CurrentContentOrganization + { + Id = new Guid(c.Value), Type = OrganizationUserType.Owner - }); - } - - if (claimsDict.ContainsKey("orgadmin")) + })); + } + else if (orgApi && OrganizationId.HasValue) + { + organizations.Add(new CurrentContentOrganization { - organizations.AddRange(claimsDict["orgadmin"].Select(c => - new CurrentContentOrganization - { - Id = new Guid(c.Value), - Type = OrganizationUserType.Admin - })); - } - - if (claimsDict.ContainsKey("orguser")) - { - organizations.AddRange(claimsDict["orguser"].Select(c => - new CurrentContentOrganization - { - Id = new Guid(c.Value), - Type = OrganizationUserType.User - })); - } - - if (claimsDict.ContainsKey("orgmanager")) - { - organizations.AddRange(claimsDict["orgmanager"].Select(c => - new CurrentContentOrganization - { - Id = new Guid(c.Value), - Type = OrganizationUserType.Manager - })); - } - - if (claimsDict.ContainsKey("orgcustom")) - { - organizations.AddRange(claimsDict["orgcustom"].Select(c => - new CurrentContentOrganization - { - Id = new Guid(c.Value), - Type = OrganizationUserType.Custom, - Permissions = SetOrganizationPermissionsFromClaims(c.Value, claimsDict) - })); - } - - return organizations; + Id = OrganizationId.Value, + Type = OrganizationUserType.Owner + }); } - private List GetProviders(Dictionary> claimsDict) + if (claimsDict.ContainsKey("orgadmin")) { - var providers = new List(); - if (claimsDict.ContainsKey("providerprovideradmin")) - { - providers.AddRange(claimsDict["providerprovideradmin"].Select(c => - new CurrentContentProvider - { - Id = new Guid(c.Value), - Type = ProviderUserType.ProviderAdmin - })); - } - - if (claimsDict.ContainsKey("providerserviceuser")) - { - providers.AddRange(claimsDict["providerserviceuser"].Select(c => - new CurrentContentProvider - { - Id = new Guid(c.Value), - Type = ProviderUserType.ServiceUser - })); - } - - return providers; + organizations.AddRange(claimsDict["orgadmin"].Select(c => + new CurrentContentOrganization + { + Id = new Guid(c.Value), + Type = OrganizationUserType.Admin + })); } - public async Task OrganizationUser(Guid orgId) + if (claimsDict.ContainsKey("orguser")) { - return (Organizations?.Any(o => o.Id == orgId) ?? false) || await OrganizationOwner(orgId); + organizations.AddRange(claimsDict["orguser"].Select(c => + new CurrentContentOrganization + { + Id = new Guid(c.Value), + Type = OrganizationUserType.User + })); } - public async Task OrganizationManager(Guid orgId) + if (claimsDict.ContainsKey("orgmanager")) { - return await OrganizationAdmin(orgId) || - (Organizations?.Any(o => o.Id == orgId && o.Type == OrganizationUserType.Manager) ?? false); + organizations.AddRange(claimsDict["orgmanager"].Select(c => + new CurrentContentOrganization + { + Id = new Guid(c.Value), + Type = OrganizationUserType.Manager + })); } - public async Task OrganizationAdmin(Guid orgId) + if (claimsDict.ContainsKey("orgcustom")) { - return await OrganizationOwner(orgId) || - (Organizations?.Any(o => o.Id == orgId && o.Type == OrganizationUserType.Admin) ?? false); + organizations.AddRange(claimsDict["orgcustom"].Select(c => + new CurrentContentOrganization + { + Id = new Guid(c.Value), + Type = OrganizationUserType.Custom, + Permissions = SetOrganizationPermissionsFromClaims(c.Value, claimsDict) + })); } - public async Task OrganizationOwner(Guid orgId) + return organizations; + } + + private List GetProviders(Dictionary> claimsDict) + { + var providers = new List(); + if (claimsDict.ContainsKey("providerprovideradmin")) { - if (Organizations?.Any(o => o.Id == orgId && o.Type == OrganizationUserType.Owner) ?? false) - { - return true; - } - - if (Providers.Any()) - { - return await ProviderUserForOrgAsync(orgId); - } - - return false; + providers.AddRange(claimsDict["providerprovideradmin"].Select(c => + new CurrentContentProvider + { + Id = new Guid(c.Value), + Type = ProviderUserType.ProviderAdmin + })); } - public Task OrganizationCustom(Guid orgId) + if (claimsDict.ContainsKey("providerserviceuser")) { - return Task.FromResult(Organizations?.Any(o => o.Id == orgId && o.Type == OrganizationUserType.Custom) ?? false); + providers.AddRange(claimsDict["providerserviceuser"].Select(c => + new CurrentContentProvider + { + Id = new Guid(c.Value), + Type = ProviderUserType.ServiceUser + })); } - public async Task AccessEventLogs(Guid orgId) + return providers; + } + + public async Task OrganizationUser(Guid orgId) + { + return (Organizations?.Any(o => o.Id == orgId) ?? false) || await OrganizationOwner(orgId); + } + + public async Task OrganizationManager(Guid orgId) + { + return await OrganizationAdmin(orgId) || + (Organizations?.Any(o => o.Id == orgId && o.Type == OrganizationUserType.Manager) ?? false); + } + + public async Task OrganizationAdmin(Guid orgId) + { + return await OrganizationOwner(orgId) || + (Organizations?.Any(o => o.Id == orgId && o.Type == OrganizationUserType.Admin) ?? false); + } + + public async Task OrganizationOwner(Guid orgId) + { + if (Organizations?.Any(o => o.Id == orgId && o.Type == OrganizationUserType.Owner) ?? false) { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.AccessEventLogs ?? false)) ?? false); + return true; } - public async Task AccessImportExport(Guid orgId) + if (Providers.Any()) { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.AccessImportExport ?? false)) ?? false); + return await ProviderUserForOrgAsync(orgId); } - public async Task AccessReports(Guid orgId) + return false; + } + + public Task OrganizationCustom(Guid orgId) + { + return Task.FromResult(Organizations?.Any(o => o.Id == orgId && o.Type == OrganizationUserType.Custom) ?? false); + } + + public async Task AccessEventLogs(Guid orgId) + { + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.AccessEventLogs ?? false)) ?? false); + } + + public async Task AccessImportExport(Guid orgId) + { + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.AccessImportExport ?? false)) ?? false); + } + + public async Task AccessReports(Guid orgId) + { + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.AccessReports ?? false)) ?? false); + } + + public async Task CreateNewCollections(Guid orgId) + { + return await OrganizationManager(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.CreateNewCollections ?? false)) ?? false); + } + + public async Task EditAnyCollection(Guid orgId) + { + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.EditAnyCollection ?? false)) ?? false); + } + + public async Task DeleteAnyCollection(Guid orgId) + { + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.DeleteAnyCollection ?? false)) ?? false); + } + + public async Task ViewAllCollections(Guid orgId) + { + return await CreateNewCollections(orgId) || await EditAnyCollection(orgId) || await DeleteAnyCollection(orgId); + } + + public async Task EditAssignedCollections(Guid orgId) + { + return await OrganizationManager(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.EditAssignedCollections ?? false)) ?? false); + } + + public async Task DeleteAssignedCollections(Guid orgId) + { + return await OrganizationManager(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.DeleteAssignedCollections ?? false)) ?? false); + } + + public async Task ViewAssignedCollections(Guid orgId) + { + return await EditAssignedCollections(orgId) || await DeleteAssignedCollections(orgId); + } + + public async Task ManageGroups(Guid orgId) + { + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.ManageGroups ?? false)) ?? false); + } + + public async Task ManagePolicies(Guid orgId) + { + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.ManagePolicies ?? false)) ?? false); + } + + public async Task ManageSso(Guid orgId) + { + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.ManageSso ?? false)) ?? false); + } + + public async Task ManageScim(Guid orgId) + { + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.ManageScim ?? false)) ?? false); + } + + public async Task ManageUsers(Guid orgId) + { + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.ManageUsers ?? false)) ?? false); + } + + public async Task ManageResetPassword(Guid orgId) + { + return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId + && (o.Permissions?.ManageResetPassword ?? false)) ?? false); + } + + public async Task ManageBilling(Guid orgId) + { + var orgManagedByProvider = await ProviderIdForOrg(orgId) != null; + + return orgManagedByProvider + ? await ProviderUserForOrgAsync(orgId) + : await OrganizationOwner(orgId); + } + + public bool ProviderProviderAdmin(Guid providerId) + { + return Providers?.Any(o => o.Id == providerId && o.Type == ProviderUserType.ProviderAdmin) ?? false; + } + + public bool ProviderManageUsers(Guid providerId) + { + return ProviderProviderAdmin(providerId); + } + + public bool ProviderAccessEventLogs(Guid providerId) + { + return ProviderProviderAdmin(providerId); + } + + public bool AccessProviderOrganizations(Guid providerId) + { + return ProviderUser(providerId); + } + + public bool ManageProviderOrganizations(Guid providerId) + { + return ProviderProviderAdmin(providerId); + } + + public bool ProviderUser(Guid providerId) + { + return Providers?.Any(o => o.Id == providerId) ?? false; + } + + public async Task ProviderUserForOrgAsync(Guid orgId) + { + return (await GetProviderOrganizations()).Any(po => po.OrganizationId == orgId); + } + + public async Task ProviderIdForOrg(Guid orgId) + { + if (Organizations?.Any(org => org.Id == orgId) ?? false) { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.AccessReports ?? false)) ?? false); + return null; } - public async Task CreateNewCollections(Guid orgId) + var po = (await GetProviderOrganizations()) + ?.FirstOrDefault(po => po.OrganizationId == orgId); + + return po?.ProviderId; + } + + public async Task> OrganizationMembershipAsync( + IOrganizationUserRepository organizationUserRepository, Guid userId) + { + if (Organizations == null) { - return await OrganizationManager(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.CreateNewCollections ?? false)) ?? false); + var userOrgs = await organizationUserRepository.GetManyByUserAsync(userId); + Organizations = userOrgs.Where(ou => ou.Status == OrganizationUserStatusType.Confirmed) + .Select(ou => new CurrentContentOrganization(ou)).ToList(); } + return Organizations; + } - public async Task EditAnyCollection(Guid orgId) + public async Task> ProviderMembershipAsync( + IProviderUserRepository providerUserRepository, Guid userId) + { + if (Providers == null) { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.EditAnyCollection ?? false)) ?? false); + var userProviders = await providerUserRepository.GetManyByUserAsync(userId); + Providers = userProviders.Where(ou => ou.Status == ProviderUserStatusType.Confirmed) + .Select(ou => new CurrentContentProvider(ou)).ToList(); } + return Providers; + } - public async Task DeleteAnyCollection(Guid orgId) + private string GetClaimValue(Dictionary> claims, string type) + { + if (!claims.ContainsKey(type)) { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.DeleteAnyCollection ?? false)) ?? false); + return null; } - public async Task ViewAllCollections(Guid orgId) + return claims[type].FirstOrDefault()?.Value; + } + + private Permissions SetOrganizationPermissionsFromClaims(string organizationId, Dictionary> claimsDict) + { + bool hasClaim(string claimKey) { - return await CreateNewCollections(orgId) || await EditAnyCollection(orgId) || await DeleteAnyCollection(orgId); + return claimsDict.ContainsKey(claimKey) ? + claimsDict[claimKey].Any(x => x.Value == organizationId) : false; } - public async Task EditAssignedCollections(Guid orgId) + return new Permissions { - return await OrganizationManager(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.EditAssignedCollections ?? false)) ?? false); - } + AccessEventLogs = hasClaim("accesseventlogs"), + AccessImportExport = hasClaim("accessimportexport"), + AccessReports = hasClaim("accessreports"), + CreateNewCollections = hasClaim("createnewcollections"), + EditAnyCollection = hasClaim("editanycollection"), + DeleteAnyCollection = hasClaim("deleteanycollection"), + EditAssignedCollections = hasClaim("editassignedcollections"), + DeleteAssignedCollections = hasClaim("deleteassignedcollections"), + ManageGroups = hasClaim("managegroups"), + ManagePolicies = hasClaim("managepolicies"), + ManageSso = hasClaim("managesso"), + ManageUsers = hasClaim("manageusers"), + ManageResetPassword = hasClaim("manageresetpassword"), + ManageScim = hasClaim("managescim"), + }; + } - public async Task DeleteAssignedCollections(Guid orgId) + protected async Task> GetProviderOrganizations() + { + if (_providerUserOrganizations == null && UserId.HasValue) { - return await OrganizationManager(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.DeleteAssignedCollections ?? false)) ?? false); + _providerUserOrganizations = await _providerUserRepository.GetManyOrganizationDetailsByUserAsync(UserId.Value, ProviderUserStatusType.Confirmed); } - public async Task ViewAssignedCollections(Guid orgId) - { - return await EditAssignedCollections(orgId) || await DeleteAssignedCollections(orgId); - } - - public async Task ManageGroups(Guid orgId) - { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.ManageGroups ?? false)) ?? false); - } - - public async Task ManagePolicies(Guid orgId) - { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.ManagePolicies ?? false)) ?? false); - } - - public async Task ManageSso(Guid orgId) - { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.ManageSso ?? false)) ?? false); - } - - public async Task ManageScim(Guid orgId) - { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.ManageScim ?? false)) ?? false); - } - - public async Task ManageUsers(Guid orgId) - { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.ManageUsers ?? false)) ?? false); - } - - public async Task ManageResetPassword(Guid orgId) - { - return await OrganizationAdmin(orgId) || (Organizations?.Any(o => o.Id == orgId - && (o.Permissions?.ManageResetPassword ?? false)) ?? false); - } - - public async Task ManageBilling(Guid orgId) - { - var orgManagedByProvider = await ProviderIdForOrg(orgId) != null; - - return orgManagedByProvider - ? await ProviderUserForOrgAsync(orgId) - : await OrganizationOwner(orgId); - } - - public bool ProviderProviderAdmin(Guid providerId) - { - return Providers?.Any(o => o.Id == providerId && o.Type == ProviderUserType.ProviderAdmin) ?? false; - } - - public bool ProviderManageUsers(Guid providerId) - { - return ProviderProviderAdmin(providerId); - } - - public bool ProviderAccessEventLogs(Guid providerId) - { - return ProviderProviderAdmin(providerId); - } - - public bool AccessProviderOrganizations(Guid providerId) - { - return ProviderUser(providerId); - } - - public bool ManageProviderOrganizations(Guid providerId) - { - return ProviderProviderAdmin(providerId); - } - - public bool ProviderUser(Guid providerId) - { - return Providers?.Any(o => o.Id == providerId) ?? false; - } - - public async Task ProviderUserForOrgAsync(Guid orgId) - { - return (await GetProviderOrganizations()).Any(po => po.OrganizationId == orgId); - } - - public async Task ProviderIdForOrg(Guid orgId) - { - if (Organizations?.Any(org => org.Id == orgId) ?? false) - { - return null; - } - - var po = (await GetProviderOrganizations()) - ?.FirstOrDefault(po => po.OrganizationId == orgId); - - return po?.ProviderId; - } - - public async Task> OrganizationMembershipAsync( - IOrganizationUserRepository organizationUserRepository, Guid userId) - { - if (Organizations == null) - { - var userOrgs = await organizationUserRepository.GetManyByUserAsync(userId); - Organizations = userOrgs.Where(ou => ou.Status == OrganizationUserStatusType.Confirmed) - .Select(ou => new CurrentContentOrganization(ou)).ToList(); - } - return Organizations; - } - - public async Task> ProviderMembershipAsync( - IProviderUserRepository providerUserRepository, Guid userId) - { - if (Providers == null) - { - var userProviders = await providerUserRepository.GetManyByUserAsync(userId); - Providers = userProviders.Where(ou => ou.Status == ProviderUserStatusType.Confirmed) - .Select(ou => new CurrentContentProvider(ou)).ToList(); - } - return Providers; - } - - private string GetClaimValue(Dictionary> claims, string type) - { - if (!claims.ContainsKey(type)) - { - return null; - } - - return claims[type].FirstOrDefault()?.Value; - } - - private Permissions SetOrganizationPermissionsFromClaims(string organizationId, Dictionary> claimsDict) - { - bool hasClaim(string claimKey) - { - return claimsDict.ContainsKey(claimKey) ? - claimsDict[claimKey].Any(x => x.Value == organizationId) : false; - } - - return new Permissions - { - AccessEventLogs = hasClaim("accesseventlogs"), - AccessImportExport = hasClaim("accessimportexport"), - AccessReports = hasClaim("accessreports"), - CreateNewCollections = hasClaim("createnewcollections"), - EditAnyCollection = hasClaim("editanycollection"), - DeleteAnyCollection = hasClaim("deleteanycollection"), - EditAssignedCollections = hasClaim("editassignedcollections"), - DeleteAssignedCollections = hasClaim("deleteassignedcollections"), - ManageGroups = hasClaim("managegroups"), - ManagePolicies = hasClaim("managepolicies"), - ManageSso = hasClaim("managesso"), - ManageUsers = hasClaim("manageusers"), - ManageResetPassword = hasClaim("manageresetpassword"), - ManageScim = hasClaim("managescim"), - }; - } - - protected async Task> GetProviderOrganizations() - { - if (_providerUserOrganizations == null && UserId.HasValue) - { - _providerUserOrganizations = await _providerUserRepository.GetManyOrganizationDetailsByUserAsync(UserId.Value, ProviderUserStatusType.Confirmed); - } - - return _providerUserOrganizations; - } + return _providerUserOrganizations; } } diff --git a/src/Core/Context/ICurrentContext.cs b/src/Core/Context/ICurrentContext.cs index d82ad12e4..b53e43dfa 100644 --- a/src/Core/Context/ICurrentContext.cs +++ b/src/Core/Context/ICurrentContext.cs @@ -5,65 +5,64 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Microsoft.AspNetCore.Http; -namespace Bit.Core.Context +namespace Bit.Core.Context; + +public interface ICurrentContext { - public interface ICurrentContext - { - HttpContext HttpContext { get; set; } - Guid? UserId { get; set; } - User User { get; set; } - string DeviceIdentifier { get; set; } - DeviceType? DeviceType { get; set; } - string IpAddress { get; set; } - List Organizations { get; set; } - Guid? InstallationId { get; set; } - Guid? OrganizationId { get; set; } - bool IsBot { get; set; } - bool MaybeBot { get; set; } - int? BotScore { get; set; } - string ClientId { get; set; } - Task BuildAsync(HttpContext httpContext, GlobalSettings globalSettings); - Task BuildAsync(ClaimsPrincipal user, GlobalSettings globalSettings); + HttpContext HttpContext { get; set; } + Guid? UserId { get; set; } + User User { get; set; } + string DeviceIdentifier { get; set; } + DeviceType? DeviceType { get; set; } + string IpAddress { get; set; } + List Organizations { get; set; } + Guid? InstallationId { get; set; } + Guid? OrganizationId { get; set; } + bool IsBot { get; set; } + bool MaybeBot { get; set; } + int? BotScore { get; set; } + string ClientId { get; set; } + Task BuildAsync(HttpContext httpContext, GlobalSettings globalSettings); + Task BuildAsync(ClaimsPrincipal user, GlobalSettings globalSettings); - Task SetContextAsync(ClaimsPrincipal user); + Task SetContextAsync(ClaimsPrincipal user); - Task OrganizationUser(Guid orgId); - Task OrganizationManager(Guid orgId); - Task OrganizationAdmin(Guid orgId); - Task OrganizationOwner(Guid orgId); - Task OrganizationCustom(Guid orgId); - Task AccessEventLogs(Guid orgId); - Task AccessImportExport(Guid orgId); - Task AccessReports(Guid orgId); - Task CreateNewCollections(Guid orgId); - Task EditAnyCollection(Guid orgId); - Task DeleteAnyCollection(Guid orgId); - Task ViewAllCollections(Guid orgId); - Task EditAssignedCollections(Guid orgId); - Task DeleteAssignedCollections(Guid orgId); - Task ViewAssignedCollections(Guid orgId); - Task ManageGroups(Guid orgId); - Task ManagePolicies(Guid orgId); - Task ManageSso(Guid orgId); - Task ManageUsers(Guid orgId); - Task ManageScim(Guid orgId); - Task ManageResetPassword(Guid orgId); - Task ManageBilling(Guid orgId); - Task ProviderUserForOrgAsync(Guid orgId); - bool ProviderProviderAdmin(Guid providerId); - bool ProviderUser(Guid providerId); - bool ProviderManageUsers(Guid providerId); - bool ProviderAccessEventLogs(Guid providerId); - bool AccessProviderOrganizations(Guid providerId); - bool ManageProviderOrganizations(Guid providerId); + Task OrganizationUser(Guid orgId); + Task OrganizationManager(Guid orgId); + Task OrganizationAdmin(Guid orgId); + Task OrganizationOwner(Guid orgId); + Task OrganizationCustom(Guid orgId); + Task AccessEventLogs(Guid orgId); + Task AccessImportExport(Guid orgId); + Task AccessReports(Guid orgId); + Task CreateNewCollections(Guid orgId); + Task EditAnyCollection(Guid orgId); + Task DeleteAnyCollection(Guid orgId); + Task ViewAllCollections(Guid orgId); + Task EditAssignedCollections(Guid orgId); + Task DeleteAssignedCollections(Guid orgId); + Task ViewAssignedCollections(Guid orgId); + Task ManageGroups(Guid orgId); + Task ManagePolicies(Guid orgId); + Task ManageSso(Guid orgId); + Task ManageUsers(Guid orgId); + Task ManageScim(Guid orgId); + Task ManageResetPassword(Guid orgId); + Task ManageBilling(Guid orgId); + Task ProviderUserForOrgAsync(Guid orgId); + bool ProviderProviderAdmin(Guid providerId); + bool ProviderUser(Guid providerId); + bool ProviderManageUsers(Guid providerId); + bool ProviderAccessEventLogs(Guid providerId); + bool AccessProviderOrganizations(Guid providerId); + bool ManageProviderOrganizations(Guid providerId); - Task> OrganizationMembershipAsync( - IOrganizationUserRepository organizationUserRepository, Guid userId); + Task> OrganizationMembershipAsync( + IOrganizationUserRepository organizationUserRepository, Guid userId); - Task> ProviderMembershipAsync( - IProviderUserRepository providerUserRepository, Guid userId); + Task> ProviderMembershipAsync( + IProviderUserRepository providerUserRepository, Guid userId); - Task ProviderIdForOrg(Guid orgId); - } + Task ProviderIdForOrg(Guid orgId); } diff --git a/src/Core/Entities/Cipher.cs b/src/Core/Entities/Cipher.cs index c4e57aa76..186a7c5b8 100644 --- a/src/Core/Entities/Cipher.cs +++ b/src/Core/Entities/Cipher.cs @@ -2,108 +2,107 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +public class Cipher : ITableObject, ICloneable { - public class Cipher : ITableObject, ICloneable + private Dictionary _attachmentData; + + public Guid Id { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public Enums.CipherType Type { get; set; } + public string Data { get; set; } + public string Favorites { get; set; } + public string Folders { get; set; } + public string Attachments { get; set; } + public DateTime CreationDate { get; set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; set; } = DateTime.UtcNow; + public DateTime? DeletedDate { get; set; } + public Enums.CipherRepromptType? Reprompt { get; set; } + + public void SetNewId() { - private Dictionary _attachmentData; + Id = CoreHelpers.GenerateComb(); + } - public Guid Id { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public Enums.CipherType Type { get; set; } - public string Data { get; set; } - public string Favorites { get; set; } - public string Folders { get; set; } - public string Attachments { get; set; } - public DateTime CreationDate { get; set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; set; } = DateTime.UtcNow; - public DateTime? DeletedDate { get; set; } - public Enums.CipherRepromptType? Reprompt { get; set; } - - public void SetNewId() + public Dictionary GetAttachments() + { + if (string.IsNullOrWhiteSpace(Attachments)) { - Id = CoreHelpers.GenerateComb(); + return null; } - public Dictionary GetAttachments() + if (_attachmentData != null) { - if (string.IsNullOrWhiteSpace(Attachments)) - { - return null; - } - - if (_attachmentData != null) - { - return _attachmentData; - } - - try - { - _attachmentData = JsonSerializer.Deserialize>(Attachments); - foreach (var kvp in _attachmentData) - { - kvp.Value.AttachmentId = kvp.Key; - } - return _attachmentData; - } - catch - { - return null; - } + return _attachmentData; } - public void SetAttachments(Dictionary data) + try { - if (data == null || data.Count == 0) + _attachmentData = JsonSerializer.Deserialize>(Attachments); + foreach (var kvp in _attachmentData) { - _attachmentData = null; - Attachments = null; - return; + kvp.Value.AttachmentId = kvp.Key; } - - _attachmentData = data; - Attachments = JsonSerializer.Serialize(_attachmentData); + return _attachmentData; } - - public void AddAttachment(string id, CipherAttachment.MetaData data) + catch { - var attachments = GetAttachments(); - if (attachments == null) - { - attachments = new Dictionary(); - } - - attachments.Add(id, data); - SetAttachments(attachments); - } - - public void DeleteAttachment(string id) - { - var attachments = GetAttachments(); - if (!attachments?.ContainsKey(id) ?? true) - { - return; - } - - attachments.Remove(id); - SetAttachments(attachments); - } - - public bool ContainsAttachment(string id) - { - var attachments = GetAttachments(); - return attachments?.ContainsKey(id) ?? false; - } - - object ICloneable.Clone() => Clone(); - public Cipher Clone() - { - var clone = CoreHelpers.CloneObject(this); - clone.CreationDate = CreationDate; - clone.RevisionDate = RevisionDate; - - return clone; + return null; } } + + public void SetAttachments(Dictionary data) + { + if (data == null || data.Count == 0) + { + _attachmentData = null; + Attachments = null; + return; + } + + _attachmentData = data; + Attachments = JsonSerializer.Serialize(_attachmentData); + } + + public void AddAttachment(string id, CipherAttachment.MetaData data) + { + var attachments = GetAttachments(); + if (attachments == null) + { + attachments = new Dictionary(); + } + + attachments.Add(id, data); + SetAttachments(attachments); + } + + public void DeleteAttachment(string id) + { + var attachments = GetAttachments(); + if (!attachments?.ContainsKey(id) ?? true) + { + return; + } + + attachments.Remove(id); + SetAttachments(attachments); + } + + public bool ContainsAttachment(string id) + { + var attachments = GetAttachments(); + return attachments?.ContainsKey(id) ?? false; + } + + object ICloneable.Clone() => Clone(); + public Cipher Clone() + { + var clone = CoreHelpers.CloneObject(this); + clone.CreationDate = CreationDate; + clone.RevisionDate = RevisionDate; + + return clone; + } } diff --git a/src/Core/Entities/Collection.cs b/src/Core/Entities/Collection.cs index fb6e646fc..fb7225fc2 100644 --- a/src/Core/Entities/Collection.cs +++ b/src/Core/Entities/Collection.cs @@ -1,21 +1,20 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Utilities; -namespace Bit.Core.Entities -{ - public class Collection : ITableObject - { - public Guid Id { get; set; } - public Guid OrganizationId { get; set; } - public string Name { get; set; } - [MaxLength(300)] - public string ExternalId { get; set; } - public DateTime CreationDate { get; set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; set; } = DateTime.UtcNow; +namespace Bit.Core.Entities; - public void SetNewId() - { - Id = CoreHelpers.GenerateComb(); - } +public class Collection : ITableObject +{ + public Guid Id { get; set; } + public Guid OrganizationId { get; set; } + public string Name { get; set; } + [MaxLength(300)] + public string ExternalId { get; set; } + public DateTime CreationDate { get; set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; set; } = DateTime.UtcNow; + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); } } diff --git a/src/Core/Entities/CollectionCipher.cs b/src/Core/Entities/CollectionCipher.cs index f04c2bdf4..d212ced51 100644 --- a/src/Core/Entities/CollectionCipher.cs +++ b/src/Core/Entities/CollectionCipher.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +public class CollectionCipher { - public class CollectionCipher - { - public Guid CollectionId { get; set; } - public Guid CipherId { get; set; } - } + public Guid CollectionId { get; set; } + public Guid CipherId { get; set; } } diff --git a/src/Core/Entities/CollectionGroup.cs b/src/Core/Entities/CollectionGroup.cs index c68ae3005..8224aed46 100644 --- a/src/Core/Entities/CollectionGroup.cs +++ b/src/Core/Entities/CollectionGroup.cs @@ -1,10 +1,9 @@ -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +public class CollectionGroup { - public class CollectionGroup - { - public Guid CollectionId { get; set; } - public Guid GroupId { get; set; } - public bool ReadOnly { get; set; } - public bool HidePasswords { get; set; } - } + public Guid CollectionId { get; set; } + public Guid GroupId { get; set; } + public bool ReadOnly { get; set; } + public bool HidePasswords { get; set; } } diff --git a/src/Core/Entities/CollectionUser.cs b/src/Core/Entities/CollectionUser.cs index 5b5d01fcc..bb22e7b7c 100644 --- a/src/Core/Entities/CollectionUser.cs +++ b/src/Core/Entities/CollectionUser.cs @@ -1,10 +1,9 @@ -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +public class CollectionUser { - public class CollectionUser - { - public Guid CollectionId { get; set; } - public Guid OrganizationUserId { get; set; } - public bool ReadOnly { get; set; } - public bool HidePasswords { get; set; } - } + public Guid CollectionId { get; set; } + public Guid OrganizationUserId { get; set; } + public bool ReadOnly { get; set; } + public bool HidePasswords { get; set; } } diff --git a/src/Core/Entities/Device.cs b/src/Core/Entities/Device.cs index 9cca56c3f..3b5fb1a24 100644 --- a/src/Core/Entities/Device.cs +++ b/src/Core/Entities/Device.cs @@ -1,25 +1,24 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Utilities; -namespace Bit.Core.Entities -{ - public class Device : ITableObject - { - public Guid Id { get; set; } - public Guid UserId { get; set; } - [MaxLength(50)] - public string Name { get; set; } - public Enums.DeviceType Type { get; set; } - [MaxLength(50)] - public string Identifier { get; set; } - [MaxLength(255)] - public string PushToken { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; +namespace Bit.Core.Entities; - public void SetNewId() - { - Id = CoreHelpers.GenerateComb(); - } +public class Device : ITableObject +{ + public Guid Id { get; set; } + public Guid UserId { get; set; } + [MaxLength(50)] + public string Name { get; set; } + public Enums.DeviceType Type { get; set; } + [MaxLength(50)] + public string Identifier { get; set; } + [MaxLength(255)] + public string PushToken { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); } } diff --git a/src/Core/Entities/EmergencyAccess.cs b/src/Core/Entities/EmergencyAccess.cs index eafd9ee8e..e78f90e66 100644 --- a/src/Core/Entities/EmergencyAccess.cs +++ b/src/Core/Entities/EmergencyAccess.cs @@ -2,46 +2,45 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +public class EmergencyAccess : ITableObject { - public class EmergencyAccess : ITableObject + public Guid Id { get; set; } + public Guid GrantorId { get; set; } + public Guid? GranteeId { get; set; } + [MaxLength(256)] + public string Email { get; set; } + public string KeyEncrypted { get; set; } + public EmergencyAccessType Type { get; set; } + public EmergencyAccessStatusType Status { get; set; } + public int WaitTimeDays { get; set; } + public DateTime? RecoveryInitiatedDate { get; set; } + public DateTime? LastNotificationDate { get; set; } + public DateTime CreationDate { get; set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; set; } = DateTime.UtcNow; + + public void SetNewId() { - public Guid Id { get; set; } - public Guid GrantorId { get; set; } - public Guid? GranteeId { get; set; } - [MaxLength(256)] - public string Email { get; set; } - public string KeyEncrypted { get; set; } - public EmergencyAccessType Type { get; set; } - public EmergencyAccessStatusType Status { get; set; } - public int WaitTimeDays { get; set; } - public DateTime? RecoveryInitiatedDate { get; set; } - public DateTime? LastNotificationDate { get; set; } - public DateTime CreationDate { get; set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; set; } = DateTime.UtcNow; + Id = CoreHelpers.GenerateComb(); + } - public void SetNewId() + public EmergencyAccess ToEmergencyAccess() + { + return new EmergencyAccess { - Id = CoreHelpers.GenerateComb(); - } - - public EmergencyAccess ToEmergencyAccess() - { - return new EmergencyAccess - { - Id = Id, - GrantorId = GrantorId, - GranteeId = GranteeId, - Email = Email, - KeyEncrypted = KeyEncrypted, - Type = Type, - Status = Status, - WaitTimeDays = WaitTimeDays, - RecoveryInitiatedDate = RecoveryInitiatedDate, - LastNotificationDate = LastNotificationDate, - CreationDate = CreationDate, - RevisionDate = RevisionDate, - }; - } + Id = Id, + GrantorId = GrantorId, + GranteeId = GranteeId, + Email = Email, + KeyEncrypted = KeyEncrypted, + Type = Type, + Status = Status, + WaitTimeDays = WaitTimeDays, + RecoveryInitiatedDate = RecoveryInitiatedDate, + LastNotificationDate = LastNotificationDate, + CreationDate = CreationDate, + RevisionDate = RevisionDate, + }; } } diff --git a/src/Core/Entities/Event.cs b/src/Core/Entities/Event.cs index d17116ce6..99e2091c9 100644 --- a/src/Core/Entities/Event.cs +++ b/src/Core/Entities/Event.cs @@ -3,54 +3,53 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Utilities; -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +public class Event : ITableObject, IEvent { - public class Event : ITableObject, IEvent + public Event() { } + + public Event(IEvent e) { - public Event() { } + Date = e.Date; + Type = e.Type; + UserId = e.UserId; + OrganizationId = e.OrganizationId; + ProviderId = e.ProviderId; + CipherId = e.CipherId; + CollectionId = e.CollectionId; + PolicyId = e.PolicyId; + GroupId = e.GroupId; + OrganizationUserId = e.OrganizationUserId; + InstallationId = e.InstallationId; + ProviderUserId = e.ProviderUserId; + ProviderOrganizationId = e.ProviderOrganizationId; + DeviceType = e.DeviceType; + IpAddress = e.IpAddress; + ActingUserId = e.ActingUserId; + } - public Event(IEvent e) - { - Date = e.Date; - Type = e.Type; - UserId = e.UserId; - OrganizationId = e.OrganizationId; - ProviderId = e.ProviderId; - CipherId = e.CipherId; - CollectionId = e.CollectionId; - PolicyId = e.PolicyId; - GroupId = e.GroupId; - OrganizationUserId = e.OrganizationUserId; - InstallationId = e.InstallationId; - ProviderUserId = e.ProviderUserId; - ProviderOrganizationId = e.ProviderOrganizationId; - DeviceType = e.DeviceType; - IpAddress = e.IpAddress; - ActingUserId = e.ActingUserId; - } + public Guid Id { get; set; } + public DateTime Date { get; set; } + public EventType Type { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public Guid? InstallationId { get; set; } + public Guid? ProviderId { get; set; } + public Guid? CipherId { get; set; } + public Guid? CollectionId { get; set; } + public Guid? PolicyId { get; set; } + public Guid? GroupId { get; set; } + public Guid? OrganizationUserId { get; set; } + public Guid? ProviderUserId { get; set; } + public Guid? ProviderOrganizationId { get; set; } + public DeviceType? DeviceType { get; set; } + [MaxLength(50)] + public string IpAddress { get; set; } + public Guid? ActingUserId { get; set; } - public Guid Id { get; set; } - public DateTime Date { get; set; } - public EventType Type { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public Guid? InstallationId { get; set; } - public Guid? ProviderId { get; set; } - public Guid? CipherId { get; set; } - public Guid? CollectionId { get; set; } - public Guid? PolicyId { get; set; } - public Guid? GroupId { get; set; } - public Guid? OrganizationUserId { get; set; } - public Guid? ProviderUserId { get; set; } - public Guid? ProviderOrganizationId { get; set; } - public DeviceType? DeviceType { get; set; } - [MaxLength(50)] - public string IpAddress { get; set; } - public Guid? ActingUserId { get; set; } - - public void SetNewId() - { - Id = CoreHelpers.GenerateComb(); - } + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); } } diff --git a/src/Core/Entities/Folder.cs b/src/Core/Entities/Folder.cs index 5fc97a3e5..fd6d4dafa 100644 --- a/src/Core/Entities/Folder.cs +++ b/src/Core/Entities/Folder.cs @@ -1,18 +1,17 @@ using Bit.Core.Utilities; -namespace Bit.Core.Entities -{ - public class Folder : ITableObject - { - public Guid Id { get; set; } - public Guid UserId { get; set; } - public string Name { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; +namespace Bit.Core.Entities; - public void SetNewId() - { - Id = CoreHelpers.GenerateComb(); - } +public class Folder : ITableObject +{ + public Guid Id { get; set; } + public Guid UserId { get; set; } + public string Name { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); } } diff --git a/src/Core/Entities/Grant.cs b/src/Core/Entities/Grant.cs index f2bd464fb..f66ff1134 100644 --- a/src/Core/Entities/Grant.cs +++ b/src/Core/Entities/Grant.cs @@ -1,24 +1,23 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +public class Grant { - public class Grant - { - [MaxLength(200)] - public string Key { get; set; } - [MaxLength(50)] - public string Type { get; set; } - [MaxLength(200)] - public string SubjectId { get; set; } - [MaxLength(100)] - public string SessionId { get; set; } - [MaxLength(200)] - public string ClientId { get; set; } - [MaxLength(200)] - public string Description { get; set; } - public DateTime CreationDate { get; set; } - public DateTime? ExpirationDate { get; set; } - public DateTime? ConsumedDate { get; set; } - public string Data { get; set; } - } + [MaxLength(200)] + public string Key { get; set; } + [MaxLength(50)] + public string Type { get; set; } + [MaxLength(200)] + public string SubjectId { get; set; } + [MaxLength(100)] + public string SessionId { get; set; } + [MaxLength(200)] + public string ClientId { get; set; } + [MaxLength(200)] + public string Description { get; set; } + public DateTime CreationDate { get; set; } + public DateTime? ExpirationDate { get; set; } + public DateTime? ConsumedDate { get; set; } + public string Data { get; set; } } diff --git a/src/Core/Entities/Group.cs b/src/Core/Entities/Group.cs index 0ca760cff..3c15380fa 100644 --- a/src/Core/Entities/Group.cs +++ b/src/Core/Entities/Group.cs @@ -2,23 +2,22 @@ using Bit.Core.Models; using Bit.Core.Utilities; -namespace Bit.Core.Entities -{ - public class Group : ITableObject, IExternal - { - public Guid Id { get; set; } - public Guid OrganizationId { get; set; } - [MaxLength(100)] - public string Name { get; set; } - public bool AccessAll { get; set; } - [MaxLength(300)] - public string ExternalId { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; +namespace Bit.Core.Entities; - public void SetNewId() - { - Id = CoreHelpers.GenerateComb(); - } +public class Group : ITableObject, IExternal +{ + public Guid Id { get; set; } + public Guid OrganizationId { get; set; } + [MaxLength(100)] + public string Name { get; set; } + public bool AccessAll { get; set; } + [MaxLength(300)] + public string ExternalId { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); } } diff --git a/src/Core/Entities/GroupUser.cs b/src/Core/Entities/GroupUser.cs index c7933d5e7..3497c2c74 100644 --- a/src/Core/Entities/GroupUser.cs +++ b/src/Core/Entities/GroupUser.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +public class GroupUser { - public class GroupUser - { - public Guid GroupId { get; set; } - public Guid OrganizationUserId { get; set; } - } + public Guid GroupId { get; set; } + public Guid OrganizationUserId { get; set; } } diff --git a/src/Core/Entities/IReferenceable.cs b/src/Core/Entities/IReferenceable.cs index a5373978d..79837781e 100644 --- a/src/Core/Entities/IReferenceable.cs +++ b/src/Core/Entities/IReferenceable.cs @@ -1,9 +1,8 @@ -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +public interface IReferenceable { - public interface IReferenceable - { - Guid Id { get; set; } - string ReferenceData { get; set; } - bool IsUser(); - } + Guid Id { get; set; } + string ReferenceData { get; set; } + bool IsUser(); } diff --git a/src/Core/Entities/IRevisable.cs b/src/Core/Entities/IRevisable.cs index 6de7478c0..bba3b3c94 100644 --- a/src/Core/Entities/IRevisable.cs +++ b/src/Core/Entities/IRevisable.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +public interface IRevisable { - public interface IRevisable - { - DateTime CreationDate { get; } - DateTime RevisionDate { get; } - } + DateTime CreationDate { get; } + DateTime RevisionDate { get; } } diff --git a/src/Core/Entities/IStorable.cs b/src/Core/Entities/IStorable.cs index 67c16098f..fd0da49fe 100644 --- a/src/Core/Entities/IStorable.cs +++ b/src/Core/Entities/IStorable.cs @@ -1,10 +1,9 @@ -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +public interface IStorable { - public interface IStorable - { - long? Storage { get; set; } - short? MaxStorageGb { get; set; } - long StorageBytesRemaining(); - long StorageBytesRemaining(short maxStorageGb); - } + long? Storage { get; set; } + short? MaxStorageGb { get; set; } + long StorageBytesRemaining(); + long StorageBytesRemaining(short maxStorageGb); } diff --git a/src/Core/Entities/IStorableSubscriber.cs b/src/Core/Entities/IStorableSubscriber.cs index e37966dea..27fcb25f6 100644 --- a/src/Core/Entities/IStorableSubscriber.cs +++ b/src/Core/Entities/IStorableSubscriber.cs @@ -1,5 +1,4 @@ -namespace Bit.Core.Entities -{ - public interface IStorableSubscriber : IStorable, ISubscriber - { } -} +namespace Bit.Core.Entities; + +public interface IStorableSubscriber : IStorable, ISubscriber +{ } diff --git a/src/Core/Entities/ISubscriber.cs b/src/Core/Entities/ISubscriber.cs index 1c80ffc20..6753e648e 100644 --- a/src/Core/Entities/ISubscriber.cs +++ b/src/Core/Entities/ISubscriber.cs @@ -1,18 +1,17 @@ using Bit.Core.Enums; -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +public interface ISubscriber { - public interface ISubscriber - { - Guid Id { get; } - GatewayType? Gateway { get; set; } - string GatewayCustomerId { get; set; } - string GatewaySubscriptionId { get; set; } - string BillingEmailAddress(); - string BillingName(); - string BraintreeCustomerIdPrefix(); - string BraintreeIdField(); - string GatewayIdField(); - bool IsUser(); - } + Guid Id { get; } + GatewayType? Gateway { get; set; } + string GatewayCustomerId { get; set; } + string GatewaySubscriptionId { get; set; } + string BillingEmailAddress(); + string BillingName(); + string BraintreeCustomerIdPrefix(); + string BraintreeIdField(); + string GatewayIdField(); + bool IsUser(); } diff --git a/src/Core/Entities/ITableObject.cs b/src/Core/Entities/ITableObject.cs index f9ecb864b..1f54b8cc1 100644 --- a/src/Core/Entities/ITableObject.cs +++ b/src/Core/Entities/ITableObject.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +public interface ITableObject where T : IEquatable { - public interface ITableObject where T : IEquatable - { - T Id { get; set; } - void SetNewId(); - } + T Id { get; set; } + void SetNewId(); } diff --git a/src/Core/Entities/Installation.cs b/src/Core/Entities/Installation.cs index 36966d861..a91ecef2e 100644 --- a/src/Core/Entities/Installation.cs +++ b/src/Core/Entities/Installation.cs @@ -1,21 +1,20 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Utilities; -namespace Bit.Core.Entities -{ - public class Installation : ITableObject - { - public Guid Id { get; set; } - [MaxLength(256)] - public string Email { get; set; } - [MaxLength(150)] - public string Key { get; set; } - public bool Enabled { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; +namespace Bit.Core.Entities; - public void SetNewId() - { - Id = CoreHelpers.GenerateComb(); - } +public class Installation : ITableObject +{ + public Guid Id { get; set; } + [MaxLength(256)] + public string Email { get; set; } + [MaxLength(150)] + public string Key { get; set; } + public bool Enabled { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); } } diff --git a/src/Core/Entities/Organization.cs b/src/Core/Entities/Organization.cs index 818db3230..823eb5baf 100644 --- a/src/Core/Entities/Organization.cs +++ b/src/Core/Entities/Organization.cs @@ -4,196 +4,195 @@ using Bit.Core.Enums; using Bit.Core.Models; using Bit.Core.Utilities; -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +public class Organization : ITableObject, ISubscriber, IStorable, IStorableSubscriber, IRevisable, IReferenceable { - public class Organization : ITableObject, ISubscriber, IStorable, IStorableSubscriber, IRevisable, IReferenceable + private Dictionary _twoFactorProviders; + + public Guid Id { get; set; } + [MaxLength(50)] + public string Identifier { get; set; } + [MaxLength(50)] + public string Name { get; set; } + [MaxLength(50)] + public string BusinessName { get; set; } + [MaxLength(50)] + public string BusinessAddress1 { get; set; } + [MaxLength(50)] + public string BusinessAddress2 { get; set; } + [MaxLength(50)] + public string BusinessAddress3 { get; set; } + [MaxLength(2)] + public string BusinessCountry { get; set; } + [MaxLength(30)] + public string BusinessTaxNumber { get; set; } + [MaxLength(256)] + public string BillingEmail { get; set; } + [MaxLength(50)] + public string Plan { get; set; } + public PlanType PlanType { get; set; } + public int? Seats { get; set; } + public short? MaxCollections { get; set; } + public bool UsePolicies { get; set; } + public bool UseSso { get; set; } + public bool UseKeyConnector { get; set; } + public bool UseScim { get; set; } + public bool UseGroups { get; set; } + public bool UseDirectory { get; set; } + public bool UseEvents { get; set; } + public bool UseTotp { get; set; } + public bool Use2fa { get; set; } + public bool UseApi { get; set; } + public bool UseResetPassword { get; set; } + public bool SelfHost { get; set; } + public bool UsersGetPremium { get; set; } + public long? Storage { get; set; } + public short? MaxStorageGb { get; set; } + public GatewayType? Gateway { get; set; } + [MaxLength(50)] + public string GatewayCustomerId { get; set; } + [MaxLength(50)] + public string GatewaySubscriptionId { get; set; } + public string ReferenceData { get; set; } + public bool Enabled { get; set; } = true; + [MaxLength(100)] + public string LicenseKey { get; set; } + public string PublicKey { get; set; } + public string PrivateKey { get; set; } + public string TwoFactorProviders { get; set; } + public DateTime? ExpirationDate { get; set; } + public DateTime CreationDate { get; set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; set; } = DateTime.UtcNow; + public int? MaxAutoscaleSeats { get; set; } = null; + public DateTime? OwnersNotifiedOfAutoscaling { get; set; } = null; + + public void SetNewId() { - private Dictionary _twoFactorProviders; - - public Guid Id { get; set; } - [MaxLength(50)] - public string Identifier { get; set; } - [MaxLength(50)] - public string Name { get; set; } - [MaxLength(50)] - public string BusinessName { get; set; } - [MaxLength(50)] - public string BusinessAddress1 { get; set; } - [MaxLength(50)] - public string BusinessAddress2 { get; set; } - [MaxLength(50)] - public string BusinessAddress3 { get; set; } - [MaxLength(2)] - public string BusinessCountry { get; set; } - [MaxLength(30)] - public string BusinessTaxNumber { get; set; } - [MaxLength(256)] - public string BillingEmail { get; set; } - [MaxLength(50)] - public string Plan { get; set; } - public PlanType PlanType { get; set; } - public int? Seats { get; set; } - public short? MaxCollections { get; set; } - public bool UsePolicies { get; set; } - public bool UseSso { get; set; } - public bool UseKeyConnector { get; set; } - public bool UseScim { get; set; } - public bool UseGroups { get; set; } - public bool UseDirectory { get; set; } - public bool UseEvents { get; set; } - public bool UseTotp { get; set; } - public bool Use2fa { get; set; } - public bool UseApi { get; set; } - public bool UseResetPassword { get; set; } - public bool SelfHost { get; set; } - public bool UsersGetPremium { get; set; } - public long? Storage { get; set; } - public short? MaxStorageGb { get; set; } - public GatewayType? Gateway { get; set; } - [MaxLength(50)] - public string GatewayCustomerId { get; set; } - [MaxLength(50)] - public string GatewaySubscriptionId { get; set; } - public string ReferenceData { get; set; } - public bool Enabled { get; set; } = true; - [MaxLength(100)] - public string LicenseKey { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } - public string TwoFactorProviders { get; set; } - public DateTime? ExpirationDate { get; set; } - public DateTime CreationDate { get; set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; set; } = DateTime.UtcNow; - public int? MaxAutoscaleSeats { get; set; } = null; - public DateTime? OwnersNotifiedOfAutoscaling { get; set; } = null; - - public void SetNewId() + if (Id == default(Guid)) { - if (Id == default(Guid)) + Id = CoreHelpers.GenerateComb(); + } + } + + public string BillingEmailAddress() + { + return BillingEmail?.ToLowerInvariant()?.Trim(); + } + + public string BillingName() + { + return BusinessName; + } + + public string BraintreeCustomerIdPrefix() + { + return "o"; + } + + public string BraintreeIdField() + { + return "organization_id"; + } + + public string GatewayIdField() + { + return "organizationId"; + } + + public bool IsUser() + { + return false; + } + + public long StorageBytesRemaining() + { + if (!MaxStorageGb.HasValue) + { + return 0; + } + + return StorageBytesRemaining(MaxStorageGb.Value); + } + + public long StorageBytesRemaining(short maxStorageGb) + { + var maxStorageBytes = maxStorageGb * 1073741824L; + if (!Storage.HasValue) + { + return maxStorageBytes; + } + + return maxStorageBytes - Storage.Value; + } + + public Dictionary GetTwoFactorProviders() + { + if (string.IsNullOrWhiteSpace(TwoFactorProviders)) + { + return null; + } + + try + { + if (_twoFactorProviders == null) { - Id = CoreHelpers.GenerateComb(); + _twoFactorProviders = + JsonHelpers.LegacyDeserialize>( + TwoFactorProviders); } - } - public string BillingEmailAddress() + return _twoFactorProviders; + } + catch (JsonException) { - return BillingEmail?.ToLowerInvariant()?.Trim(); + return null; } + } - public string BillingName() + public void SetTwoFactorProviders(Dictionary providers) + { + if (!providers.Any()) { - return BusinessName; + TwoFactorProviders = null; + _twoFactorProviders = null; + return; } - public string BraintreeCustomerIdPrefix() - { - return "o"; - } + TwoFactorProviders = JsonHelpers.LegacySerialize(providers, JsonHelpers.LegacyEnumKeyResolver); + _twoFactorProviders = providers; + } - public string BraintreeIdField() - { - return "organization_id"; - } - - public string GatewayIdField() - { - return "organizationId"; - } - - public bool IsUser() + public bool TwoFactorProviderIsEnabled(TwoFactorProviderType provider) + { + var providers = GetTwoFactorProviders(); + if (providers == null || !providers.ContainsKey(provider)) { return false; } - public long StorageBytesRemaining() - { - if (!MaxStorageGb.HasValue) - { - return 0; - } + return providers[provider].Enabled && Use2fa; + } - return StorageBytesRemaining(MaxStorageGb.Value); + public bool TwoFactorIsEnabled() + { + var providers = GetTwoFactorProviders(); + if (providers == null) + { + return false; } - public long StorageBytesRemaining(short maxStorageGb) - { - var maxStorageBytes = maxStorageGb * 1073741824L; - if (!Storage.HasValue) - { - return maxStorageBytes; - } + return providers.Any(p => (p.Value?.Enabled ?? false) && Use2fa); + } - return maxStorageBytes - Storage.Value; + public TwoFactorProvider GetTwoFactorProvider(TwoFactorProviderType provider) + { + var providers = GetTwoFactorProviders(); + if (providers == null || !providers.ContainsKey(provider)) + { + return null; } - public Dictionary GetTwoFactorProviders() - { - if (string.IsNullOrWhiteSpace(TwoFactorProviders)) - { - return null; - } - - try - { - if (_twoFactorProviders == null) - { - _twoFactorProviders = - JsonHelpers.LegacyDeserialize>( - TwoFactorProviders); - } - - return _twoFactorProviders; - } - catch (JsonException) - { - return null; - } - } - - public void SetTwoFactorProviders(Dictionary providers) - { - if (!providers.Any()) - { - TwoFactorProviders = null; - _twoFactorProviders = null; - return; - } - - TwoFactorProviders = JsonHelpers.LegacySerialize(providers, JsonHelpers.LegacyEnumKeyResolver); - _twoFactorProviders = providers; - } - - public bool TwoFactorProviderIsEnabled(TwoFactorProviderType provider) - { - var providers = GetTwoFactorProviders(); - if (providers == null || !providers.ContainsKey(provider)) - { - return false; - } - - return providers[provider].Enabled && Use2fa; - } - - public bool TwoFactorIsEnabled() - { - var providers = GetTwoFactorProviders(); - if (providers == null) - { - return false; - } - - return providers.Any(p => (p.Value?.Enabled ?? false) && Use2fa); - } - - public TwoFactorProvider GetTwoFactorProvider(TwoFactorProviderType provider) - { - var providers = GetTwoFactorProviders(); - if (providers == null || !providers.ContainsKey(provider)) - { - return null; - } - - return providers[provider]; - } + return providers[provider]; } } diff --git a/src/Core/Entities/OrganizationApiKey.cs b/src/Core/Entities/OrganizationApiKey.cs index f3a71bde2..af9f3c912 100644 --- a/src/Core/Entities/OrganizationApiKey.cs +++ b/src/Core/Entities/OrganizationApiKey.cs @@ -2,20 +2,19 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.Entities -{ - public class OrganizationApiKey : ITableObject - { - public Guid Id { get; set; } - public Guid OrganizationId { get; set; } - public OrganizationApiKeyType Type { get; set; } - [MaxLength(30)] - public string ApiKey { get; set; } - public DateTime RevisionDate { get; set; } +namespace Bit.Core.Entities; - public void SetNewId() - { - Id = CoreHelpers.GenerateComb(); - } +public class OrganizationApiKey : ITableObject +{ + public Guid Id { get; set; } + public Guid OrganizationId { get; set; } + public OrganizationApiKeyType Type { get; set; } + [MaxLength(30)] + public string ApiKey { get; set; } + public DateTime RevisionDate { get; set; } + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); } } diff --git a/src/Core/Entities/OrganizationConnection.cs b/src/Core/Entities/OrganizationConnection.cs index 804913fd6..cc0717738 100644 --- a/src/Core/Entities/OrganizationConnection.cs +++ b/src/Core/Entities/OrganizationConnection.cs @@ -2,45 +2,44 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +public class OrganizationConnection : OrganizationConnection where T : new() { - public class OrganizationConnection : OrganizationConnection where T : new() + public new T Config { - public new T Config - { - get => base.GetConfig(); - set => base.SetConfig(value); - } - } - - public class OrganizationConnection : ITableObject - { - public Guid Id { get; set; } - public OrganizationConnectionType Type { get; set; } - public Guid OrganizationId { get; set; } - public bool Enabled { get; set; } - public string Config { get; set; } - - public void SetNewId() - { - Id = CoreHelpers.GenerateComb(); - } - - public T GetConfig() where T : new() - { - try - { - return JsonSerializer.Deserialize(Config); - } - catch (JsonException) - { - return default; - } - } - - public void SetConfig(T config) where T : new() - { - Config = JsonSerializer.Serialize(config); - } + get => base.GetConfig(); + set => base.SetConfig(value); + } +} + +public class OrganizationConnection : ITableObject +{ + public Guid Id { get; set; } + public OrganizationConnectionType Type { get; set; } + public Guid OrganizationId { get; set; } + public bool Enabled { get; set; } + public string Config { get; set; } + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); + } + + public T GetConfig() where T : new() + { + try + { + return JsonSerializer.Deserialize(Config); + } + catch (JsonException) + { + return default; + } + } + + public void SetConfig(T config) where T : new() + { + Config = JsonSerializer.Serialize(config); } } diff --git a/src/Core/Entities/OrganizationSponsorship.cs b/src/Core/Entities/OrganizationSponsorship.cs index 27d07e8f7..8d747bd62 100644 --- a/src/Core/Entities/OrganizationSponsorship.cs +++ b/src/Core/Entities/OrganizationSponsorship.cs @@ -2,26 +2,25 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.Entities -{ - public class OrganizationSponsorship : ITableObject - { - public Guid Id { get; set; } - public Guid? SponsoringOrganizationId { get; set; } - public Guid SponsoringOrganizationUserId { get; set; } - public Guid? SponsoredOrganizationId { get; set; } - [MaxLength(256)] - public string FriendlyName { get; set; } - [MaxLength(256)] - public string OfferedToEmail { get; set; } - public PlanSponsorshipType? PlanSponsorshipType { get; set; } - public DateTime? LastSyncDate { get; set; } - public DateTime? ValidUntil { get; set; } - public bool ToDelete { get; set; } +namespace Bit.Core.Entities; - public void SetNewId() - { - Id = CoreHelpers.GenerateComb(); - } +public class OrganizationSponsorship : ITableObject +{ + public Guid Id { get; set; } + public Guid? SponsoringOrganizationId { get; set; } + public Guid SponsoringOrganizationUserId { get; set; } + public Guid? SponsoredOrganizationId { get; set; } + [MaxLength(256)] + public string FriendlyName { get; set; } + [MaxLength(256)] + public string OfferedToEmail { get; set; } + public PlanSponsorshipType? PlanSponsorshipType { get; set; } + public DateTime? LastSyncDate { get; set; } + public DateTime? ValidUntil { get; set; } + public bool ToDelete { get; set; } + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); } } diff --git a/src/Core/Entities/OrganizationUser.cs b/src/Core/Entities/OrganizationUser.cs index 390374dec..ee1bdc15d 100644 --- a/src/Core/Entities/OrganizationUser.cs +++ b/src/Core/Entities/OrganizationUser.cs @@ -3,29 +3,28 @@ using Bit.Core.Enums; using Bit.Core.Models; using Bit.Core.Utilities; -namespace Bit.Core.Entities -{ - public class OrganizationUser : ITableObject, IExternal - { - public Guid Id { get; set; } - public Guid OrganizationId { get; set; } - public Guid? UserId { get; set; } - [MaxLength(256)] - public string Email { get; set; } - public string Key { get; set; } - public string ResetPasswordKey { get; set; } - public OrganizationUserStatusType Status { get; set; } - public OrganizationUserType Type { get; set; } - public bool AccessAll { get; set; } - [MaxLength(300)] - public string ExternalId { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; - public string Permissions { get; set; } +namespace Bit.Core.Entities; - public void SetNewId() - { - Id = CoreHelpers.GenerateComb(); - } +public class OrganizationUser : ITableObject, IExternal +{ + public Guid Id { get; set; } + public Guid OrganizationId { get; set; } + public Guid? UserId { get; set; } + [MaxLength(256)] + public string Email { get; set; } + public string Key { get; set; } + public string ResetPasswordKey { get; set; } + public OrganizationUserStatusType Status { get; set; } + public OrganizationUserType Type { get; set; } + public bool AccessAll { get; set; } + [MaxLength(300)] + public string ExternalId { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; + public string Permissions { get; set; } + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); } } diff --git a/src/Core/Entities/Policy.cs b/src/Core/Entities/Policy.cs index 7a5f95871..4863b8ccc 100644 --- a/src/Core/Entities/Policy.cs +++ b/src/Core/Entities/Policy.cs @@ -2,31 +2,30 @@ using Bit.Core.Models.Data.Organizations.Policies; using Bit.Core.Utilities; -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +public class Policy : ITableObject { - public class Policy : ITableObject + public Guid Id { get; set; } + public Guid OrganizationId { get; set; } + public PolicyType Type { get; set; } + public string Data { get; set; } + public bool Enabled { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; + + public void SetNewId() { - public Guid Id { get; set; } - public Guid OrganizationId { get; set; } - public PolicyType Type { get; set; } - public string Data { get; set; } - public bool Enabled { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; + Id = CoreHelpers.GenerateComb(); + } - public void SetNewId() - { - Id = CoreHelpers.GenerateComb(); - } + public T GetDataModel() where T : IPolicyDataModel, new() + { + return CoreHelpers.LoadClassFromJsonData(Data); + } - public T GetDataModel() where T : IPolicyDataModel, new() - { - return CoreHelpers.LoadClassFromJsonData(Data); - } - - public void SetDataModel(T dataModel) where T : IPolicyDataModel, new() - { - Data = CoreHelpers.ClassToJsonData(dataModel); - } + public void SetDataModel(T dataModel) where T : IPolicyDataModel, new() + { + Data = CoreHelpers.ClassToJsonData(dataModel); } } diff --git a/src/Core/Entities/Provider/Provider.cs b/src/Core/Entities/Provider/Provider.cs index 95da01f93..440be7d43 100644 --- a/src/Core/Entities/Provider/Provider.cs +++ b/src/Core/Entities/Provider/Provider.cs @@ -1,31 +1,30 @@ using Bit.Core.Enums.Provider; using Bit.Core.Utilities; -namespace Bit.Core.Entities.Provider -{ - public class Provider : ITableObject - { - public Guid Id { get; set; } - public string Name { get; set; } - public string BusinessName { get; set; } - public string BusinessAddress1 { get; set; } - public string BusinessAddress2 { get; set; } - public string BusinessAddress3 { get; set; } - public string BusinessCountry { get; set; } - public string BusinessTaxNumber { get; set; } - public string BillingEmail { get; set; } - public ProviderStatusType Status { get; set; } - public bool UseEvents { get; set; } - public bool Enabled { get; set; } = true; - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; +namespace Bit.Core.Entities.Provider; - public void SetNewId() +public class Provider : ITableObject +{ + public Guid Id { get; set; } + public string Name { get; set; } + public string BusinessName { get; set; } + public string BusinessAddress1 { get; set; } + public string BusinessAddress2 { get; set; } + public string BusinessAddress3 { get; set; } + public string BusinessCountry { get; set; } + public string BusinessTaxNumber { get; set; } + public string BillingEmail { get; set; } + public ProviderStatusType Status { get; set; } + public bool UseEvents { get; set; } + public bool Enabled { get; set; } = true; + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; + + public void SetNewId() + { + if (Id == default) { - if (Id == default) - { - Id = CoreHelpers.GenerateComb(); - } + Id = CoreHelpers.GenerateComb(); } } } diff --git a/src/Core/Entities/Provider/ProviderOrganization.cs b/src/Core/Entities/Provider/ProviderOrganization.cs index 6bb1eec54..6cafef67b 100644 --- a/src/Core/Entities/Provider/ProviderOrganization.cs +++ b/src/Core/Entities/Provider/ProviderOrganization.cs @@ -1,23 +1,22 @@ using Bit.Core.Utilities; -namespace Bit.Core.Entities.Provider -{ - public class ProviderOrganization : ITableObject - { - public Guid Id { get; set; } - public Guid ProviderId { get; set; } - public Guid OrganizationId { get; set; } - public string Key { get; set; } - public string Settings { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; +namespace Bit.Core.Entities.Provider; - public void SetNewId() +public class ProviderOrganization : ITableObject +{ + public Guid Id { get; set; } + public Guid ProviderId { get; set; } + public Guid OrganizationId { get; set; } + public string Key { get; set; } + public string Settings { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; + + public void SetNewId() + { + if (Id == default) { - if (Id == default) - { - Id = CoreHelpers.GenerateComb(); - } + Id = CoreHelpers.GenerateComb(); } } } diff --git a/src/Core/Entities/Provider/ProviderUser.cs b/src/Core/Entities/Provider/ProviderUser.cs index c3d0582da..9b86d591c 100644 --- a/src/Core/Entities/Provider/ProviderUser.cs +++ b/src/Core/Entities/Provider/ProviderUser.cs @@ -1,27 +1,26 @@ using Bit.Core.Enums.Provider; using Bit.Core.Utilities; -namespace Bit.Core.Entities.Provider -{ - public class ProviderUser : ITableObject - { - public Guid Id { get; set; } - public Guid ProviderId { get; set; } - public Guid? UserId { get; set; } - public string Email { get; set; } - public string Key { get; set; } - public ProviderUserStatusType Status { get; set; } - public ProviderUserType Type { get; set; } - public string Permissions { get; set; } - public DateTime CreationDate { get; set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; set; } = DateTime.UtcNow; +namespace Bit.Core.Entities.Provider; - public void SetNewId() +public class ProviderUser : ITableObject +{ + public Guid Id { get; set; } + public Guid ProviderId { get; set; } + public Guid? UserId { get; set; } + public string Email { get; set; } + public string Key { get; set; } + public ProviderUserStatusType Status { get; set; } + public ProviderUserType Type { get; set; } + public string Permissions { get; set; } + public DateTime CreationDate { get; set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; set; } = DateTime.UtcNow; + + public void SetNewId() + { + if (Id == default) { - if (Id == default) - { - Id = CoreHelpers.GenerateComb(); - } + Id = CoreHelpers.GenerateComb(); } } } diff --git a/src/Core/Entities/Role.cs b/src/Core/Entities/Role.cs index 2acdb1c65..5e1f6319c 100644 --- a/src/Core/Entities/Role.cs +++ b/src/Core/Entities/Role.cs @@ -1,10 +1,9 @@ -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +/// +/// This class is not used. It is implemented to make the Identity provider happy. +/// +public class Role { - /// - /// This class is not used. It is implemented to make the Identity provider happy. - /// - public class Role - { - public string Name { get; set; } - } + public string Name { get; set; } } diff --git a/src/Core/Entities/Send.cs b/src/Core/Entities/Send.cs index cbe2006e8..7cc8f3b25 100644 --- a/src/Core/Entities/Send.cs +++ b/src/Core/Entities/Send.cs @@ -2,30 +2,29 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.Entities -{ - public class Send : ITableObject - { - public Guid Id { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public SendType Type { get; set; } - public string Data { get; set; } - public string Key { get; set; } - [MaxLength(300)] - public string Password { get; set; } - public int? MaxAccessCount { get; set; } - public int AccessCount { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; - public DateTime? ExpirationDate { get; set; } - public DateTime DeletionDate { get; set; } - public bool Disabled { get; set; } - public bool? HideEmail { get; set; } +namespace Bit.Core.Entities; - public void SetNewId() - { - Id = CoreHelpers.GenerateComb(); - } +public class Send : ITableObject +{ + public Guid Id { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public SendType Type { get; set; } + public string Data { get; set; } + public string Key { get; set; } + [MaxLength(300)] + public string Password { get; set; } + public int? MaxAccessCount { get; set; } + public int AccessCount { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; + public DateTime? ExpirationDate { get; set; } + public DateTime DeletionDate { get; set; } + public bool Disabled { get; set; } + public bool? HideEmail { get; set; } + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); } } diff --git a/src/Core/Entities/SsoConfig.cs b/src/Core/Entities/SsoConfig.cs index 63bf9173c..09f3697b7 100644 --- a/src/Core/Entities/SsoConfig.cs +++ b/src/Core/Entities/SsoConfig.cs @@ -1,30 +1,29 @@ using Bit.Core.Models.Data; -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +public class SsoConfig : ITableObject { - public class SsoConfig : ITableObject + public long Id { get; set; } + public bool Enabled { get; set; } = true; + public Guid OrganizationId { get; set; } + public string Data { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; + + public void SetNewId() { - public long Id { get; set; } - public bool Enabled { get; set; } = true; - public Guid OrganizationId { get; set; } - public string Data { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; + // int will be auto-populated + Id = 0; + } - public void SetNewId() - { - // int will be auto-populated - Id = 0; - } + public SsoConfigurationData GetData() + { + return SsoConfigurationData.Deserialize(Data); + } - public SsoConfigurationData GetData() - { - return SsoConfigurationData.Deserialize(Data); - } - - public void SetData(SsoConfigurationData data) - { - Data = data.Serialize(); - } + public void SetData(SsoConfigurationData data) + { + Data = data.Serialize(); } } diff --git a/src/Core/Entities/SsoUser.cs b/src/Core/Entities/SsoUser.cs index 47818e2bd..6bc32c20d 100644 --- a/src/Core/Entities/SsoUser.cs +++ b/src/Core/Entities/SsoUser.cs @@ -1,20 +1,19 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Entities -{ - public class SsoUser : ITableObject - { - public long Id { get; set; } - public Guid UserId { get; set; } - public Guid? OrganizationId { get; set; } - [MaxLength(50)] - public string ExternalId { get; set; } - public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; +namespace Bit.Core.Entities; - public void SetNewId() - { - // int will be auto-populated - Id = 0; - } +public class SsoUser : ITableObject +{ + public long Id { get; set; } + public Guid UserId { get; set; } + public Guid? OrganizationId { get; set; } + [MaxLength(50)] + public string ExternalId { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; + + public void SetNewId() + { + // int will be auto-populated + Id = 0; } } diff --git a/src/Core/Entities/TaxRate.cs b/src/Core/Entities/TaxRate.cs index bf53c8cf0..a04ccf445 100644 --- a/src/Core/Entities/TaxRate.cs +++ b/src/Core/Entities/TaxRate.cs @@ -1,24 +1,23 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Entities -{ - public class TaxRate : ITableObject - { - [MaxLength(40)] - public string Id { get; set; } - [MaxLength(50)] - public string Country { get; set; } - [MaxLength(2)] - public string State { get; set; } - [MaxLength(10)] - public string PostalCode { get; set; } - public decimal Rate { get; set; } - public bool Active { get; set; } +namespace Bit.Core.Entities; - public void SetNewId() - { - // Id is created by Stripe, should exist before this gets called - return; - } +public class TaxRate : ITableObject +{ + [MaxLength(40)] + public string Id { get; set; } + [MaxLength(50)] + public string Country { get; set; } + [MaxLength(2)] + public string State { get; set; } + [MaxLength(10)] + public string PostalCode { get; set; } + public decimal Rate { get; set; } + public bool Active { get; set; } + + public void SetNewId() + { + // Id is created by Stripe, should exist before this gets called + return; } } diff --git a/src/Core/Entities/Transaction.cs b/src/Core/Entities/Transaction.cs index b2a01908c..f82b76a12 100644 --- a/src/Core/Entities/Transaction.cs +++ b/src/Core/Entities/Transaction.cs @@ -2,28 +2,27 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.Entities -{ - public class Transaction : ITableObject - { - public Guid Id { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public TransactionType Type { get; set; } - public decimal Amount { get; set; } - public bool? Refunded { get; set; } - public decimal? RefundedAmount { get; set; } - [MaxLength(100)] - public string Details { get; set; } - public PaymentMethodType? PaymentMethodType { get; set; } - public GatewayType? Gateway { get; set; } - [MaxLength(50)] - public string GatewayId { get; set; } - public DateTime CreationDate { get; set; } = DateTime.UtcNow; +namespace Bit.Core.Entities; - public void SetNewId() - { - Id = CoreHelpers.GenerateComb(); - } +public class Transaction : ITableObject +{ + public Guid Id { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public TransactionType Type { get; set; } + public decimal Amount { get; set; } + public bool? Refunded { get; set; } + public decimal? RefundedAmount { get; set; } + [MaxLength(100)] + public string Details { get; set; } + public PaymentMethodType? PaymentMethodType { get; set; } + public GatewayType? Gateway { get; set; } + [MaxLength(50)] + public string GatewayId { get; set; } + public DateTime CreationDate { get; set; } = DateTime.UtcNow; + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); } } diff --git a/src/Core/Entities/User.cs b/src/Core/Entities/User.cs index e5d79c722..5236fe249 100644 --- a/src/Core/Entities/User.cs +++ b/src/Core/Entities/User.cs @@ -5,189 +5,188 @@ using Bit.Core.Models; using Bit.Core.Utilities; using Microsoft.AspNetCore.Identity; -namespace Bit.Core.Entities +namespace Bit.Core.Entities; + +public class User : ITableObject, ISubscriber, IStorable, IStorableSubscriber, IRevisable, ITwoFactorProvidersUser, IReferenceable { - public class User : ITableObject, ISubscriber, IStorable, IStorableSubscriber, IRevisable, ITwoFactorProvidersUser, IReferenceable + private Dictionary _twoFactorProviders; + + public Guid Id { get; set; } + [MaxLength(50)] + public string Name { get; set; } + [Required] + [MaxLength(256)] + public string Email { get; set; } + public bool EmailVerified { get; set; } + [MaxLength(300)] + public string MasterPassword { get; set; } + [MaxLength(50)] + public string MasterPasswordHint { get; set; } + [MaxLength(10)] + public string Culture { get; set; } = "en-US"; + [Required] + [MaxLength(50)] + public string SecurityStamp { get; set; } + public string TwoFactorProviders { get; set; } + [MaxLength(32)] + public string TwoFactorRecoveryCode { get; set; } + public string EquivalentDomains { get; set; } + public string ExcludedGlobalEquivalentDomains { get; set; } + public DateTime AccountRevisionDate { get; set; } = DateTime.UtcNow; + public string Key { get; set; } + public string PublicKey { get; set; } + public string PrivateKey { get; set; } + public bool Premium { get; set; } + public DateTime? PremiumExpirationDate { get; set; } + public DateTime? RenewalReminderDate { get; set; } + public long? Storage { get; set; } + public short? MaxStorageGb { get; set; } + public GatewayType? Gateway { get; set; } + [MaxLength(50)] + public string GatewayCustomerId { get; set; } + [MaxLength(50)] + public string GatewaySubscriptionId { get; set; } + public string ReferenceData { get; set; } + [MaxLength(100)] + public string LicenseKey { get; set; } + [Required] + [MaxLength(30)] + public string ApiKey { get; set; } + public KdfType Kdf { get; set; } = KdfType.PBKDF2_SHA256; + public int KdfIterations { get; set; } = 5000; + public DateTime CreationDate { get; set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; set; } = DateTime.UtcNow; + public bool ForcePasswordReset { get; set; } + public bool UsesKeyConnector { get; set; } + public int FailedLoginCount { get; set; } + public DateTime? LastFailedLoginDate { get; set; } + public bool UnknownDeviceVerificationEnabled { get; set; } + + public void SetNewId() { - private Dictionary _twoFactorProviders; + Id = CoreHelpers.GenerateComb(); + } - public Guid Id { get; set; } - [MaxLength(50)] - public string Name { get; set; } - [Required] - [MaxLength(256)] - public string Email { get; set; } - public bool EmailVerified { get; set; } - [MaxLength(300)] - public string MasterPassword { get; set; } - [MaxLength(50)] - public string MasterPasswordHint { get; set; } - [MaxLength(10)] - public string Culture { get; set; } = "en-US"; - [Required] - [MaxLength(50)] - public string SecurityStamp { get; set; } - public string TwoFactorProviders { get; set; } - [MaxLength(32)] - public string TwoFactorRecoveryCode { get; set; } - public string EquivalentDomains { get; set; } - public string ExcludedGlobalEquivalentDomains { get; set; } - public DateTime AccountRevisionDate { get; set; } = DateTime.UtcNow; - public string Key { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } - public bool Premium { get; set; } - public DateTime? PremiumExpirationDate { get; set; } - public DateTime? RenewalReminderDate { get; set; } - public long? Storage { get; set; } - public short? MaxStorageGb { get; set; } - public GatewayType? Gateway { get; set; } - [MaxLength(50)] - public string GatewayCustomerId { get; set; } - [MaxLength(50)] - public string GatewaySubscriptionId { get; set; } - public string ReferenceData { get; set; } - [MaxLength(100)] - public string LicenseKey { get; set; } - [Required] - [MaxLength(30)] - public string ApiKey { get; set; } - public KdfType Kdf { get; set; } = KdfType.PBKDF2_SHA256; - public int KdfIterations { get; set; } = 5000; - public DateTime CreationDate { get; set; } = DateTime.UtcNow; - public DateTime RevisionDate { get; set; } = DateTime.UtcNow; - public bool ForcePasswordReset { get; set; } - public bool UsesKeyConnector { get; set; } - public int FailedLoginCount { get; set; } - public DateTime? LastFailedLoginDate { get; set; } - public bool UnknownDeviceVerificationEnabled { get; set; } + public string BillingEmailAddress() + { + return Email?.ToLowerInvariant()?.Trim(); + } - public void SetNewId() + public string BillingName() + { + return Name; + } + + public string BraintreeCustomerIdPrefix() + { + return "u"; + } + + public string BraintreeIdField() + { + return "user_id"; + } + + public string GatewayIdField() + { + return "userId"; + } + + public bool IsUser() + { + return true; + } + + public Dictionary GetTwoFactorProviders() + { + if (string.IsNullOrWhiteSpace(TwoFactorProviders)) { - Id = CoreHelpers.GenerateComb(); + return null; } - public string BillingEmailAddress() + try { - return Email?.ToLowerInvariant()?.Trim(); - } - - public string BillingName() - { - return Name; - } - - public string BraintreeCustomerIdPrefix() - { - return "u"; - } - - public string BraintreeIdField() - { - return "user_id"; - } - - public string GatewayIdField() - { - return "userId"; - } - - public bool IsUser() - { - return true; - } - - public Dictionary GetTwoFactorProviders() - { - if (string.IsNullOrWhiteSpace(TwoFactorProviders)) + if (_twoFactorProviders == null) { - return null; + _twoFactorProviders = + JsonHelpers.LegacyDeserialize>( + TwoFactorProviders); } - try - { - if (_twoFactorProviders == null) - { - _twoFactorProviders = - JsonHelpers.LegacyDeserialize>( - TwoFactorProviders); - } - - return _twoFactorProviders; - } - catch (JsonException) - { - return null; - } + return _twoFactorProviders; } - - public Guid? GetUserId() + catch (JsonException) { - return Id; - } - - public bool GetPremium() - { - return Premium; - } - - public void SetTwoFactorProviders(Dictionary providers) - { - // When replacing with system.text remember to remove the extra serialization in WebAuthnTokenProvider. - TwoFactorProviders = JsonHelpers.LegacySerialize(providers, JsonHelpers.LegacyEnumKeyResolver); - _twoFactorProviders = providers; - } - - public void ClearTwoFactorProviders() - { - SetTwoFactorProviders(new Dictionary()); - } - - public TwoFactorProvider GetTwoFactorProvider(TwoFactorProviderType provider) - { - var providers = GetTwoFactorProviders(); - if (providers == null || !providers.ContainsKey(provider)) - { - return null; - } - - return providers[provider]; - } - - public long StorageBytesRemaining() - { - if (!MaxStorageGb.HasValue) - { - return 0; - } - - return StorageBytesRemaining(MaxStorageGb.Value); - } - - public long StorageBytesRemaining(short maxStorageGb) - { - var maxStorageBytes = maxStorageGb * 1073741824L; - if (!Storage.HasValue) - { - return maxStorageBytes; - } - - return maxStorageBytes - Storage.Value; - } - - public IdentityUser ToIdentityUser(bool twoFactorEnabled) - { - return new IdentityUser - { - Id = Id.ToString(), - Email = Email, - NormalizedEmail = Email, - EmailConfirmed = EmailVerified, - UserName = Email, - NormalizedUserName = Email, - TwoFactorEnabled = twoFactorEnabled, - SecurityStamp = SecurityStamp - }; + return null; } } + + public Guid? GetUserId() + { + return Id; + } + + public bool GetPremium() + { + return Premium; + } + + public void SetTwoFactorProviders(Dictionary providers) + { + // When replacing with system.text remember to remove the extra serialization in WebAuthnTokenProvider. + TwoFactorProviders = JsonHelpers.LegacySerialize(providers, JsonHelpers.LegacyEnumKeyResolver); + _twoFactorProviders = providers; + } + + public void ClearTwoFactorProviders() + { + SetTwoFactorProviders(new Dictionary()); + } + + public TwoFactorProvider GetTwoFactorProvider(TwoFactorProviderType provider) + { + var providers = GetTwoFactorProviders(); + if (providers == null || !providers.ContainsKey(provider)) + { + return null; + } + + return providers[provider]; + } + + public long StorageBytesRemaining() + { + if (!MaxStorageGb.HasValue) + { + return 0; + } + + return StorageBytesRemaining(MaxStorageGb.Value); + } + + public long StorageBytesRemaining(short maxStorageGb) + { + var maxStorageBytes = maxStorageGb * 1073741824L; + if (!Storage.HasValue) + { + return maxStorageBytes; + } + + return maxStorageBytes - Storage.Value; + } + + public IdentityUser ToIdentityUser(bool twoFactorEnabled) + { + return new IdentityUser + { + Id = Id.ToString(), + Email = Email, + NormalizedEmail = Email, + EmailConfirmed = EmailVerified, + UserName = Email, + NormalizedUserName = Email, + TwoFactorEnabled = twoFactorEnabled, + SecurityStamp = SecurityStamp + }; + } } diff --git a/src/Core/Enums/ApplicationCacheMessageType.cs b/src/Core/Enums/ApplicationCacheMessageType.cs index b91b07995..94889ed4e 100644 --- a/src/Core/Enums/ApplicationCacheMessageType.cs +++ b/src/Core/Enums/ApplicationCacheMessageType.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum ApplicationCacheMessageType : byte { - public enum ApplicationCacheMessageType : byte - { - UpsertOrganizationAbility = 0, - DeleteOrganizationAbility = 1 - } + UpsertOrganizationAbility = 0, + DeleteOrganizationAbility = 1 } diff --git a/src/Core/Enums/BitwardenClient.cs b/src/Core/Enums/BitwardenClient.cs index 067eef92b..6a1244c0c 100644 --- a/src/Core/Enums/BitwardenClient.cs +++ b/src/Core/Enums/BitwardenClient.cs @@ -1,13 +1,12 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public static class BitwardenClient { - public static class BitwardenClient - { - public const string - Web = "web", - Browser = "browser", - Desktop = "desktop", - Mobile = "mobile", - Cli = "cli", - DirectoryConnector = "connector"; - } + public const string + Web = "web", + Browser = "browser", + Desktop = "desktop", + Mobile = "mobile", + Cli = "cli", + DirectoryConnector = "connector"; } diff --git a/src/Core/Enums/CipherRepromptType.cs b/src/Core/Enums/CipherRepromptType.cs index 0e5b60ff2..3c64c1945 100644 --- a/src/Core/Enums/CipherRepromptType.cs +++ b/src/Core/Enums/CipherRepromptType.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum CipherRepromptType : byte { - public enum CipherRepromptType : byte - { - None = 0, - Password = 1, - } + None = 0, + Password = 1, } diff --git a/src/Core/Enums/CipherStateAction.cs b/src/Core/Enums/CipherStateAction.cs index 87b73a41c..926c8b06c 100644 --- a/src/Core/Enums/CipherStateAction.cs +++ b/src/Core/Enums/CipherStateAction.cs @@ -1,9 +1,8 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum CipherStateAction { - public enum CipherStateAction - { - Restore, - SoftDelete, - HardDelete, - } + Restore, + SoftDelete, + HardDelete, } diff --git a/src/Core/Enums/CipherType.cs b/src/Core/Enums/CipherType.cs index 0aca94864..d9f37bcbc 100644 --- a/src/Core/Enums/CipherType.cs +++ b/src/Core/Enums/CipherType.cs @@ -1,12 +1,11 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum CipherType : byte { - public enum CipherType : byte - { - // Folder is deprecated - //Folder = 0, - Login = 1, - SecureNote = 2, - Card = 3, - Identity = 4 - } + // Folder is deprecated + //Folder = 0, + Login = 1, + SecureNote = 2, + Card = 3, + Identity = 4 } diff --git a/src/Core/Enums/DeviceType.cs b/src/Core/Enums/DeviceType.cs index 53aa21c76..361d9ac38 100644 --- a/src/Core/Enums/DeviceType.cs +++ b/src/Core/Enums/DeviceType.cs @@ -1,50 +1,49 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum DeviceType : byte { - public enum DeviceType : byte - { - [Display(Name = "Android")] - Android = 0, - [Display(Name = "iOS")] - iOS = 1, - [Display(Name = "Chrome Extension")] - ChromeExtension = 2, - [Display(Name = "Firefox Extension")] - FirefoxExtension = 3, - [Display(Name = "Opera Extension")] - OperaExtension = 4, - [Display(Name = "Edge Extension")] - EdgeExtension = 5, - [Display(Name = "Windows")] - WindowsDesktop = 6, - [Display(Name = "macOS")] - MacOsDesktop = 7, - [Display(Name = "Linux")] - LinuxDesktop = 8, - [Display(Name = "Chrome")] - ChromeBrowser = 9, - [Display(Name = "Firefox")] - FirefoxBrowser = 10, - [Display(Name = "Opera")] - OperaBrowser = 11, - [Display(Name = "Edge")] - EdgeBrowser = 12, - [Display(Name = "Internet Explorer")] - IEBrowser = 13, - [Display(Name = "Unknown Browser")] - UnknownBrowser = 14, - [Display(Name = "Android")] - AndroidAmazon = 15, - [Display(Name = "UWP")] - UWP = 16, - [Display(Name = "Safari")] - SafariBrowser = 17, - [Display(Name = "Vivaldi")] - VivaldiBrowser = 18, - [Display(Name = "Vivaldi Extension")] - VivaldiExtension = 19, - [Display(Name = "Safari Extension")] - SafariExtension = 20 - } + [Display(Name = "Android")] + Android = 0, + [Display(Name = "iOS")] + iOS = 1, + [Display(Name = "Chrome Extension")] + ChromeExtension = 2, + [Display(Name = "Firefox Extension")] + FirefoxExtension = 3, + [Display(Name = "Opera Extension")] + OperaExtension = 4, + [Display(Name = "Edge Extension")] + EdgeExtension = 5, + [Display(Name = "Windows")] + WindowsDesktop = 6, + [Display(Name = "macOS")] + MacOsDesktop = 7, + [Display(Name = "Linux")] + LinuxDesktop = 8, + [Display(Name = "Chrome")] + ChromeBrowser = 9, + [Display(Name = "Firefox")] + FirefoxBrowser = 10, + [Display(Name = "Opera")] + OperaBrowser = 11, + [Display(Name = "Edge")] + EdgeBrowser = 12, + [Display(Name = "Internet Explorer")] + IEBrowser = 13, + [Display(Name = "Unknown Browser")] + UnknownBrowser = 14, + [Display(Name = "Android")] + AndroidAmazon = 15, + [Display(Name = "UWP")] + UWP = 16, + [Display(Name = "Safari")] + SafariBrowser = 17, + [Display(Name = "Vivaldi")] + VivaldiBrowser = 18, + [Display(Name = "Vivaldi Extension")] + VivaldiExtension = 19, + [Display(Name = "Safari Extension")] + SafariExtension = 20 } diff --git a/src/Core/Enums/EmergencyAccessStatusType.cs b/src/Core/Enums/EmergencyAccessStatusType.cs index 2c5b472a9..79fca334e 100644 --- a/src/Core/Enums/EmergencyAccessStatusType.cs +++ b/src/Core/Enums/EmergencyAccessStatusType.cs @@ -1,11 +1,10 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum EmergencyAccessStatusType : byte { - public enum EmergencyAccessStatusType : byte - { - Invited = 0, - Accepted = 1, - Confirmed = 2, - RecoveryInitiated = 3, - RecoveryApproved = 4, - } + Invited = 0, + Accepted = 1, + Confirmed = 2, + RecoveryInitiated = 3, + RecoveryApproved = 4, } diff --git a/src/Core/Enums/EmergencyAccessType.cs b/src/Core/Enums/EmergencyAccessType.cs index d622857aa..5742bb531 100644 --- a/src/Core/Enums/EmergencyAccessType.cs +++ b/src/Core/Enums/EmergencyAccessType.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum EmergencyAccessType : byte { - public enum EmergencyAccessType : byte - { - View = 0, - Takeover = 1, - } + View = 0, + Takeover = 1, } diff --git a/src/Core/Enums/EncryptionType.cs b/src/Core/Enums/EncryptionType.cs index 2b6eaf086..a37110911 100644 --- a/src/Core/Enums/EncryptionType.cs +++ b/src/Core/Enums/EncryptionType.cs @@ -1,13 +1,12 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum EncryptionType : byte { - public enum EncryptionType : byte - { - AesCbc256_B64 = 0, - AesCbc128_HmacSha256_B64 = 1, - AesCbc256_HmacSha256_B64 = 2, - Rsa2048_OaepSha256_B64 = 3, - Rsa2048_OaepSha1_B64 = 4, - Rsa2048_OaepSha256_HmacSha256_B64 = 5, - Rsa2048_OaepSha1_HmacSha256_B64 = 6 - } + AesCbc256_B64 = 0, + AesCbc128_HmacSha256_B64 = 1, + AesCbc256_HmacSha256_B64 = 2, + Rsa2048_OaepSha256_B64 = 3, + Rsa2048_OaepSha1_B64 = 4, + Rsa2048_OaepSha256_HmacSha256_B64 = 5, + Rsa2048_OaepSha1_HmacSha256_B64 = 6 } diff --git a/src/Core/Enums/EventType.cs b/src/Core/Enums/EventType.cs index 98d844008..09a1afffd 100644 --- a/src/Core/Enums/EventType.cs +++ b/src/Core/Enums/EventType.cs @@ -1,79 +1,78 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum EventType : int { - public enum EventType : int - { - User_LoggedIn = 1000, - User_ChangedPassword = 1001, - User_Updated2fa = 1002, - User_Disabled2fa = 1003, - User_Recovered2fa = 1004, - User_FailedLogIn = 1005, - User_FailedLogIn2fa = 1006, - User_ClientExportedVault = 1007, - User_UpdatedTempPassword = 1008, - User_MigratedKeyToKeyConnector = 1009, + User_LoggedIn = 1000, + User_ChangedPassword = 1001, + User_Updated2fa = 1002, + User_Disabled2fa = 1003, + User_Recovered2fa = 1004, + User_FailedLogIn = 1005, + User_FailedLogIn2fa = 1006, + User_ClientExportedVault = 1007, + User_UpdatedTempPassword = 1008, + User_MigratedKeyToKeyConnector = 1009, - Cipher_Created = 1100, - Cipher_Updated = 1101, - Cipher_Deleted = 1102, - Cipher_AttachmentCreated = 1103, - Cipher_AttachmentDeleted = 1104, - Cipher_Shared = 1105, - Cipher_UpdatedCollections = 1106, - Cipher_ClientViewed = 1107, - Cipher_ClientToggledPasswordVisible = 1108, - Cipher_ClientToggledHiddenFieldVisible = 1109, - Cipher_ClientToggledCardCodeVisible = 1110, - Cipher_ClientCopiedPassword = 1111, - Cipher_ClientCopiedHiddenField = 1112, - Cipher_ClientCopiedCardCode = 1113, - Cipher_ClientAutofilled = 1114, - Cipher_SoftDeleted = 1115, - Cipher_Restored = 1116, - Cipher_ClientToggledCardNumberVisible = 1117, + Cipher_Created = 1100, + Cipher_Updated = 1101, + Cipher_Deleted = 1102, + Cipher_AttachmentCreated = 1103, + Cipher_AttachmentDeleted = 1104, + Cipher_Shared = 1105, + Cipher_UpdatedCollections = 1106, + Cipher_ClientViewed = 1107, + Cipher_ClientToggledPasswordVisible = 1108, + Cipher_ClientToggledHiddenFieldVisible = 1109, + Cipher_ClientToggledCardCodeVisible = 1110, + Cipher_ClientCopiedPassword = 1111, + Cipher_ClientCopiedHiddenField = 1112, + Cipher_ClientCopiedCardCode = 1113, + Cipher_ClientAutofilled = 1114, + Cipher_SoftDeleted = 1115, + Cipher_Restored = 1116, + Cipher_ClientToggledCardNumberVisible = 1117, - Collection_Created = 1300, - Collection_Updated = 1301, - Collection_Deleted = 1302, + Collection_Created = 1300, + Collection_Updated = 1301, + Collection_Deleted = 1302, - Group_Created = 1400, - Group_Updated = 1401, - Group_Deleted = 1402, + Group_Created = 1400, + Group_Updated = 1401, + Group_Deleted = 1402, - OrganizationUser_Invited = 1500, - OrganizationUser_Confirmed = 1501, - OrganizationUser_Updated = 1502, - OrganizationUser_Removed = 1503, - OrganizationUser_UpdatedGroups = 1504, - OrganizationUser_UnlinkedSso = 1505, - OrganizationUser_ResetPassword_Enroll = 1506, - OrganizationUser_ResetPassword_Withdraw = 1507, - OrganizationUser_AdminResetPassword = 1508, - OrganizationUser_ResetSsoLink = 1509, - OrganizationUser_FirstSsoLogin = 1510, - OrganizationUser_Revoked = 1511, - OrganizationUser_Restored = 1512, + OrganizationUser_Invited = 1500, + OrganizationUser_Confirmed = 1501, + OrganizationUser_Updated = 1502, + OrganizationUser_Removed = 1503, + OrganizationUser_UpdatedGroups = 1504, + OrganizationUser_UnlinkedSso = 1505, + OrganizationUser_ResetPassword_Enroll = 1506, + OrganizationUser_ResetPassword_Withdraw = 1507, + OrganizationUser_AdminResetPassword = 1508, + OrganizationUser_ResetSsoLink = 1509, + OrganizationUser_FirstSsoLogin = 1510, + OrganizationUser_Revoked = 1511, + OrganizationUser_Restored = 1512, - Organization_Updated = 1600, - Organization_PurgedVault = 1601, - Organization_ClientExportedVault = 1602, - Organization_VaultAccessed = 1603, - Organization_EnabledSso = 1604, - Organization_DisabledSso = 1605, - Organization_EnabledKeyConnector = 1606, - Organization_DisabledKeyConnector = 1607, - Organization_SponsorshipsSynced = 1608, + Organization_Updated = 1600, + Organization_PurgedVault = 1601, + Organization_ClientExportedVault = 1602, + Organization_VaultAccessed = 1603, + Organization_EnabledSso = 1604, + Organization_DisabledSso = 1605, + Organization_EnabledKeyConnector = 1606, + Organization_DisabledKeyConnector = 1607, + Organization_SponsorshipsSynced = 1608, - Policy_Updated = 1700, + Policy_Updated = 1700, - ProviderUser_Invited = 1800, - ProviderUser_Confirmed = 1801, - ProviderUser_Updated = 1802, - ProviderUser_Removed = 1803, + ProviderUser_Invited = 1800, + ProviderUser_Confirmed = 1801, + ProviderUser_Updated = 1802, + ProviderUser_Removed = 1803, - ProviderOrganization_Created = 1900, - ProviderOrganization_Added = 1901, - ProviderOrganization_Removed = 1902, - ProviderOrganization_VaultAccessed = 1903, - } + ProviderOrganization_Created = 1900, + ProviderOrganization_Added = 1901, + ProviderOrganization_Removed = 1902, + ProviderOrganization_VaultAccessed = 1903, } diff --git a/src/Core/Enums/FieldType.cs b/src/Core/Enums/FieldType.cs index 5eef485b7..4642b63a8 100644 --- a/src/Core/Enums/FieldType.cs +++ b/src/Core/Enums/FieldType.cs @@ -1,10 +1,9 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum FieldType : byte { - public enum FieldType : byte - { - Text = 0, - Hidden = 1, - Boolean = 2, - Linked = 3, - } + Text = 0, + Hidden = 1, + Boolean = 2, + Linked = 3, } diff --git a/src/Core/Enums/FileUploadType.cs b/src/Core/Enums/FileUploadType.cs index 4bdefd4dd..4d32589b6 100644 --- a/src/Core/Enums/FileUploadType.cs +++ b/src/Core/Enums/FileUploadType.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum FileUploadType { - public enum FileUploadType - { - Direct = 0, - Azure = 1, - } + Direct = 0, + Azure = 1, } diff --git a/src/Core/Enums/GatewayType.cs b/src/Core/Enums/GatewayType.cs index 68c959ad7..5ad73cf0f 100644 --- a/src/Core/Enums/GatewayType.cs +++ b/src/Core/Enums/GatewayType.cs @@ -1,22 +1,21 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum GatewayType : byte { - public enum GatewayType : byte - { - [Display(Name = "Stripe")] - Stripe = 0, - [Display(Name = "Braintree")] - Braintree = 1, - [Display(Name = "Apple App Store")] - AppStore = 2, - [Display(Name = "Google Play Store")] - PlayStore = 3, - [Display(Name = "BitPay")] - BitPay = 4, - [Display(Name = "PayPal")] - PayPal = 5, - [Display(Name = "Bank")] - Bank = 6, - } + [Display(Name = "Stripe")] + Stripe = 0, + [Display(Name = "Braintree")] + Braintree = 1, + [Display(Name = "Apple App Store")] + AppStore = 2, + [Display(Name = "Google Play Store")] + PlayStore = 3, + [Display(Name = "BitPay")] + BitPay = 4, + [Display(Name = "PayPal")] + PayPal = 5, + [Display(Name = "Bank")] + Bank = 6, } diff --git a/src/Core/Enums/GlobalEquivalentDomainsType.cs b/src/Core/Enums/GlobalEquivalentDomainsType.cs index 22b0cdd3a..1291736d7 100644 --- a/src/Core/Enums/GlobalEquivalentDomainsType.cs +++ b/src/Core/Enums/GlobalEquivalentDomainsType.cs @@ -1,95 +1,94 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum GlobalEquivalentDomainsType : byte { - public enum GlobalEquivalentDomainsType : byte - { - Google = 0, - Apple = 1, - Ameritrade = 2, - BoA = 3, - Sprint = 4, - WellsFargo = 5, - Merrill = 6, - Citi = 7, - Cnet = 8, - Gap = 9, - Microsoft = 10, - United = 11, - Yahoo = 12, - Zonelabs = 13, - PayPal = 14, - Avon = 15, - Diapers = 16, - Contacts = 17, - Amazon = 18, - Cox = 19, - Norton = 20, - Verizon = 21, - Buy = 22, - Sirius = 23, - Ea = 24, - Basecamp = 25, - Steam = 26, - Chart = 27, - Gotomeeting = 28, - Gogo = 29, - Oracle = 30, - Discover = 31, - Dcu = 32, - Healthcare = 33, - Pepco = 34, - Century21 = 35, - Comcast = 36, - Cricket = 37, - Mtb = 38, - Dropbox = 39, - Snapfish = 40, - Alibaba = 41, - Playstation = 42, - Mercado = 43, - Zendesk = 44, - Autodesk = 45, - RailNation = 46, - Wpcu = 47, - Mathletics = 48, - Discountbank = 49, - Mi = 50, - Facebook = 51, - Postepay = 52, - Skysports = 53, - Disney = 54, - Pokemon = 55, - Uv = 56, - Yahavo = 57, - Mdsol = 58, - Sears = 59, - Xiami = 60, - Belkin = 61, - Turbotax = 62, - Shopify = 63, - Ebay = 64, - Techdata = 65, - Schwab = 66, - Mozilla = 67, // deprecated - Tesla = 68, - MorganStanley = 69, - TaxAct = 70, - Wikimedia = 71, - Airbnb = 72, - Eventbrite = 73, - StackExchange = 74, - Docusign = 75, - Envato = 76, - X10Hosting = 77, - Cisco = 78, - CedarFair = 79, - Ubiquiti = 80, - Discord = 81, - Netcup = 82, - Yandex = 83, - Sony = 84, - Proton = 85, - Ubisoft = 86, - TransferWise = 87, - TakeawayEU = 88, - } + Google = 0, + Apple = 1, + Ameritrade = 2, + BoA = 3, + Sprint = 4, + WellsFargo = 5, + Merrill = 6, + Citi = 7, + Cnet = 8, + Gap = 9, + Microsoft = 10, + United = 11, + Yahoo = 12, + Zonelabs = 13, + PayPal = 14, + Avon = 15, + Diapers = 16, + Contacts = 17, + Amazon = 18, + Cox = 19, + Norton = 20, + Verizon = 21, + Buy = 22, + Sirius = 23, + Ea = 24, + Basecamp = 25, + Steam = 26, + Chart = 27, + Gotomeeting = 28, + Gogo = 29, + Oracle = 30, + Discover = 31, + Dcu = 32, + Healthcare = 33, + Pepco = 34, + Century21 = 35, + Comcast = 36, + Cricket = 37, + Mtb = 38, + Dropbox = 39, + Snapfish = 40, + Alibaba = 41, + Playstation = 42, + Mercado = 43, + Zendesk = 44, + Autodesk = 45, + RailNation = 46, + Wpcu = 47, + Mathletics = 48, + Discountbank = 49, + Mi = 50, + Facebook = 51, + Postepay = 52, + Skysports = 53, + Disney = 54, + Pokemon = 55, + Uv = 56, + Yahavo = 57, + Mdsol = 58, + Sears = 59, + Xiami = 60, + Belkin = 61, + Turbotax = 62, + Shopify = 63, + Ebay = 64, + Techdata = 65, + Schwab = 66, + Mozilla = 67, // deprecated + Tesla = 68, + MorganStanley = 69, + TaxAct = 70, + Wikimedia = 71, + Airbnb = 72, + Eventbrite = 73, + StackExchange = 74, + Docusign = 75, + Envato = 76, + X10Hosting = 77, + Cisco = 78, + CedarFair = 79, + Ubiquiti = 80, + Discord = 81, + Netcup = 82, + Yandex = 83, + Sony = 84, + Proton = 85, + Ubisoft = 86, + TransferWise = 87, + TakeawayEU = 88, } diff --git a/src/Core/Enums/KdfType.cs b/src/Core/Enums/KdfType.cs index 1c845846a..212794eac 100644 --- a/src/Core/Enums/KdfType.cs +++ b/src/Core/Enums/KdfType.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum KdfType : byte { - public enum KdfType : byte - { - PBKDF2_SHA256 = 0 - } + PBKDF2_SHA256 = 0 } diff --git a/src/Core/Enums/LicenseType.cs b/src/Core/Enums/LicenseType.cs index 60d622b9c..90ca0d7a6 100644 --- a/src/Core/Enums/LicenseType.cs +++ b/src/Core/Enums/LicenseType.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum LicenseType : byte { - public enum LicenseType : byte - { - User = 0, - Organization = 1, - } + User = 0, + Organization = 1, } diff --git a/src/Core/Enums/OrganizationApiKeyType.cs b/src/Core/Enums/OrganizationApiKeyType.cs index 153079cf2..8fdbf931a 100644 --- a/src/Core/Enums/OrganizationApiKeyType.cs +++ b/src/Core/Enums/OrganizationApiKeyType.cs @@ -1,9 +1,8 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum OrganizationApiKeyType : byte { - public enum OrganizationApiKeyType : byte - { - Default = 0, - BillingSync = 1, - Scim = 2, - } + Default = 0, + BillingSync = 1, + Scim = 2, } diff --git a/src/Core/Enums/OrganizationConnectionType.cs b/src/Core/Enums/OrganizationConnectionType.cs index e998e5532..995cfc866 100644 --- a/src/Core/Enums/OrganizationConnectionType.cs +++ b/src/Core/Enums/OrganizationConnectionType.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum OrganizationConnectionType : byte { - public enum OrganizationConnectionType : byte - { - CloudBillingSync = 1, - Scim = 2, - } + CloudBillingSync = 1, + Scim = 2, } diff --git a/src/Core/Enums/OrganizationUserStatusType.cs b/src/Core/Enums/OrganizationUserStatusType.cs index 8c39c053f..576e98ea7 100644 --- a/src/Core/Enums/OrganizationUserStatusType.cs +++ b/src/Core/Enums/OrganizationUserStatusType.cs @@ -1,10 +1,9 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum OrganizationUserStatusType : short { - public enum OrganizationUserStatusType : short - { - Invited = 0, - Accepted = 1, - Confirmed = 2, - Revoked = -1, - } + Invited = 0, + Accepted = 1, + Confirmed = 2, + Revoked = -1, } diff --git a/src/Core/Enums/OrganizationUserType.cs b/src/Core/Enums/OrganizationUserType.cs index 738c80657..620eaeb33 100644 --- a/src/Core/Enums/OrganizationUserType.cs +++ b/src/Core/Enums/OrganizationUserType.cs @@ -1,11 +1,10 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum OrganizationUserType : byte { - public enum OrganizationUserType : byte - { - Owner = 0, - Admin = 1, - User = 2, - Manager = 3, - Custom = 4, - } + Owner = 0, + Admin = 1, + User = 2, + Manager = 3, + Custom = 4, } diff --git a/src/Core/Enums/PaymentMethodType.cs b/src/Core/Enums/PaymentMethodType.cs index b0290f92b..0b6c235b3 100644 --- a/src/Core/Enums/PaymentMethodType.cs +++ b/src/Core/Enums/PaymentMethodType.cs @@ -1,28 +1,27 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum PaymentMethodType : byte { - public enum PaymentMethodType : byte - { - [Display(Name = "Card")] - Card = 0, - [Display(Name = "Bank Account")] - BankAccount = 1, - [Display(Name = "PayPal")] - PayPal = 2, - [Display(Name = "BitPay")] - BitPay = 3, - [Display(Name = "Credit")] - Credit = 4, - [Display(Name = "Wire Transfer")] - WireTransfer = 5, - [Display(Name = "Apple In-App Purchase")] - AppleInApp = 6, - [Display(Name = "Google In-App Purchase")] - GoogleInApp = 7, - [Display(Name = "Check")] - Check = 8, - [Display(Name = "None")] - None = 255, - } + [Display(Name = "Card")] + Card = 0, + [Display(Name = "Bank Account")] + BankAccount = 1, + [Display(Name = "PayPal")] + PayPal = 2, + [Display(Name = "BitPay")] + BitPay = 3, + [Display(Name = "Credit")] + Credit = 4, + [Display(Name = "Wire Transfer")] + WireTransfer = 5, + [Display(Name = "Apple In-App Purchase")] + AppleInApp = 6, + [Display(Name = "Google In-App Purchase")] + GoogleInApp = 7, + [Display(Name = "Check")] + Check = 8, + [Display(Name = "None")] + None = 255, } diff --git a/src/Core/Enums/PlanSponsorshipType.cs b/src/Core/Enums/PlanSponsorshipType.cs index 59f778e10..2bb7a15b1 100644 --- a/src/Core/Enums/PlanSponsorshipType.cs +++ b/src/Core/Enums/PlanSponsorshipType.cs @@ -1,10 +1,9 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum PlanSponsorshipType : byte { - public enum PlanSponsorshipType : byte - { - [Display(Name = "Families For Enterprise")] - FamiliesForEnterprise = 0, - } + [Display(Name = "Families For Enterprise")] + FamiliesForEnterprise = 0, } diff --git a/src/Core/Enums/PlanType.cs b/src/Core/Enums/PlanType.cs index 037f1f893..ac32f217e 100644 --- a/src/Core/Enums/PlanType.cs +++ b/src/Core/Enums/PlanType.cs @@ -1,32 +1,31 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum PlanType : byte { - public enum PlanType : byte - { - [Display(Name = "Free")] - Free = 0, - [Display(Name = "Families 2019")] - FamiliesAnnually2019 = 1, - [Display(Name = "Teams (Monthly) 2019")] - TeamsMonthly2019 = 2, - [Display(Name = "Teams (Annually) 2019")] - TeamsAnnually2019 = 3, - [Display(Name = "Enterprise (Monthly) 2019")] - EnterpriseMonthly2019 = 4, - [Display(Name = "Enterprise (Annually) 2019")] - EnterpriseAnnually2019 = 5, - [Display(Name = "Custom")] - Custom = 6, - [Display(Name = "Families")] - FamiliesAnnually = 7, - [Display(Name = "Teams (Monthly)")] - TeamsMonthly = 8, - [Display(Name = "Teams (Annually)")] - TeamsAnnually = 9, - [Display(Name = "Enterprise (Monthly)")] - EnterpriseMonthly = 10, - [Display(Name = "Enterprise (Annually)")] - EnterpriseAnnually = 11, - } + [Display(Name = "Free")] + Free = 0, + [Display(Name = "Families 2019")] + FamiliesAnnually2019 = 1, + [Display(Name = "Teams (Monthly) 2019")] + TeamsMonthly2019 = 2, + [Display(Name = "Teams (Annually) 2019")] + TeamsAnnually2019 = 3, + [Display(Name = "Enterprise (Monthly) 2019")] + EnterpriseMonthly2019 = 4, + [Display(Name = "Enterprise (Annually) 2019")] + EnterpriseAnnually2019 = 5, + [Display(Name = "Custom")] + Custom = 6, + [Display(Name = "Families")] + FamiliesAnnually = 7, + [Display(Name = "Teams (Monthly)")] + TeamsMonthly = 8, + [Display(Name = "Teams (Annually)")] + TeamsAnnually = 9, + [Display(Name = "Enterprise (Monthly)")] + EnterpriseMonthly = 10, + [Display(Name = "Enterprise (Annually)")] + EnterpriseAnnually = 11, } diff --git a/src/Core/Enums/PolicyType.cs b/src/Core/Enums/PolicyType.cs index ac7669995..e4c120836 100644 --- a/src/Core/Enums/PolicyType.cs +++ b/src/Core/Enums/PolicyType.cs @@ -1,17 +1,16 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum PolicyType : byte { - public enum PolicyType : byte - { - TwoFactorAuthentication = 0, - MasterPassword = 1, - PasswordGenerator = 2, - SingleOrg = 3, - RequireSso = 4, - PersonalOwnership = 5, - DisableSend = 6, - SendOptions = 7, - ResetPassword = 8, - MaximumVaultTimeout = 9, - DisablePersonalVaultExport = 10, - } + TwoFactorAuthentication = 0, + MasterPassword = 1, + PasswordGenerator = 2, + SingleOrg = 3, + RequireSso = 4, + PersonalOwnership = 5, + DisableSend = 6, + SendOptions = 7, + ResetPassword = 8, + MaximumVaultTimeout = 9, + DisablePersonalVaultExport = 10, } diff --git a/src/Core/Enums/ProductType.cs b/src/Core/Enums/ProductType.cs index 2f9b1d478..1e443f56f 100644 --- a/src/Core/Enums/ProductType.cs +++ b/src/Core/Enums/ProductType.cs @@ -1,17 +1,16 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum ProductType : byte { - public enum ProductType : byte - { - [Display(Name = "Free")] - Free = 0, - [Display(Name = "Families")] - Families = 1, - [Display(Name = "Teams")] - Teams = 2, - [Display(Name = "Enterprise")] - Enterprise = 3, - } + [Display(Name = "Free")] + Free = 0, + [Display(Name = "Families")] + Families = 1, + [Display(Name = "Teams")] + Teams = 2, + [Display(Name = "Enterprise")] + Enterprise = 3, } diff --git a/src/Core/Enums/Provider/ProviderStatusType.cs b/src/Core/Enums/Provider/ProviderStatusType.cs index 16d8d6330..bcb1f8cd2 100644 --- a/src/Core/Enums/Provider/ProviderStatusType.cs +++ b/src/Core/Enums/Provider/ProviderStatusType.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Enums.Provider +namespace Bit.Core.Enums.Provider; + +public enum ProviderStatusType : byte { - public enum ProviderStatusType : byte - { - Pending = 0, - Created = 1, - } + Pending = 0, + Created = 1, } diff --git a/src/Core/Enums/Provider/ProviderUserStatusType.cs b/src/Core/Enums/Provider/ProviderUserStatusType.cs index 73e9c8e33..60571386d 100644 --- a/src/Core/Enums/Provider/ProviderUserStatusType.cs +++ b/src/Core/Enums/Provider/ProviderUserStatusType.cs @@ -1,9 +1,8 @@ -namespace Bit.Core.Enums.Provider +namespace Bit.Core.Enums.Provider; + +public enum ProviderUserStatusType : byte { - public enum ProviderUserStatusType : byte - { - Invited = 0, - Accepted = 1, - Confirmed = 2, - } + Invited = 0, + Accepted = 1, + Confirmed = 2, } diff --git a/src/Core/Enums/Provider/ProviderUserType.cs b/src/Core/Enums/Provider/ProviderUserType.cs index 7147d21a3..d13591290 100644 --- a/src/Core/Enums/Provider/ProviderUserType.cs +++ b/src/Core/Enums/Provider/ProviderUserType.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Enums.Provider +namespace Bit.Core.Enums.Provider; + +public enum ProviderUserType : byte { - public enum ProviderUserType : byte - { - ProviderAdmin = 0, - ServiceUser = 1, - } + ProviderAdmin = 0, + ServiceUser = 1, } diff --git a/src/Core/Enums/PushType.cs b/src/Core/Enums/PushType.cs index 7899656b5..9054d1d40 100644 --- a/src/Core/Enums/PushType.cs +++ b/src/Core/Enums/PushType.cs @@ -1,24 +1,23 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum PushType : byte { - public enum PushType : byte - { - SyncCipherUpdate = 0, - SyncCipherCreate = 1, - SyncLoginDelete = 2, - SyncFolderDelete = 3, - SyncCiphers = 4, + SyncCipherUpdate = 0, + SyncCipherCreate = 1, + SyncLoginDelete = 2, + SyncFolderDelete = 3, + SyncCiphers = 4, - SyncVault = 5, - SyncOrgKeys = 6, - SyncFolderCreate = 7, - SyncFolderUpdate = 8, - SyncCipherDelete = 9, - SyncSettings = 10, + SyncVault = 5, + SyncOrgKeys = 6, + SyncFolderCreate = 7, + SyncFolderUpdate = 8, + SyncCipherDelete = 9, + SyncSettings = 10, - LogOut = 11, + LogOut = 11, - SyncSendCreate = 12, - SyncSendUpdate = 13, - SyncSendDelete = 14, - } + SyncSendCreate = 12, + SyncSendUpdate = 13, + SyncSendDelete = 14, } diff --git a/src/Core/Enums/ReferenceEventSource.cs b/src/Core/Enums/ReferenceEventSource.cs index 0a19b0772..3d7ad85ff 100644 --- a/src/Core/Enums/ReferenceEventSource.cs +++ b/src/Core/Enums/ReferenceEventSource.cs @@ -1,12 +1,11 @@ using System.Runtime.Serialization; -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum ReferenceEventSource { - public enum ReferenceEventSource - { - [EnumMember(Value = "organization")] - Organization, - [EnumMember(Value = "user")] - User, - } + [EnumMember(Value = "organization")] + Organization, + [EnumMember(Value = "user")] + User, } diff --git a/src/Core/Enums/ReferenceEventType.cs b/src/Core/Enums/ReferenceEventType.cs index efd631f32..1a925736c 100644 --- a/src/Core/Enums/ReferenceEventType.cs +++ b/src/Core/Enums/ReferenceEventType.cs @@ -1,44 +1,43 @@ using System.Runtime.Serialization; -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum ReferenceEventType { - public enum ReferenceEventType - { - [EnumMember(Value = "signup")] - Signup, - [EnumMember(Value = "upgrade-plan")] - UpgradePlan, - [EnumMember(Value = "adjust-storage")] - AdjustStorage, - [EnumMember(Value = "adjust-seats")] - AdjustSeats, - [EnumMember(Value = "cancel-subscription")] - CancelSubscription, - [EnumMember(Value = "reinstate-subscription")] - ReinstateSubscription, - [EnumMember(Value = "delete-account")] - DeleteAccount, - [EnumMember(Value = "confirm-email")] - ConfirmEmailAddress, - [EnumMember(Value = "invited-users")] - InvitedUsers, - [EnumMember(Value = "rebilled")] - Rebilled, - [EnumMember(Value = "send-created")] - SendCreated, - [EnumMember(Value = "send-accessed")] - SendAccessed, - [EnumMember(Value = "directory-synced")] - DirectorySynced, - [EnumMember(Value = "vault-imported")] - VaultImported, - [EnumMember(Value = "cipher-created")] - CipherCreated, - [EnumMember(Value = "group-created")] - GroupCreated, - [EnumMember(Value = "collection-created")] - CollectionCreated, - [EnumMember(Value = "organization-edited-by-admin")] - OrganizationEditedByAdmin - } + [EnumMember(Value = "signup")] + Signup, + [EnumMember(Value = "upgrade-plan")] + UpgradePlan, + [EnumMember(Value = "adjust-storage")] + AdjustStorage, + [EnumMember(Value = "adjust-seats")] + AdjustSeats, + [EnumMember(Value = "cancel-subscription")] + CancelSubscription, + [EnumMember(Value = "reinstate-subscription")] + ReinstateSubscription, + [EnumMember(Value = "delete-account")] + DeleteAccount, + [EnumMember(Value = "confirm-email")] + ConfirmEmailAddress, + [EnumMember(Value = "invited-users")] + InvitedUsers, + [EnumMember(Value = "rebilled")] + Rebilled, + [EnumMember(Value = "send-created")] + SendCreated, + [EnumMember(Value = "send-accessed")] + SendAccessed, + [EnumMember(Value = "directory-synced")] + DirectorySynced, + [EnumMember(Value = "vault-imported")] + VaultImported, + [EnumMember(Value = "cipher-created")] + CipherCreated, + [EnumMember(Value = "group-created")] + GroupCreated, + [EnumMember(Value = "collection-created")] + CollectionCreated, + [EnumMember(Value = "organization-edited-by-admin")] + OrganizationEditedByAdmin } diff --git a/src/Core/Enums/Saml2BindingType.cs b/src/Core/Enums/Saml2BindingType.cs index 0c0882bc4..c02a5d7cc 100644 --- a/src/Core/Enums/Saml2BindingType.cs +++ b/src/Core/Enums/Saml2BindingType.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum Saml2BindingType : byte { - public enum Saml2BindingType : byte - { - HttpRedirect = 1, - HttpPost = 2, - } + HttpRedirect = 1, + HttpPost = 2, } diff --git a/src/Core/Enums/Saml2NameIdFormat.cs b/src/Core/Enums/Saml2NameIdFormat.cs index 9ba83e58f..f90426e5c 100644 --- a/src/Core/Enums/Saml2NameIdFormat.cs +++ b/src/Core/Enums/Saml2NameIdFormat.cs @@ -1,15 +1,14 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum Saml2NameIdFormat : byte { - public enum Saml2NameIdFormat : byte - { - NotConfigured = 0, - Unspecified = 1, - EmailAddress = 2, - X509SubjectName = 3, - WindowsDomainQualifiedName = 4, - KerberosPrincipalName = 5, - EntityIdentifier = 6, - Persistent = 7, - Transient = 8, - } + NotConfigured = 0, + Unspecified = 1, + EmailAddress = 2, + X509SubjectName = 3, + WindowsDomainQualifiedName = 4, + KerberosPrincipalName = 5, + EntityIdentifier = 6, + Persistent = 7, + Transient = 8, } diff --git a/src/Core/Enums/Saml2SigningBehavior.cs b/src/Core/Enums/Saml2SigningBehavior.cs index a02e5b1d9..25344dbc8 100644 --- a/src/Core/Enums/Saml2SigningBehavior.cs +++ b/src/Core/Enums/Saml2SigningBehavior.cs @@ -1,9 +1,8 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum Saml2SigningBehavior : byte { - public enum Saml2SigningBehavior : byte - { - IfIdpWantAuthnRequestsSigned = 0, - Always = 1, - Never = 3 - } + IfIdpWantAuthnRequestsSigned = 0, + Always = 1, + Never = 3 } diff --git a/src/Core/Enums/ScimProviderType.cs b/src/Core/Enums/ScimProviderType.cs index 18039c87c..c1d467039 100644 --- a/src/Core/Enums/ScimProviderType.cs +++ b/src/Core/Enums/ScimProviderType.cs @@ -1,13 +1,12 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum ScimProviderType : byte { - public enum ScimProviderType : byte - { - Default = 0, - AzureAd = 1, - Okta = 2, - OneLogin = 3, - JumpCloud = 4, - GoogleWorkspace = 5, - Rippling = 6, - } + Default = 0, + AzureAd = 1, + Okta = 2, + OneLogin = 3, + JumpCloud = 4, + GoogleWorkspace = 5, + Rippling = 6, } diff --git a/src/Core/Enums/SecureNoteType.cs b/src/Core/Enums/SecureNoteType.cs index cc84edfc3..cdd565e7c 100644 --- a/src/Core/Enums/SecureNoteType.cs +++ b/src/Core/Enums/SecureNoteType.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum SecureNoteType : byte { - public enum SecureNoteType : byte - { - Generic = 0 - } + Generic = 0 } diff --git a/src/Core/Enums/SendType.cs b/src/Core/Enums/SendType.cs index a52008556..ce59df6b3 100644 --- a/src/Core/Enums/SendType.cs +++ b/src/Core/Enums/SendType.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum SendType : byte { - public enum SendType : byte - { - Text = 0, - File = 1 - } + Text = 0, + File = 1 } diff --git a/src/Core/Enums/SsoType.cs b/src/Core/Enums/SsoType.cs index 3c1884bd7..3e890817f 100644 --- a/src/Core/Enums/SsoType.cs +++ b/src/Core/Enums/SsoType.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum SsoType : byte { - public enum SsoType : byte - { - OpenIdConnect = 1, - Saml2 = 2, - } + OpenIdConnect = 1, + Saml2 = 2, } diff --git a/src/Core/Enums/SupportedDatabaseProviders.cs b/src/Core/Enums/SupportedDatabaseProviders.cs index c38a023c4..81e60b58e 100644 --- a/src/Core/Enums/SupportedDatabaseProviders.cs +++ b/src/Core/Enums/SupportedDatabaseProviders.cs @@ -1,10 +1,9 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum SupportedDatabaseProviders { - public enum SupportedDatabaseProviders - { - SqlServer, - MySql, - Postgres, - } + SqlServer, + MySql, + Postgres, } diff --git a/src/Core/Enums/TransactionType.cs b/src/Core/Enums/TransactionType.cs index 02556ae1d..6a5107763 100644 --- a/src/Core/Enums/TransactionType.cs +++ b/src/Core/Enums/TransactionType.cs @@ -1,18 +1,17 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum TransactionType : byte { - public enum TransactionType : byte - { - [Display(Name = "Charge")] - Charge = 0, - [Display(Name = "Credit")] - Credit = 1, - [Display(Name = "Promotional Credit")] - PromotionalCredit = 2, - [Display(Name = "Referral Credit")] - ReferralCredit = 3, - [Display(Name = "Refund")] - Refund = 4, - } + [Display(Name = "Charge")] + Charge = 0, + [Display(Name = "Credit")] + Credit = 1, + [Display(Name = "Promotional Credit")] + PromotionalCredit = 2, + [Display(Name = "Referral Credit")] + ReferralCredit = 3, + [Display(Name = "Refund")] + Refund = 4, } diff --git a/src/Core/Enums/TwoFactorProviderType.cs b/src/Core/Enums/TwoFactorProviderType.cs index 40c4e5511..31d626991 100644 --- a/src/Core/Enums/TwoFactorProviderType.cs +++ b/src/Core/Enums/TwoFactorProviderType.cs @@ -1,14 +1,13 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum TwoFactorProviderType : byte { - public enum TwoFactorProviderType : byte - { - Authenticator = 0, - Email = 1, - Duo = 2, - YubiKey = 3, - U2f = 4, // Deprecated - Remember = 5, - OrganizationDuo = 6, - WebAuthn = 7, - } + Authenticator = 0, + Email = 1, + Duo = 2, + YubiKey = 3, + U2f = 4, // Deprecated + Remember = 5, + OrganizationDuo = 6, + WebAuthn = 7, } diff --git a/src/Core/Enums/UriMatchType.cs b/src/Core/Enums/UriMatchType.cs index 569437298..593caf40c 100644 --- a/src/Core/Enums/UriMatchType.cs +++ b/src/Core/Enums/UriMatchType.cs @@ -1,12 +1,11 @@ -namespace Bit.Core.Enums +namespace Bit.Core.Enums; + +public enum UriMatchType : byte { - public enum UriMatchType : byte - { - Domain = 0, - Host = 1, - StartsWith = 2, - Exact = 3, - RegularExpression = 4, - Never = 5 - } + Domain = 0, + Host = 1, + StartsWith = 2, + Exact = 3, + RegularExpression = 4, + Never = 5 } diff --git a/src/Core/Exceptions/BadRequestException.cs b/src/Core/Exceptions/BadRequestException.cs index 686bf786c..d18bd041e 100644 --- a/src/Core/Exceptions/BadRequestException.cs +++ b/src/Core/Exceptions/BadRequestException.cs @@ -1,31 +1,30 @@ using Microsoft.AspNetCore.Mvc.ModelBinding; -namespace Bit.Core.Exceptions +namespace Bit.Core.Exceptions; + +public class BadRequestException : Exception { - public class BadRequestException : Exception + public BadRequestException(string message) + : base(message) + { } + + public BadRequestException(string key, string errorMessage) + : base("The model state is invalid.") { - public BadRequestException(string message) - : base(message) - { } - - public BadRequestException(string key, string errorMessage) - : base("The model state is invalid.") - { - ModelState = new ModelStateDictionary(); - ModelState.AddModelError(key, errorMessage); - } - - public BadRequestException(ModelStateDictionary modelState) - : base("The model state is invalid.") - { - if (modelState.IsValid || modelState.ErrorCount == 0) - { - return; - } - - ModelState = modelState; - } - - public ModelStateDictionary ModelState { get; set; } + ModelState = new ModelStateDictionary(); + ModelState.AddModelError(key, errorMessage); } + + public BadRequestException(ModelStateDictionary modelState) + : base("The model state is invalid.") + { + if (modelState.IsValid || modelState.ErrorCount == 0) + { + return; + } + + ModelState = modelState; + } + + public ModelStateDictionary ModelState { get; set; } } diff --git a/src/Core/Exceptions/GatewayException.cs b/src/Core/Exceptions/GatewayException.cs index d97511a68..73e8cd761 100644 --- a/src/Core/Exceptions/GatewayException.cs +++ b/src/Core/Exceptions/GatewayException.cs @@ -1,9 +1,8 @@ -namespace Bit.Core.Exceptions +namespace Bit.Core.Exceptions; + +public class GatewayException : Exception { - public class GatewayException : Exception - { - public GatewayException(string message, Exception innerException = null) - : base(message, innerException) - { } - } + public GatewayException(string message, Exception innerException = null) + : base(message, innerException) + { } } diff --git a/src/Core/Exceptions/InvalidEmailException.cs b/src/Core/Exceptions/InvalidEmailException.cs index 64ede1fdb..1f17acf62 100644 --- a/src/Core/Exceptions/InvalidEmailException.cs +++ b/src/Core/Exceptions/InvalidEmailException.cs @@ -1,11 +1,10 @@ -namespace Bit.Core.Exceptions -{ - public class InvalidEmailException : Exception - { - public InvalidEmailException() - : base("Invalid email.") - { +namespace Bit.Core.Exceptions; + +public class InvalidEmailException : Exception +{ + public InvalidEmailException() + : base("Invalid email.") + { - } } } diff --git a/src/Core/Exceptions/InvalidGatewayCustomerIdException.cs b/src/Core/Exceptions/InvalidGatewayCustomerIdException.cs index ad3a4544a..cfc7c56c1 100644 --- a/src/Core/Exceptions/InvalidGatewayCustomerIdException.cs +++ b/src/Core/Exceptions/InvalidGatewayCustomerIdException.cs @@ -1,11 +1,10 @@ -namespace Bit.Core.Exceptions -{ - public class InvalidGatewayCustomerIdException : Exception - { - public InvalidGatewayCustomerIdException() - : base("Invalid gateway customerId.") - { +namespace Bit.Core.Exceptions; + +public class InvalidGatewayCustomerIdException : Exception +{ + public InvalidGatewayCustomerIdException() + : base("Invalid gateway customerId.") + { - } } } diff --git a/src/Core/Exceptions/NotFoundException.cs b/src/Core/Exceptions/NotFoundException.cs index a47023093..3f52f792c 100644 --- a/src/Core/Exceptions/NotFoundException.cs +++ b/src/Core/Exceptions/NotFoundException.cs @@ -1,4 +1,3 @@ -namespace Bit.Core.Exceptions -{ - public class NotFoundException : Exception { } -} +namespace Bit.Core.Exceptions; + +public class NotFoundException : Exception { } diff --git a/src/Core/HostedServices/ApplicationCacheHostedService.cs b/src/Core/HostedServices/ApplicationCacheHostedService.cs index a5a27e5de..d5f4b77e3 100644 --- a/src/Core/HostedServices/ApplicationCacheHostedService.cs +++ b/src/Core/HostedServices/ApplicationCacheHostedService.cs @@ -8,100 +8,99 @@ using Microsoft.Azure.ServiceBus.Management; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; -namespace Bit.Core.HostedServices +namespace Bit.Core.HostedServices; + +public class ApplicationCacheHostedService : IHostedService, IDisposable { - public class ApplicationCacheHostedService : IHostedService, IDisposable + private readonly InMemoryServiceBusApplicationCacheService _applicationCacheService; + private readonly IOrganizationRepository _organizationRepository; + protected readonly ILogger _logger; + private readonly SubscriptionClient _subscriptionClient; + private readonly ManagementClient _managementClient; + private readonly string _subName; + private readonly string _topicName; + + public ApplicationCacheHostedService( + IApplicationCacheService applicationCacheService, + IOrganizationRepository organizationRepository, + ILogger logger, + GlobalSettings globalSettings) { - private readonly InMemoryServiceBusApplicationCacheService _applicationCacheService; - private readonly IOrganizationRepository _organizationRepository; - protected readonly ILogger _logger; - private readonly SubscriptionClient _subscriptionClient; - private readonly ManagementClient _managementClient; - private readonly string _subName; - private readonly string _topicName; + _topicName = globalSettings.ServiceBus.ApplicationCacheTopicName; + _subName = CoreHelpers.GetApplicationCacheServiceBusSubcriptionName(globalSettings); + _applicationCacheService = applicationCacheService as InMemoryServiceBusApplicationCacheService; + _organizationRepository = organizationRepository; + _logger = logger; + _managementClient = new ManagementClient(globalSettings.ServiceBus.ConnectionString); + _subscriptionClient = new SubscriptionClient(globalSettings.ServiceBus.ConnectionString, + _topicName, _subName); + } - public ApplicationCacheHostedService( - IApplicationCacheService applicationCacheService, - IOrganizationRepository organizationRepository, - ILogger logger, - GlobalSettings globalSettings) + public virtual async Task StartAsync(CancellationToken cancellationToken) + { + try { - _topicName = globalSettings.ServiceBus.ApplicationCacheTopicName; - _subName = CoreHelpers.GetApplicationCacheServiceBusSubcriptionName(globalSettings); - _applicationCacheService = applicationCacheService as InMemoryServiceBusApplicationCacheService; - _organizationRepository = organizationRepository; - _logger = logger; - _managementClient = new ManagementClient(globalSettings.ServiceBus.ConnectionString); - _subscriptionClient = new SubscriptionClient(globalSettings.ServiceBus.ConnectionString, - _topicName, _subName); + await _managementClient.CreateSubscriptionAsync(new SubscriptionDescription(_topicName, _subName) + { + DefaultMessageTimeToLive = TimeSpan.FromDays(14), + LockDuration = TimeSpan.FromSeconds(30), + EnableDeadLetteringOnFilterEvaluationExceptions = true, + EnableDeadLetteringOnMessageExpiration = true, + }, new RuleDescription("default", new SqlFilter($"sys.Label != '{_subName}'"))); } - - public virtual async Task StartAsync(CancellationToken cancellationToken) - { - try + catch (MessagingEntityAlreadyExistsException) { } + _subscriptionClient.RegisterMessageHandler(ProcessMessageAsync, + new MessageHandlerOptions(ExceptionReceivedHandlerAsync) { - await _managementClient.CreateSubscriptionAsync(new SubscriptionDescription(_topicName, _subName) - { - DefaultMessageTimeToLive = TimeSpan.FromDays(14), - LockDuration = TimeSpan.FromSeconds(30), - EnableDeadLetteringOnFilterEvaluationExceptions = true, - EnableDeadLetteringOnMessageExpiration = true, - }, new RuleDescription("default", new SqlFilter($"sys.Label != '{_subName}'"))); - } - catch (MessagingEntityAlreadyExistsException) { } - _subscriptionClient.RegisterMessageHandler(ProcessMessageAsync, - new MessageHandlerOptions(ExceptionReceivedHandlerAsync) - { - MaxConcurrentCalls = 2, - AutoComplete = false, - }); + MaxConcurrentCalls = 2, + AutoComplete = false, + }); + } + + public virtual async Task StopAsync(CancellationToken cancellationToken) + { + await _subscriptionClient.CloseAsync(); + try + { + await _managementClient.DeleteSubscriptionAsync(_topicName, _subName, cancellationToken); } + catch { } + } - public virtual async Task StopAsync(CancellationToken cancellationToken) + public virtual void Dispose() + { } + + private async Task ProcessMessageAsync(Message message, CancellationToken cancellationToken) + { + if (message.Label != _subName && _applicationCacheService != null) { - await _subscriptionClient.CloseAsync(); - try + switch ((ApplicationCacheMessageType)message.UserProperties["type"]) { - await _managementClient.DeleteSubscriptionAsync(_topicName, _subName, cancellationToken); - } - catch { } - } - - public virtual void Dispose() - { } - - private async Task ProcessMessageAsync(Message message, CancellationToken cancellationToken) - { - if (message.Label != _subName && _applicationCacheService != null) - { - switch ((ApplicationCacheMessageType)message.UserProperties["type"]) - { - case ApplicationCacheMessageType.UpsertOrganizationAbility: - var upsertedOrgId = (Guid)message.UserProperties["id"]; - var upsertedOrg = await _organizationRepository.GetByIdAsync(upsertedOrgId); - if (upsertedOrg != null) - { - await _applicationCacheService.BaseUpsertOrganizationAbilityAsync(upsertedOrg); - } - break; - case ApplicationCacheMessageType.DeleteOrganizationAbility: - await _applicationCacheService.BaseDeleteOrganizationAbilityAsync( - (Guid)message.UserProperties["id"]); - break; - default: - break; - } - } - if (!cancellationToken.IsCancellationRequested) - { - await _subscriptionClient.CompleteAsync(message.SystemProperties.LockToken); + case ApplicationCacheMessageType.UpsertOrganizationAbility: + var upsertedOrgId = (Guid)message.UserProperties["id"]; + var upsertedOrg = await _organizationRepository.GetByIdAsync(upsertedOrgId); + if (upsertedOrg != null) + { + await _applicationCacheService.BaseUpsertOrganizationAbilityAsync(upsertedOrg); + } + break; + case ApplicationCacheMessageType.DeleteOrganizationAbility: + await _applicationCacheService.BaseDeleteOrganizationAbilityAsync( + (Guid)message.UserProperties["id"]); + break; + default: + break; } } - - private Task ExceptionReceivedHandlerAsync(ExceptionReceivedEventArgs args) + if (!cancellationToken.IsCancellationRequested) { - _logger.LogError(args.Exception, "Message handler encountered an exception."); - return Task.FromResult(0); + await _subscriptionClient.CompleteAsync(message.SystemProperties.LockToken); } } + + private Task ExceptionReceivedHandlerAsync(ExceptionReceivedEventArgs args) + { + _logger.LogError(args.Exception, "Message handler encountered an exception."); + return Task.FromResult(0); + } } diff --git a/src/Core/HostedServices/IpRateLimitSeedStartupService.cs b/src/Core/HostedServices/IpRateLimitSeedStartupService.cs index dd77982cb..a6869d929 100644 --- a/src/Core/HostedServices/IpRateLimitSeedStartupService.cs +++ b/src/Core/HostedServices/IpRateLimitSeedStartupService.cs @@ -1,41 +1,40 @@ using AspNetCoreRateLimit; using Microsoft.Extensions.Hosting; -namespace Bit.Core.HostedServices +namespace Bit.Core.HostedServices; + +/// +/// A startup service that will seed the IP rate limiting stores with any values in the +/// GlobalSettings configuration. +/// +/// +/// Using an here because it runs before the request processing pipeline +/// is configured, so that any rate limiting configuration is seeded/applied before any requests come in. +/// +/// +/// This is a cleaner alternative to modifying Program.cs in every project that requires rate limiting as +/// described/suggested here: +/// https://github.com/stefanprodan/AspNetCoreRateLimit/wiki/Version-3.0.0-Breaking-Changes +/// +/// +public class IpRateLimitSeedStartupService : IHostedService { - /// - /// A startup service that will seed the IP rate limiting stores with any values in the - /// GlobalSettings configuration. - /// - /// - /// Using an here because it runs before the request processing pipeline - /// is configured, so that any rate limiting configuration is seeded/applied before any requests come in. - /// - /// - /// This is a cleaner alternative to modifying Program.cs in every project that requires rate limiting as - /// described/suggested here: - /// https://github.com/stefanprodan/AspNetCoreRateLimit/wiki/Version-3.0.0-Breaking-Changes - /// - /// - public class IpRateLimitSeedStartupService : IHostedService + private readonly IIpPolicyStore _ipPolicyStore; + private readonly IClientPolicyStore _clientPolicyStore; + + public IpRateLimitSeedStartupService(IIpPolicyStore ipPolicyStore, IClientPolicyStore clientPolicyStore) { - private readonly IIpPolicyStore _ipPolicyStore; - private readonly IClientPolicyStore _clientPolicyStore; - - public IpRateLimitSeedStartupService(IIpPolicyStore ipPolicyStore, IClientPolicyStore clientPolicyStore) - { - _ipPolicyStore = ipPolicyStore; - _clientPolicyStore = clientPolicyStore; - } - - public async Task StartAsync(CancellationToken cancellationToken) - { - // Seed the policies from GlobalSettings - await _ipPolicyStore.SeedAsync(); - await _clientPolicyStore.SeedAsync(); - } - - // noop - public Task StopAsync(CancellationToken cancellationToken) => Task.CompletedTask; + _ipPolicyStore = ipPolicyStore; + _clientPolicyStore = clientPolicyStore; } + + public async Task StartAsync(CancellationToken cancellationToken) + { + // Seed the policies from GlobalSettings + await _ipPolicyStore.SeedAsync(); + await _clientPolicyStore.SeedAsync(); + } + + // noop + public Task StopAsync(CancellationToken cancellationToken) => Task.CompletedTask; } diff --git a/src/Core/Identity/AuthenticatorTokenProvider.cs b/src/Core/Identity/AuthenticatorTokenProvider.cs index 5eef3869d..8bda023e5 100644 --- a/src/Core/Identity/AuthenticatorTokenProvider.cs +++ b/src/Core/Identity/AuthenticatorTokenProvider.cs @@ -5,42 +5,41 @@ using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.DependencyInjection; using OtpNet; -namespace Bit.Core.Identity +namespace Bit.Core.Identity; + +public class AuthenticatorTokenProvider : IUserTwoFactorTokenProvider { - public class AuthenticatorTokenProvider : IUserTwoFactorTokenProvider + private readonly IServiceProvider _serviceProvider; + + public AuthenticatorTokenProvider(IServiceProvider serviceProvider) { - private readonly IServiceProvider _serviceProvider; + _serviceProvider = serviceProvider; + } - public AuthenticatorTokenProvider(IServiceProvider serviceProvider) + public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) + { + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Authenticator); + if (string.IsNullOrWhiteSpace((string)provider?.MetaData["Key"])) { - _serviceProvider = serviceProvider; + return false; } + return await _serviceProvider.GetRequiredService() + .TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.Authenticator, user); + } - public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) - { - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Authenticator); - if (string.IsNullOrWhiteSpace((string)provider?.MetaData["Key"])) - { - return false; - } - return await _serviceProvider.GetRequiredService() - .TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.Authenticator, user); - } + public Task GenerateAsync(string purpose, UserManager manager, User user) + { + return Task.FromResult(null); + } - public Task GenerateAsync(string purpose, UserManager manager, User user) - { - return Task.FromResult(null); - } + public Task ValidateAsync(string purpose, string token, UserManager manager, User user) + { + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Authenticator); + var otp = new Totp(Base32Encoding.ToBytes((string)provider.MetaData["Key"])); - public Task ValidateAsync(string purpose, string token, UserManager manager, User user) - { - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Authenticator); - var otp = new Totp(Base32Encoding.ToBytes((string)provider.MetaData["Key"])); + long timeStepMatched; + var valid = otp.VerifyTotp(token, out timeStepMatched, new VerificationWindow(1, 1)); - long timeStepMatched; - var valid = otp.VerifyTotp(token, out timeStepMatched, new VerificationWindow(1, 1)); - - return Task.FromResult(valid); - } + return Task.FromResult(valid); } } diff --git a/src/Core/Identity/CustomIdentityServiceCollectionExtensions.cs b/src/Core/Identity/CustomIdentityServiceCollectionExtensions.cs index 0acb4a3f4..a63bde879 100644 --- a/src/Core/Identity/CustomIdentityServiceCollectionExtensions.cs +++ b/src/Core/Identity/CustomIdentityServiceCollectionExtensions.cs @@ -2,49 +2,48 @@ using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.DependencyInjection.Extensions; -namespace Microsoft.Extensions.DependencyInjection +namespace Microsoft.Extensions.DependencyInjection; + +// ref: https://github.com/aspnet/Identity/blob/dev/src/Microsoft.AspNetCore.Identity/IdentityServiceCollectionExtensions.cs +public static class CustomIdentityServiceCollectionExtensions { - // ref: https://github.com/aspnet/Identity/blob/dev/src/Microsoft.AspNetCore.Identity/IdentityServiceCollectionExtensions.cs - public static class CustomIdentityServiceCollectionExtensions + public static IdentityBuilder AddIdentityWithoutCookieAuth( + this IServiceCollection services) + where TUser : class + where TRole : class { - public static IdentityBuilder AddIdentityWithoutCookieAuth( - this IServiceCollection services) - where TUser : class - where TRole : class + return services.AddIdentityWithoutCookieAuth(setupAction: null); + } + + public static IdentityBuilder AddIdentityWithoutCookieAuth( + this IServiceCollection services, + Action setupAction) + where TUser : class + where TRole : class + { + // Hosting doesn't add IHttpContextAccessor by default + services.AddHttpContextAccessor(); + // Identity services + services.TryAddScoped, UserValidator>(); + services.TryAddScoped, PasswordValidator>(); + services.TryAddScoped, PasswordHasher>(); + services.TryAddScoped(); + services.TryAddScoped, RoleValidator>(); + // No interface for the error describer so we can add errors without rev'ing the interface + services.TryAddScoped(); + services.TryAddScoped>(); + services.TryAddScoped>(); + services.TryAddScoped, UserClaimsPrincipalFactory>(); + services.TryAddScoped, DefaultUserConfirmation>(); + services.TryAddScoped>(); + services.TryAddScoped>(); + services.TryAddScoped>(); + + if (setupAction != null) { - return services.AddIdentityWithoutCookieAuth(setupAction: null); + services.Configure(setupAction); } - public static IdentityBuilder AddIdentityWithoutCookieAuth( - this IServiceCollection services, - Action setupAction) - where TUser : class - where TRole : class - { - // Hosting doesn't add IHttpContextAccessor by default - services.AddHttpContextAccessor(); - // Identity services - services.TryAddScoped, UserValidator>(); - services.TryAddScoped, PasswordValidator>(); - services.TryAddScoped, PasswordHasher>(); - services.TryAddScoped(); - services.TryAddScoped, RoleValidator>(); - // No interface for the error describer so we can add errors without rev'ing the interface - services.TryAddScoped(); - services.TryAddScoped>(); - services.TryAddScoped>(); - services.TryAddScoped, UserClaimsPrincipalFactory>(); - services.TryAddScoped, DefaultUserConfirmation>(); - services.TryAddScoped>(); - services.TryAddScoped>(); - services.TryAddScoped>(); - - if (setupAction != null) - { - services.Configure(setupAction); - } - - return new IdentityBuilder(typeof(TUser), typeof(TRole), services); - } + return new IdentityBuilder(typeof(TUser), typeof(TRole), services); } } diff --git a/src/Core/Identity/DuoWebTokenProvider.cs b/src/Core/Identity/DuoWebTokenProvider.cs index 3ef02df6f..396f3b400 100644 --- a/src/Core/Identity/DuoWebTokenProvider.cs +++ b/src/Core/Identity/DuoWebTokenProvider.cs @@ -7,81 +7,80 @@ using Bit.Core.Utilities.Duo; using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.Identity +namespace Bit.Core.Identity; + +public class DuoWebTokenProvider : IUserTwoFactorTokenProvider { - public class DuoWebTokenProvider : IUserTwoFactorTokenProvider + private readonly IServiceProvider _serviceProvider; + private readonly GlobalSettings _globalSettings; + + public DuoWebTokenProvider( + IServiceProvider serviceProvider, + GlobalSettings globalSettings) { - private readonly IServiceProvider _serviceProvider; - private readonly GlobalSettings _globalSettings; + _serviceProvider = serviceProvider; + _globalSettings = globalSettings; + } - public DuoWebTokenProvider( - IServiceProvider serviceProvider, - GlobalSettings globalSettings) + public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) + { + var userService = _serviceProvider.GetRequiredService(); + if (!(await userService.CanAccessPremium(user))) { - _serviceProvider = serviceProvider; - _globalSettings = globalSettings; + return false; } - public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Duo); + if (!HasProperMetaData(provider)) { - var userService = _serviceProvider.GetRequiredService(); - if (!(await userService.CanAccessPremium(user))) - { - return false; - } - - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Duo); - if (!HasProperMetaData(provider)) - { - return false; - } - - return await userService.TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.Duo, user); + return false; } - public async Task GenerateAsync(string purpose, UserManager manager, User user) + return await userService.TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.Duo, user); + } + + public async Task GenerateAsync(string purpose, UserManager manager, User user) + { + var userService = _serviceProvider.GetRequiredService(); + if (!(await userService.CanAccessPremium(user))) { - var userService = _serviceProvider.GetRequiredService(); - if (!(await userService.CanAccessPremium(user))) - { - return null; - } - - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Duo); - if (!HasProperMetaData(provider)) - { - return null; - } - - var signatureRequest = DuoWeb.SignRequest((string)provider.MetaData["IKey"], - (string)provider.MetaData["SKey"], _globalSettings.Duo.AKey, user.Email); - return signatureRequest; + return null; } - public async Task ValidateAsync(string purpose, string token, UserManager manager, User user) + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Duo); + if (!HasProperMetaData(provider)) { - var userService = _serviceProvider.GetRequiredService(); - if (!(await userService.CanAccessPremium(user))) - { - return false; - } - - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Duo); - if (!HasProperMetaData(provider)) - { - return false; - } - - var response = DuoWeb.VerifyResponse((string)provider.MetaData["IKey"], (string)provider.MetaData["SKey"], - _globalSettings.Duo.AKey, token); - - return response == user.Email; + return null; } - private bool HasProperMetaData(TwoFactorProvider provider) + var signatureRequest = DuoWeb.SignRequest((string)provider.MetaData["IKey"], + (string)provider.MetaData["SKey"], _globalSettings.Duo.AKey, user.Email); + return signatureRequest; + } + + public async Task ValidateAsync(string purpose, string token, UserManager manager, User user) + { + var userService = _serviceProvider.GetRequiredService(); + if (!(await userService.CanAccessPremium(user))) { - return provider?.MetaData != null && provider.MetaData.ContainsKey("IKey") && - provider.MetaData.ContainsKey("SKey") && provider.MetaData.ContainsKey("Host"); + return false; } + + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Duo); + if (!HasProperMetaData(provider)) + { + return false; + } + + var response = DuoWeb.VerifyResponse((string)provider.MetaData["IKey"], (string)provider.MetaData["SKey"], + _globalSettings.Duo.AKey, token); + + return response == user.Email; + } + + private bool HasProperMetaData(TwoFactorProvider provider) + { + return provider?.MetaData != null && provider.MetaData.ContainsKey("IKey") && + provider.MetaData.ContainsKey("SKey") && provider.MetaData.ContainsKey("Host"); } } diff --git a/src/Core/Identity/EmailTokenProvider.cs b/src/Core/Identity/EmailTokenProvider.cs index a0002b47f..71987fa86 100644 --- a/src/Core/Identity/EmailTokenProvider.cs +++ b/src/Core/Identity/EmailTokenProvider.cs @@ -5,80 +5,79 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.Identity +namespace Bit.Core.Identity; + +public class EmailTokenProvider : IUserTwoFactorTokenProvider { - public class EmailTokenProvider : IUserTwoFactorTokenProvider + private readonly IServiceProvider _serviceProvider; + + public EmailTokenProvider(IServiceProvider serviceProvider) { - private readonly IServiceProvider _serviceProvider; + _serviceProvider = serviceProvider; + } - public EmailTokenProvider(IServiceProvider serviceProvider) + public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) + { + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); + if (!HasProperMetaData(provider)) { - _serviceProvider = serviceProvider; + return false; } - public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) - { - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); - if (!HasProperMetaData(provider)) - { - return false; - } + return await _serviceProvider.GetRequiredService(). + TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.Email, user); + } - return await _serviceProvider.GetRequiredService(). - TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.Email, user); + public Task GenerateAsync(string purpose, UserManager manager, User user) + { + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); + if (!HasProperMetaData(provider)) + { + return null; } - public Task GenerateAsync(string purpose, UserManager manager, User user) - { - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); - if (!HasProperMetaData(provider)) - { - return null; - } + return Task.FromResult(RedactEmail((string)provider.MetaData["Email"])); + } - return Task.FromResult(RedactEmail((string)provider.MetaData["Email"])); + public Task ValidateAsync(string purpose, string token, UserManager manager, User user) + { + return _serviceProvider.GetRequiredService().VerifyTwoFactorEmailAsync(user, token); + } + + private bool HasProperMetaData(TwoFactorProvider provider) + { + return provider?.MetaData != null && provider.MetaData.ContainsKey("Email") && + !string.IsNullOrWhiteSpace((string)provider.MetaData["Email"]); + } + + private static string RedactEmail(string email) + { + var emailParts = email.Split('@'); + + string shownPart = null; + if (emailParts[0].Length > 2 && emailParts[0].Length <= 4) + { + shownPart = emailParts[0].Substring(0, 1); + } + else if (emailParts[0].Length > 4) + { + shownPart = emailParts[0].Substring(0, 2); + } + else + { + shownPart = string.Empty; } - public Task ValidateAsync(string purpose, string token, UserManager manager, User user) + string redactedPart = null; + if (emailParts[0].Length > 4) { - return _serviceProvider.GetRequiredService().VerifyTwoFactorEmailAsync(user, token); + redactedPart = new string('*', emailParts[0].Length - 2); + } + else + { + redactedPart = new string('*', emailParts[0].Length - shownPart.Length); } - private bool HasProperMetaData(TwoFactorProvider provider) - { - return provider?.MetaData != null && provider.MetaData.ContainsKey("Email") && - !string.IsNullOrWhiteSpace((string)provider.MetaData["Email"]); - } - - private static string RedactEmail(string email) - { - var emailParts = email.Split('@'); - - string shownPart = null; - if (emailParts[0].Length > 2 && emailParts[0].Length <= 4) - { - shownPart = emailParts[0].Substring(0, 1); - } - else if (emailParts[0].Length > 4) - { - shownPart = emailParts[0].Substring(0, 2); - } - else - { - shownPart = string.Empty; - } - - string redactedPart = null; - if (emailParts[0].Length > 4) - { - redactedPart = new string('*', emailParts[0].Length - 2); - } - else - { - redactedPart = new string('*', emailParts[0].Length - shownPart.Length); - } - - return $"{shownPart}{redactedPart}@{emailParts[1]}"; - } + return $"{shownPart}{redactedPart}@{emailParts[1]}"; } } diff --git a/src/Core/Identity/IOrganizationTwoFactorTokenProvider.cs b/src/Core/Identity/IOrganizationTwoFactorTokenProvider.cs index 11157a783..0046add96 100644 --- a/src/Core/Identity/IOrganizationTwoFactorTokenProvider.cs +++ b/src/Core/Identity/IOrganizationTwoFactorTokenProvider.cs @@ -1,11 +1,10 @@ using Bit.Core.Entities; -namespace Bit.Core.Identity +namespace Bit.Core.Identity; + +public interface IOrganizationTwoFactorTokenProvider { - public interface IOrganizationTwoFactorTokenProvider - { - Task CanGenerateTwoFactorTokenAsync(Organization organization); - Task GenerateAsync(Organization organization, User user); - Task ValidateAsync(string token, Organization organization, User user); - } + Task CanGenerateTwoFactorTokenAsync(Organization organization); + Task GenerateAsync(Organization organization, User user); + Task ValidateAsync(string token, Organization organization, User user); } diff --git a/src/Core/Identity/LowerInvariantLookupNormalizer.cs b/src/Core/Identity/LowerInvariantLookupNormalizer.cs index 591b840a4..880a2bbfb 100644 --- a/src/Core/Identity/LowerInvariantLookupNormalizer.cs +++ b/src/Core/Identity/LowerInvariantLookupNormalizer.cs @@ -1,22 +1,21 @@ using Microsoft.AspNetCore.Identity; -namespace Bit.Core.Identity +namespace Bit.Core.Identity; + +public class LowerInvariantLookupNormalizer : ILookupNormalizer { - public class LowerInvariantLookupNormalizer : ILookupNormalizer + public string NormalizeEmail(string email) { - public string NormalizeEmail(string email) - { - return Normalize(email); - } + return Normalize(email); + } - public string NormalizeName(string name) - { - return Normalize(name); - } + public string NormalizeName(string name) + { + return Normalize(name); + } - private string Normalize(string key) - { - return key?.Normalize().ToLowerInvariant(); - } + private string Normalize(string key) + { + return key?.Normalize().ToLowerInvariant(); } } diff --git a/src/Core/Identity/OrganizationDuoWebTokenProvider.cs b/src/Core/Identity/OrganizationDuoWebTokenProvider.cs index cd3f27184..53d979d90 100644 --- a/src/Core/Identity/OrganizationDuoWebTokenProvider.cs +++ b/src/Core/Identity/OrganizationDuoWebTokenProvider.cs @@ -4,73 +4,72 @@ using Bit.Core.Models; using Bit.Core.Settings; using Bit.Core.Utilities.Duo; -namespace Bit.Core.Identity +namespace Bit.Core.Identity; + +public interface IOrganizationDuoWebTokenProvider : IOrganizationTwoFactorTokenProvider { } + +public class OrganizationDuoWebTokenProvider : IOrganizationDuoWebTokenProvider { - public interface IOrganizationDuoWebTokenProvider : IOrganizationTwoFactorTokenProvider { } + private readonly GlobalSettings _globalSettings; - public class OrganizationDuoWebTokenProvider : IOrganizationDuoWebTokenProvider + public OrganizationDuoWebTokenProvider(GlobalSettings globalSettings) { - private readonly GlobalSettings _globalSettings; + _globalSettings = globalSettings; + } - public OrganizationDuoWebTokenProvider(GlobalSettings globalSettings) + public Task CanGenerateTwoFactorTokenAsync(Organization organization) + { + if (organization == null || !organization.Enabled || !organization.Use2fa) { - _globalSettings = globalSettings; + return Task.FromResult(false); } - public Task CanGenerateTwoFactorTokenAsync(Organization organization) - { - if (organization == null || !organization.Enabled || !organization.Use2fa) - { - return Task.FromResult(false); - } + var provider = organization.GetTwoFactorProvider(TwoFactorProviderType.OrganizationDuo); + var canGenerate = organization.TwoFactorProviderIsEnabled(TwoFactorProviderType.OrganizationDuo) + && HasProperMetaData(provider); + return Task.FromResult(canGenerate); + } - var provider = organization.GetTwoFactorProvider(TwoFactorProviderType.OrganizationDuo); - var canGenerate = organization.TwoFactorProviderIsEnabled(TwoFactorProviderType.OrganizationDuo) - && HasProperMetaData(provider); - return Task.FromResult(canGenerate); + public Task GenerateAsync(Organization organization, User user) + { + if (organization == null || !organization.Enabled || !organization.Use2fa) + { + return Task.FromResult(null); } - public Task GenerateAsync(Organization organization, User user) + var provider = organization.GetTwoFactorProvider(TwoFactorProviderType.OrganizationDuo); + if (!HasProperMetaData(provider)) { - if (organization == null || !organization.Enabled || !organization.Use2fa) - { - return Task.FromResult(null); - } - - var provider = organization.GetTwoFactorProvider(TwoFactorProviderType.OrganizationDuo); - if (!HasProperMetaData(provider)) - { - return Task.FromResult(null); - } - - var signatureRequest = DuoWeb.SignRequest(provider.MetaData["IKey"].ToString(), - provider.MetaData["SKey"].ToString(), _globalSettings.Duo.AKey, user.Email); - return Task.FromResult(signatureRequest); + return Task.FromResult(null); } - public Task ValidateAsync(string token, Organization organization, User user) + var signatureRequest = DuoWeb.SignRequest(provider.MetaData["IKey"].ToString(), + provider.MetaData["SKey"].ToString(), _globalSettings.Duo.AKey, user.Email); + return Task.FromResult(signatureRequest); + } + + public Task ValidateAsync(string token, Organization organization, User user) + { + if (organization == null || !organization.Enabled || !organization.Use2fa) { - if (organization == null || !organization.Enabled || !organization.Use2fa) - { - return Task.FromResult(false); - } - - var provider = organization.GetTwoFactorProvider(TwoFactorProviderType.OrganizationDuo); - if (!HasProperMetaData(provider)) - { - return Task.FromResult(false); - } - - var response = DuoWeb.VerifyResponse(provider.MetaData["IKey"].ToString(), - provider.MetaData["SKey"].ToString(), _globalSettings.Duo.AKey, token); - - return Task.FromResult(response == user.Email); + return Task.FromResult(false); } - private bool HasProperMetaData(TwoFactorProvider provider) + var provider = organization.GetTwoFactorProvider(TwoFactorProviderType.OrganizationDuo); + if (!HasProperMetaData(provider)) { - return provider?.MetaData != null && provider.MetaData.ContainsKey("IKey") && - provider.MetaData.ContainsKey("SKey") && provider.MetaData.ContainsKey("Host"); + return Task.FromResult(false); } + + var response = DuoWeb.VerifyResponse(provider.MetaData["IKey"].ToString(), + provider.MetaData["SKey"].ToString(), _globalSettings.Duo.AKey, token); + + return Task.FromResult(response == user.Email); + } + + private bool HasProperMetaData(TwoFactorProvider provider) + { + return provider?.MetaData != null && provider.MetaData.ContainsKey("IKey") && + provider.MetaData.ContainsKey("SKey") && provider.MetaData.ContainsKey("Host"); } } diff --git a/src/Core/Identity/PasswordlessSignInManager.cs b/src/Core/Identity/PasswordlessSignInManager.cs index a9f400058..1ca010835 100644 --- a/src/Core/Identity/PasswordlessSignInManager.cs +++ b/src/Core/Identity/PasswordlessSignInManager.cs @@ -5,86 +5,85 @@ using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -namespace Bit.Core.Identity +namespace Bit.Core.Identity; + +public class PasswordlessSignInManager : SignInManager where TUser : class { - public class PasswordlessSignInManager : SignInManager where TUser : class + public const string PasswordlessSignInPurpose = "PasswordlessSignIn"; + + private readonly IMailService _mailService; + + public PasswordlessSignInManager(UserManager userManager, + IHttpContextAccessor contextAccessor, + IUserClaimsPrincipalFactory claimsFactory, + IOptions optionsAccessor, + ILogger> logger, + IAuthenticationSchemeProvider schemes, + IUserConfirmation confirmation, + IMailService mailService) + : base(userManager, contextAccessor, claimsFactory, optionsAccessor, logger, schemes, confirmation) { - public const string PasswordlessSignInPurpose = "PasswordlessSignIn"; + _mailService = mailService; + } - private readonly IMailService _mailService; - - public PasswordlessSignInManager(UserManager userManager, - IHttpContextAccessor contextAccessor, - IUserClaimsPrincipalFactory claimsFactory, - IOptions optionsAccessor, - ILogger> logger, - IAuthenticationSchemeProvider schemes, - IUserConfirmation confirmation, - IMailService mailService) - : base(userManager, contextAccessor, claimsFactory, optionsAccessor, logger, schemes, confirmation) + public async Task PasswordlessSignInAsync(string email, string returnUrl) + { + var user = await UserManager.FindByEmailAsync(email); + if (user == null) { - _mailService = mailService; + return SignInResult.Failed; } - public async Task PasswordlessSignInAsync(string email, string returnUrl) - { - var user = await UserManager.FindByEmailAsync(email); - if (user == null) - { - return SignInResult.Failed; - } + var token = await UserManager.GenerateUserTokenAsync(user, Options.Tokens.PasswordResetTokenProvider, + PasswordlessSignInPurpose); + await _mailService.SendPasswordlessSignInAsync(returnUrl, token, email); + return SignInResult.Success; + } - var token = await UserManager.GenerateUserTokenAsync(user, Options.Tokens.PasswordResetTokenProvider, - PasswordlessSignInPurpose); - await _mailService.SendPasswordlessSignInAsync(returnUrl, token, email); + public async Task PasswordlessSignInAsync(TUser user, string token, bool isPersistent) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + var attempt = await CheckPasswordlessSignInAsync(user, token); + return attempt.Succeeded ? + await SignInOrTwoFactorAsync(user, isPersistent, bypassTwoFactor: true) : attempt; + } + + public async Task PasswordlessSignInAsync(string email, string token, bool isPersistent) + { + var user = await UserManager.FindByEmailAsync(email); + if (user == null) + { + return SignInResult.Failed; + } + + return await PasswordlessSignInAsync(user, token, isPersistent); + } + + public virtual async Task CheckPasswordlessSignInAsync(TUser user, string token) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + var error = await PreSignInCheck(user); + if (error != null) + { + return error; + } + + if (await UserManager.VerifyUserTokenAsync(user, Options.Tokens.PasswordResetTokenProvider, + PasswordlessSignInPurpose, token)) + { return SignInResult.Success; } - public async Task PasswordlessSignInAsync(TUser user, string token, bool isPersistent) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - var attempt = await CheckPasswordlessSignInAsync(user, token); - return attempt.Succeeded ? - await SignInOrTwoFactorAsync(user, isPersistent, bypassTwoFactor: true) : attempt; - } - - public async Task PasswordlessSignInAsync(string email, string token, bool isPersistent) - { - var user = await UserManager.FindByEmailAsync(email); - if (user == null) - { - return SignInResult.Failed; - } - - return await PasswordlessSignInAsync(user, token, isPersistent); - } - - public virtual async Task CheckPasswordlessSignInAsync(TUser user, string token) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - var error = await PreSignInCheck(user); - if (error != null) - { - return error; - } - - if (await UserManager.VerifyUserTokenAsync(user, Options.Tokens.PasswordResetTokenProvider, - PasswordlessSignInPurpose, token)) - { - return SignInResult.Success; - } - - Logger.LogWarning(2, "User {userId} failed to provide the correct token.", - await UserManager.GetUserIdAsync(user)); - return SignInResult.Failed; - } + Logger.LogWarning(2, "User {userId} failed to provide the correct token.", + await UserManager.GetUserIdAsync(user)); + return SignInResult.Failed; } } diff --git a/src/Core/Identity/ReadOnlyDatabaseIdentityUserStore.cs b/src/Core/Identity/ReadOnlyDatabaseIdentityUserStore.cs index 7f4b76755..70d3da007 100644 --- a/src/Core/Identity/ReadOnlyDatabaseIdentityUserStore.cs +++ b/src/Core/Identity/ReadOnlyDatabaseIdentityUserStore.cs @@ -2,38 +2,37 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Identity; -namespace Bit.Core.Identity +namespace Bit.Core.Identity; + +public class ReadOnlyDatabaseIdentityUserStore : ReadOnlyIdentityUserStore { - public class ReadOnlyDatabaseIdentityUserStore : ReadOnlyIdentityUserStore + private readonly IUserService _userService; + private readonly IUserRepository _userRepository; + + public ReadOnlyDatabaseIdentityUserStore( + IUserService userService, + IUserRepository userRepository) { - private readonly IUserService _userService; - private readonly IUserRepository _userRepository; + _userService = userService; + _userRepository = userRepository; + } - public ReadOnlyDatabaseIdentityUserStore( - IUserService userService, - IUserRepository userRepository) + public override async Task FindByEmailAsync(string normalizedEmail, + CancellationToken cancellationToken = default(CancellationToken)) + { + var user = await _userRepository.GetByEmailAsync(normalizedEmail); + return user?.ToIdentityUser(await _userService.TwoFactorIsEnabledAsync(user)); + } + + public override async Task FindByIdAsync(string userId, + CancellationToken cancellationToken = default(CancellationToken)) + { + if (!Guid.TryParse(userId, out var userIdGuid)) { - _userService = userService; - _userRepository = userRepository; + return null; } - public override async Task FindByEmailAsync(string normalizedEmail, - CancellationToken cancellationToken = default(CancellationToken)) - { - var user = await _userRepository.GetByEmailAsync(normalizedEmail); - return user?.ToIdentityUser(await _userService.TwoFactorIsEnabledAsync(user)); - } - - public override async Task FindByIdAsync(string userId, - CancellationToken cancellationToken = default(CancellationToken)) - { - if (!Guid.TryParse(userId, out var userIdGuid)) - { - return null; - } - - var user = await _userRepository.GetByIdAsync(userIdGuid); - return user?.ToIdentityUser(await _userService.TwoFactorIsEnabledAsync(user)); - } + var user = await _userRepository.GetByIdAsync(userIdGuid); + return user?.ToIdentityUser(await _userService.TwoFactorIsEnabledAsync(user)); } } diff --git a/src/Core/Identity/ReadOnlyEnvIdentityUserStore.cs b/src/Core/Identity/ReadOnlyEnvIdentityUserStore.cs index 26cc7a3c8..341bcd38a 100644 --- a/src/Core/Identity/ReadOnlyEnvIdentityUserStore.cs +++ b/src/Core/Identity/ReadOnlyEnvIdentityUserStore.cs @@ -2,66 +2,65 @@ using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.Configuration; -namespace Bit.Core.Identity +namespace Bit.Core.Identity; + +public class ReadOnlyEnvIdentityUserStore : ReadOnlyIdentityUserStore { - public class ReadOnlyEnvIdentityUserStore : ReadOnlyIdentityUserStore + private readonly IConfiguration _configuration; + + public ReadOnlyEnvIdentityUserStore(IConfiguration configuration) { - private readonly IConfiguration _configuration; + _configuration = configuration; + } - public ReadOnlyEnvIdentityUserStore(IConfiguration configuration) + public override Task FindByEmailAsync(string normalizedEmail, + CancellationToken cancellationToken = default(CancellationToken)) + { + var usersCsv = _configuration["adminSettings:admins"]; + if (!CoreHelpers.SettingHasValue(usersCsv)) { - _configuration = configuration; + return Task.FromResult(null); } - public override Task FindByEmailAsync(string normalizedEmail, - CancellationToken cancellationToken = default(CancellationToken)) + var users = usersCsv.ToLowerInvariant().Split(','); + var usersDict = new Dictionary(); + foreach (var u in users) { - var usersCsv = _configuration["adminSettings:admins"]; - if (!CoreHelpers.SettingHasValue(usersCsv)) + var parts = u.Split(':'); + if (parts.Length == 2) { - return Task.FromResult(null); + var email = parts[0].Trim(); + var stamp = parts[1].Trim(); + usersDict.Add(email, stamp); } - - var users = usersCsv.ToLowerInvariant().Split(','); - var usersDict = new Dictionary(); - foreach (var u in users) + else { - var parts = u.Split(':'); - if (parts.Length == 2) - { - var email = parts[0].Trim(); - var stamp = parts[1].Trim(); - usersDict.Add(email, stamp); - } - else - { - var email = parts[0].Trim(); - usersDict.Add(email, email); - } + var email = parts[0].Trim(); + usersDict.Add(email, email); } - - var userStamp = usersDict.ContainsKey(normalizedEmail) ? usersDict[normalizedEmail] : null; - if (userStamp == null) - { - return Task.FromResult(null); - } - - return Task.FromResult(new IdentityUser - { - Id = normalizedEmail, - Email = normalizedEmail, - NormalizedEmail = normalizedEmail, - EmailConfirmed = true, - UserName = normalizedEmail, - NormalizedUserName = normalizedEmail, - SecurityStamp = userStamp - }); } - public override Task FindByIdAsync(string userId, - CancellationToken cancellationToken = default(CancellationToken)) + var userStamp = usersDict.ContainsKey(normalizedEmail) ? usersDict[normalizedEmail] : null; + if (userStamp == null) { - return FindByEmailAsync(userId, cancellationToken); + return Task.FromResult(null); } + + return Task.FromResult(new IdentityUser + { + Id = normalizedEmail, + Email = normalizedEmail, + NormalizedEmail = normalizedEmail, + EmailConfirmed = true, + UserName = normalizedEmail, + NormalizedUserName = normalizedEmail, + SecurityStamp = userStamp + }); + } + + public override Task FindByIdAsync(string userId, + CancellationToken cancellationToken = default(CancellationToken)) + { + return FindByEmailAsync(userId, cancellationToken); } } diff --git a/src/Core/Identity/ReadOnlyIdentityUserStore.cs b/src/Core/Identity/ReadOnlyIdentityUserStore.cs index d27b0a32f..50c42c819 100644 --- a/src/Core/Identity/ReadOnlyIdentityUserStore.cs +++ b/src/Core/Identity/ReadOnlyIdentityUserStore.cs @@ -1,120 +1,119 @@ using Microsoft.AspNetCore.Identity; -namespace Bit.Core.Identity +namespace Bit.Core.Identity; + +public abstract class ReadOnlyIdentityUserStore : + IUserStore, + IUserEmailStore, + IUserSecurityStampStore { - public abstract class ReadOnlyIdentityUserStore : - IUserStore, - IUserEmailStore, - IUserSecurityStampStore + public void Dispose() { } + + public Task CreateAsync(IdentityUser user, + CancellationToken cancellationToken = default(CancellationToken)) { - public void Dispose() { } + throw new NotImplementedException(); + } - public Task CreateAsync(IdentityUser user, - CancellationToken cancellationToken = default(CancellationToken)) - { - throw new NotImplementedException(); - } + public Task DeleteAsync(IdentityUser user, + CancellationToken cancellationToken = default(CancellationToken)) + { + throw new NotImplementedException(); + } - public Task DeleteAsync(IdentityUser user, - CancellationToken cancellationToken = default(CancellationToken)) - { - throw new NotImplementedException(); - } + public abstract Task FindByEmailAsync(string normalizedEmail, + CancellationToken cancellationToken = default(CancellationToken)); - public abstract Task FindByEmailAsync(string normalizedEmail, - CancellationToken cancellationToken = default(CancellationToken)); + public abstract Task FindByIdAsync(string userId, + CancellationToken cancellationToken = default(CancellationToken)); - public abstract Task FindByIdAsync(string userId, - CancellationToken cancellationToken = default(CancellationToken)); + public async Task FindByNameAsync(string normalizedUserName, + CancellationToken cancellationToken = default(CancellationToken)) + { + return await FindByEmailAsync(normalizedUserName, cancellationToken); + } - public async Task FindByNameAsync(string normalizedUserName, - CancellationToken cancellationToken = default(CancellationToken)) - { - return await FindByEmailAsync(normalizedUserName, cancellationToken); - } + public Task GetEmailAsync(IdentityUser user, + CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Email); + } - public Task GetEmailAsync(IdentityUser user, - CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Email); - } + public Task GetEmailConfirmedAsync(IdentityUser user, + CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.EmailConfirmed); + } - public Task GetEmailConfirmedAsync(IdentityUser user, - CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.EmailConfirmed); - } + public Task GetNormalizedEmailAsync(IdentityUser user, + CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Email); + } - public Task GetNormalizedEmailAsync(IdentityUser user, - CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Email); - } + public Task GetNormalizedUserNameAsync(IdentityUser user, + CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Email); + } - public Task GetNormalizedUserNameAsync(IdentityUser user, - CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Email); - } + public Task GetUserIdAsync(IdentityUser user, + CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Id); + } - public Task GetUserIdAsync(IdentityUser user, - CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Id); - } + public Task GetUserNameAsync(IdentityUser user, + CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Email); + } - public Task GetUserNameAsync(IdentityUser user, - CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Email); - } + public Task SetEmailAsync(IdentityUser user, string email, + CancellationToken cancellationToken = default(CancellationToken)) + { + throw new NotImplementedException(); + } - public Task SetEmailAsync(IdentityUser user, string email, - CancellationToken cancellationToken = default(CancellationToken)) - { - throw new NotImplementedException(); - } + public Task SetEmailConfirmedAsync(IdentityUser user, bool confirmed, + CancellationToken cancellationToken = default(CancellationToken)) + { + throw new NotImplementedException(); + } - public Task SetEmailConfirmedAsync(IdentityUser user, bool confirmed, - CancellationToken cancellationToken = default(CancellationToken)) - { - throw new NotImplementedException(); - } + public Task SetNormalizedEmailAsync(IdentityUser user, string normalizedEmail, + CancellationToken cancellationToken = default(CancellationToken)) + { + user.NormalizedEmail = normalizedEmail; + return Task.FromResult(0); + } - public Task SetNormalizedEmailAsync(IdentityUser user, string normalizedEmail, - CancellationToken cancellationToken = default(CancellationToken)) - { - user.NormalizedEmail = normalizedEmail; - return Task.FromResult(0); - } + public Task SetNormalizedUserNameAsync(IdentityUser user, string normalizedName, + CancellationToken cancellationToken = default(CancellationToken)) + { + user.NormalizedUserName = normalizedName; + return Task.FromResult(0); + } - public Task SetNormalizedUserNameAsync(IdentityUser user, string normalizedName, - CancellationToken cancellationToken = default(CancellationToken)) - { - user.NormalizedUserName = normalizedName; - return Task.FromResult(0); - } + public Task SetUserNameAsync(IdentityUser user, string userName, + CancellationToken cancellationToken = default(CancellationToken)) + { + throw new NotImplementedException(); + } - public Task SetUserNameAsync(IdentityUser user, string userName, - CancellationToken cancellationToken = default(CancellationToken)) - { - throw new NotImplementedException(); - } + public Task UpdateAsync(IdentityUser user, + CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(IdentityResult.Success); + } - public Task UpdateAsync(IdentityUser user, - CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(IdentityResult.Success); - } + public Task SetSecurityStampAsync(IdentityUser user, string stamp, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } - public Task SetSecurityStampAsync(IdentityUser user, string stamp, CancellationToken cancellationToken) - { - throw new NotImplementedException(); - } - - public Task GetSecurityStampAsync(IdentityUser user, CancellationToken cancellationToken) - { - return Task.FromResult(user.SecurityStamp); - } + public Task GetSecurityStampAsync(IdentityUser user, CancellationToken cancellationToken) + { + return Task.FromResult(user.SecurityStamp); } } diff --git a/src/Core/Identity/RoleStore.cs b/src/Core/Identity/RoleStore.cs index f96748f24..d6fe3f42f 100644 --- a/src/Core/Identity/RoleStore.cs +++ b/src/Core/Identity/RoleStore.cs @@ -1,61 +1,60 @@ using Bit.Core.Entities; using Microsoft.AspNetCore.Identity; -namespace Bit.Core.Identity +namespace Bit.Core.Identity; + +public class RoleStore : IRoleStore { - public class RoleStore : IRoleStore + public void Dispose() { } + + public Task CreateAsync(Role role, CancellationToken cancellationToken) { - public void Dispose() { } + throw new NotImplementedException(); + } - public Task CreateAsync(Role role, CancellationToken cancellationToken) - { - throw new NotImplementedException(); - } + public Task DeleteAsync(Role role, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } - public Task DeleteAsync(Role role, CancellationToken cancellationToken) - { - throw new NotImplementedException(); - } + public Task FindByIdAsync(string roleId, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } - public Task FindByIdAsync(string roleId, CancellationToken cancellationToken) - { - throw new NotImplementedException(); - } + public Task FindByNameAsync(string normalizedRoleName, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } - public Task FindByNameAsync(string normalizedRoleName, CancellationToken cancellationToken) - { - throw new NotImplementedException(); - } + public Task GetNormalizedRoleNameAsync(Role role, CancellationToken cancellationToken) + { + return Task.FromResult(role.Name); + } - public Task GetNormalizedRoleNameAsync(Role role, CancellationToken cancellationToken) - { - return Task.FromResult(role.Name); - } + public Task GetRoleIdAsync(Role role, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } - public Task GetRoleIdAsync(Role role, CancellationToken cancellationToken) - { - throw new NotImplementedException(); - } + public Task GetRoleNameAsync(Role role, CancellationToken cancellationToken) + { + return Task.FromResult(role.Name); + } - public Task GetRoleNameAsync(Role role, CancellationToken cancellationToken) - { - return Task.FromResult(role.Name); - } + public Task SetNormalizedRoleNameAsync(Role role, string normalizedName, CancellationToken cancellationToken) + { + return Task.FromResult(0); + } - public Task SetNormalizedRoleNameAsync(Role role, string normalizedName, CancellationToken cancellationToken) - { - return Task.FromResult(0); - } + public Task SetRoleNameAsync(Role role, string roleName, CancellationToken cancellationToken) + { + role.Name = roleName; + return Task.FromResult(0); + } - public Task SetRoleNameAsync(Role role, string roleName, CancellationToken cancellationToken) - { - role.Name = roleName; - return Task.FromResult(0); - } - - public Task UpdateAsync(Role role, CancellationToken cancellationToken) - { - throw new NotImplementedException(); - } + public Task UpdateAsync(Role role, CancellationToken cancellationToken) + { + throw new NotImplementedException(); } } diff --git a/src/Core/Identity/TwoFactorRememberTokenProvider.cs b/src/Core/Identity/TwoFactorRememberTokenProvider.cs index 2902280ff..711c8c933 100644 --- a/src/Core/Identity/TwoFactorRememberTokenProvider.cs +++ b/src/Core/Identity/TwoFactorRememberTokenProvider.cs @@ -4,18 +4,17 @@ using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -namespace Bit.Core.Identity -{ - public class TwoFactorRememberTokenProvider : DataProtectorTokenProvider - { - public TwoFactorRememberTokenProvider( - IDataProtectionProvider dataProtectionProvider, - IOptions options, - ILogger> logger) - : base(dataProtectionProvider, options, logger) - { } - } +namespace Bit.Core.Identity; - public class TwoFactorRememberTokenProviderOptions : DataProtectionTokenProviderOptions +public class TwoFactorRememberTokenProvider : DataProtectorTokenProvider +{ + public TwoFactorRememberTokenProvider( + IDataProtectionProvider dataProtectionProvider, + IOptions options, + ILogger> logger) + : base(dataProtectionProvider, options, logger) { } } + +public class TwoFactorRememberTokenProviderOptions : DataProtectionTokenProviderOptions +{ } diff --git a/src/Core/Identity/UserStore.cs b/src/Core/Identity/UserStore.cs index 53bd74484..afa0656c1 100644 --- a/src/Core/Identity/UserStore.cs +++ b/src/Core/Identity/UserStore.cs @@ -5,180 +5,179 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.Identity +namespace Bit.Core.Identity; + +public class UserStore : + IUserStore, + IUserPasswordStore, + IUserEmailStore, + IUserTwoFactorStore, + IUserSecurityStampStore { - public class UserStore : - IUserStore, - IUserPasswordStore, - IUserEmailStore, - IUserTwoFactorStore, - IUserSecurityStampStore + private readonly IServiceProvider _serviceProvider; + private readonly IUserRepository _userRepository; + private readonly ICurrentContext _currentContext; + + public UserStore( + IServiceProvider serviceProvider, + IUserRepository userRepository, + ICurrentContext currentContext) { - private readonly IServiceProvider _serviceProvider; - private readonly IUserRepository _userRepository; - private readonly ICurrentContext _currentContext; + _serviceProvider = serviceProvider; + _userRepository = userRepository; + _currentContext = currentContext; + } - public UserStore( - IServiceProvider serviceProvider, - IUserRepository userRepository, - ICurrentContext currentContext) + public void Dispose() { } + + public async Task CreateAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + await _userRepository.CreateAsync(user); + return IdentityResult.Success; + } + + public async Task DeleteAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + await _userRepository.DeleteAsync(user); + return IdentityResult.Success; + } + + public async Task FindByEmailAsync(string normalizedEmail, CancellationToken cancellationToken = default(CancellationToken)) + { + if (_currentContext?.User != null && _currentContext.User.Email == normalizedEmail) { - _serviceProvider = serviceProvider; - _userRepository = userRepository; - _currentContext = currentContext; - } - - public void Dispose() { } - - public async Task CreateAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - await _userRepository.CreateAsync(user); - return IdentityResult.Success; - } - - public async Task DeleteAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - await _userRepository.DeleteAsync(user); - return IdentityResult.Success; - } - - public async Task FindByEmailAsync(string normalizedEmail, CancellationToken cancellationToken = default(CancellationToken)) - { - if (_currentContext?.User != null && _currentContext.User.Email == normalizedEmail) - { - return _currentContext.User; - } - - _currentContext.User = await _userRepository.GetByEmailAsync(normalizedEmail); return _currentContext.User; } - public async Task FindByIdAsync(string userId, CancellationToken cancellationToken = default(CancellationToken)) + _currentContext.User = await _userRepository.GetByEmailAsync(normalizedEmail); + return _currentContext.User; + } + + public async Task FindByIdAsync(string userId, CancellationToken cancellationToken = default(CancellationToken)) + { + if (_currentContext?.User != null && + string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase)) { - if (_currentContext?.User != null && - string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase)) - { - return _currentContext.User; - } - - Guid userIdGuid; - if (!Guid.TryParse(userId, out userIdGuid)) - { - return null; - } - - _currentContext.User = await _userRepository.GetByIdAsync(userIdGuid); return _currentContext.User; } - public async Task FindByNameAsync(string normalizedUserName, CancellationToken cancellationToken = default(CancellationToken)) + Guid userIdGuid; + if (!Guid.TryParse(userId, out userIdGuid)) { - return await FindByEmailAsync(normalizedUserName, cancellationToken); + return null; } - public Task GetEmailAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Email); - } + _currentContext.User = await _userRepository.GetByIdAsync(userIdGuid); + return _currentContext.User; + } - public Task GetEmailConfirmedAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.EmailVerified); - } + public async Task FindByNameAsync(string normalizedUserName, CancellationToken cancellationToken = default(CancellationToken)) + { + return await FindByEmailAsync(normalizedUserName, cancellationToken); + } - public Task GetNormalizedEmailAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Email); - } + public Task GetEmailAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Email); + } - public Task GetNormalizedUserNameAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Email); - } + public Task GetEmailConfirmedAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.EmailVerified); + } - public Task GetPasswordHashAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.MasterPassword); - } + public Task GetNormalizedEmailAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Email); + } - public Task GetUserIdAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Id.ToString()); - } + public Task GetNormalizedUserNameAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Email); + } - public Task GetUserNameAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(user.Email); - } + public Task GetPasswordHashAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.MasterPassword); + } - public Task HasPasswordAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(!string.IsNullOrWhiteSpace(user.MasterPassword)); - } + public Task GetUserIdAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Id.ToString()); + } - public Task SetEmailAsync(User user, string email, CancellationToken cancellationToken = default(CancellationToken)) - { - user.Email = email; - return Task.FromResult(0); - } + public Task GetUserNameAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(user.Email); + } - public Task SetEmailConfirmedAsync(User user, bool confirmed, CancellationToken cancellationToken = default(CancellationToken)) - { - user.EmailVerified = confirmed; - return Task.FromResult(0); - } + public Task HasPasswordAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(!string.IsNullOrWhiteSpace(user.MasterPassword)); + } - public Task SetNormalizedEmailAsync(User user, string normalizedEmail, CancellationToken cancellationToken = default(CancellationToken)) - { - user.Email = normalizedEmail; - return Task.FromResult(0); - } + public Task SetEmailAsync(User user, string email, CancellationToken cancellationToken = default(CancellationToken)) + { + user.Email = email; + return Task.FromResult(0); + } - public Task SetNormalizedUserNameAsync(User user, string normalizedName, CancellationToken cancellationToken = default(CancellationToken)) - { - user.Email = normalizedName; - return Task.FromResult(0); - } + public Task SetEmailConfirmedAsync(User user, bool confirmed, CancellationToken cancellationToken = default(CancellationToken)) + { + user.EmailVerified = confirmed; + return Task.FromResult(0); + } - public Task SetPasswordHashAsync(User user, string passwordHash, CancellationToken cancellationToken = default(CancellationToken)) - { - user.MasterPassword = passwordHash; - return Task.FromResult(0); - } + public Task SetNormalizedEmailAsync(User user, string normalizedEmail, CancellationToken cancellationToken = default(CancellationToken)) + { + user.Email = normalizedEmail; + return Task.FromResult(0); + } - public Task SetUserNameAsync(User user, string userName, CancellationToken cancellationToken = default(CancellationToken)) - { - user.Email = userName; - return Task.FromResult(0); - } + public Task SetNormalizedUserNameAsync(User user, string normalizedName, CancellationToken cancellationToken = default(CancellationToken)) + { + user.Email = normalizedName; + return Task.FromResult(0); + } - public async Task UpdateAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) - { - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); - return IdentityResult.Success; - } + public Task SetPasswordHashAsync(User user, string passwordHash, CancellationToken cancellationToken = default(CancellationToken)) + { + user.MasterPassword = passwordHash; + return Task.FromResult(0); + } - public Task SetTwoFactorEnabledAsync(User user, bool enabled, CancellationToken cancellationToken) - { - // Do nothing... - return Task.FromResult(0); - } + public Task SetUserNameAsync(User user, string userName, CancellationToken cancellationToken = default(CancellationToken)) + { + user.Email = userName; + return Task.FromResult(0); + } - public async Task GetTwoFactorEnabledAsync(User user, CancellationToken cancellationToken) - { - return await _serviceProvider.GetRequiredService().TwoFactorIsEnabledAsync(user); - } + public async Task UpdateAsync(User user, CancellationToken cancellationToken = default(CancellationToken)) + { + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + return IdentityResult.Success; + } - public Task SetSecurityStampAsync(User user, string stamp, CancellationToken cancellationToken) - { - user.SecurityStamp = stamp; - return Task.FromResult(0); - } + public Task SetTwoFactorEnabledAsync(User user, bool enabled, CancellationToken cancellationToken) + { + // Do nothing... + return Task.FromResult(0); + } - public Task GetSecurityStampAsync(User user, CancellationToken cancellationToken) - { - return Task.FromResult(user.SecurityStamp); - } + public async Task GetTwoFactorEnabledAsync(User user, CancellationToken cancellationToken) + { + return await _serviceProvider.GetRequiredService().TwoFactorIsEnabledAsync(user); + } + + public Task SetSecurityStampAsync(User user, string stamp, CancellationToken cancellationToken) + { + user.SecurityStamp = stamp; + return Task.FromResult(0); + } + + public Task GetSecurityStampAsync(User user, CancellationToken cancellationToken) + { + return Task.FromResult(user.SecurityStamp); } } diff --git a/src/Core/Identity/WebAuthnTokenProvider.cs b/src/Core/Identity/WebAuthnTokenProvider.cs index ee857422a..b34b6b187 100644 --- a/src/Core/Identity/WebAuthnTokenProvider.cs +++ b/src/Core/Identity/WebAuthnTokenProvider.cs @@ -10,146 +10,145 @@ using Fido2NetLib.Objects; using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.Identity +namespace Bit.Core.Identity; + +public class WebAuthnTokenProvider : IUserTwoFactorTokenProvider { - public class WebAuthnTokenProvider : IUserTwoFactorTokenProvider + private readonly IServiceProvider _serviceProvider; + private readonly IFido2 _fido2; + private readonly GlobalSettings _globalSettings; + + public WebAuthnTokenProvider(IServiceProvider serviceProvider, IFido2 fido2, GlobalSettings globalSettings) { - private readonly IServiceProvider _serviceProvider; - private readonly IFido2 _fido2; - private readonly GlobalSettings _globalSettings; + _serviceProvider = serviceProvider; + _fido2 = fido2; + _globalSettings = globalSettings; + } - public WebAuthnTokenProvider(IServiceProvider serviceProvider, IFido2 fido2, GlobalSettings globalSettings) + public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) + { + var userService = _serviceProvider.GetRequiredService(); + if (!(await userService.CanAccessPremium(user))) { - _serviceProvider = serviceProvider; - _fido2 = fido2; - _globalSettings = globalSettings; + return false; } - public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) + var webAuthnProvider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); + if (!HasProperMetaData(webAuthnProvider)) { - var userService = _serviceProvider.GetRequiredService(); - if (!(await userService.CanAccessPremium(user))) - { - return false; - } - - var webAuthnProvider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); - if (!HasProperMetaData(webAuthnProvider)) - { - return false; - } - - return await userService.TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.WebAuthn, user); + return false; } - public async Task GenerateAsync(string purpose, UserManager manager, User user) + return await userService.TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.WebAuthn, user); + } + + public async Task GenerateAsync(string purpose, UserManager manager, User user) + { + var userService = _serviceProvider.GetRequiredService(); + if (!(await userService.CanAccessPremium(user))) { - var userService = _serviceProvider.GetRequiredService(); - if (!(await userService.CanAccessPremium(user))) - { - return null; - } - - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); - var keys = LoadKeys(provider); - var existingCredentials = keys.Select(key => key.Item2.Descriptor).ToList(); - - if (existingCredentials.Count == 0) - { - return null; - } - - var exts = new AuthenticationExtensionsClientInputs() - { - UserVerificationMethod = true, - AppID = CoreHelpers.U2fAppIdUrl(_globalSettings), - }; - - var options = _fido2.GetAssertionOptions(existingCredentials, UserVerificationRequirement.Discouraged, exts); - - // TODO: Remove this when newtonsoft legacy converters are gone - provider.MetaData["login"] = JsonSerializer.Serialize(options); - - var providers = user.GetTwoFactorProviders(); - providers[TwoFactorProviderType.WebAuthn] = provider; - user.SetTwoFactorProviders(providers); - await userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn, logEvent: false); - - return options.ToJson(); + return null; } - public async Task ValidateAsync(string purpose, string token, UserManager manager, User user) + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); + var keys = LoadKeys(provider); + var existingCredentials = keys.Select(key => key.Item2.Descriptor).ToList(); + + if (existingCredentials.Count == 0) { - var userService = _serviceProvider.GetRequiredService(); - if (!(await userService.CanAccessPremium(user)) || string.IsNullOrWhiteSpace(token)) - { - return false; - } - - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); - var keys = LoadKeys(provider); - - if (!provider.MetaData.ContainsKey("login")) - { - return false; - } - - var clientResponse = JsonSerializer.Deserialize(token, - new JsonSerializerOptions { PropertyNameCaseInsensitive = true }); - - var jsonOptions = provider.MetaData["login"].ToString(); - var options = AssertionOptions.FromJson(jsonOptions); - - var webAuthCred = keys.Find(k => k.Item2.Descriptor.Id.SequenceEqual(clientResponse.Id)); - - if (webAuthCred == null) - { - return false; - } - - IsUserHandleOwnerOfCredentialIdAsync callback = (args) => Task.FromResult(true); - - var res = await _fido2.MakeAssertionAsync(clientResponse, options, webAuthCred.Item2.PublicKey, webAuthCred.Item2.SignatureCounter, callback); - - provider.MetaData.Remove("login"); - - // Update SignatureCounter - webAuthCred.Item2.SignatureCounter = res.Counter; - - var providers = user.GetTwoFactorProviders(); - providers[TwoFactorProviderType.WebAuthn].MetaData[webAuthCred.Item1] = webAuthCred.Item2; - user.SetTwoFactorProviders(providers); - await userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn, logEvent: false); - - return res.Status == "ok"; + return null; } - private bool HasProperMetaData(TwoFactorProvider provider) + var exts = new AuthenticationExtensionsClientInputs() { - return provider?.MetaData?.Any() ?? false; + UserVerificationMethod = true, + AppID = CoreHelpers.U2fAppIdUrl(_globalSettings), + }; + + var options = _fido2.GetAssertionOptions(existingCredentials, UserVerificationRequirement.Discouraged, exts); + + // TODO: Remove this when newtonsoft legacy converters are gone + provider.MetaData["login"] = JsonSerializer.Serialize(options); + + var providers = user.GetTwoFactorProviders(); + providers[TwoFactorProviderType.WebAuthn] = provider; + user.SetTwoFactorProviders(providers); + await userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn, logEvent: false); + + return options.ToJson(); + } + + public async Task ValidateAsync(string purpose, string token, UserManager manager, User user) + { + var userService = _serviceProvider.GetRequiredService(); + if (!(await userService.CanAccessPremium(user)) || string.IsNullOrWhiteSpace(token)) + { + return false; } - private List> LoadKeys(TwoFactorProvider provider) + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); + var keys = LoadKeys(provider); + + if (!provider.MetaData.ContainsKey("login")) { - var keys = new List>(); - if (!HasProperMetaData(provider)) - { - return keys; - } + return false; + } - // Support up to 5 keys - for (var i = 1; i <= 5; i++) - { - var keyName = $"Key{i}"; - if (provider.MetaData.ContainsKey(keyName)) - { - var key = new TwoFactorProvider.WebAuthnData((dynamic)provider.MetaData[keyName]); + var clientResponse = JsonSerializer.Deserialize(token, + new JsonSerializerOptions { PropertyNameCaseInsensitive = true }); - keys.Add(new Tuple(keyName, key)); - } - } + var jsonOptions = provider.MetaData["login"].ToString(); + var options = AssertionOptions.FromJson(jsonOptions); + var webAuthCred = keys.Find(k => k.Item2.Descriptor.Id.SequenceEqual(clientResponse.Id)); + + if (webAuthCred == null) + { + return false; + } + + IsUserHandleOwnerOfCredentialIdAsync callback = (args) => Task.FromResult(true); + + var res = await _fido2.MakeAssertionAsync(clientResponse, options, webAuthCred.Item2.PublicKey, webAuthCred.Item2.SignatureCounter, callback); + + provider.MetaData.Remove("login"); + + // Update SignatureCounter + webAuthCred.Item2.SignatureCounter = res.Counter; + + var providers = user.GetTwoFactorProviders(); + providers[TwoFactorProviderType.WebAuthn].MetaData[webAuthCred.Item1] = webAuthCred.Item2; + user.SetTwoFactorProviders(providers); + await userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn, logEvent: false); + + return res.Status == "ok"; + } + + private bool HasProperMetaData(TwoFactorProvider provider) + { + return provider?.MetaData?.Any() ?? false; + } + + private List> LoadKeys(TwoFactorProvider provider) + { + var keys = new List>(); + if (!HasProperMetaData(provider)) + { return keys; } + + // Support up to 5 keys + for (var i = 1; i <= 5; i++) + { + var keyName = $"Key{i}"; + if (provider.MetaData.ContainsKey(keyName)) + { + var key = new TwoFactorProvider.WebAuthnData((dynamic)provider.MetaData[keyName]); + + keys.Add(new Tuple(keyName, key)); + } + } + + return keys; } } diff --git a/src/Core/Identity/YubicoOtpTokenProvider.cs b/src/Core/Identity/YubicoOtpTokenProvider.cs index 763cdbf4b..3d7bb9fe7 100644 --- a/src/Core/Identity/YubicoOtpTokenProvider.cs +++ b/src/Core/Identity/YubicoOtpTokenProvider.cs @@ -6,71 +6,70 @@ using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.DependencyInjection; using YubicoDotNetClient; -namespace Bit.Core.Identity +namespace Bit.Core.Identity; + +public class YubicoOtpTokenProvider : IUserTwoFactorTokenProvider { - public class YubicoOtpTokenProvider : IUserTwoFactorTokenProvider + private readonly IServiceProvider _serviceProvider; + private readonly GlobalSettings _globalSettings; + + public YubicoOtpTokenProvider( + IServiceProvider serviceProvider, + GlobalSettings globalSettings) { - private readonly IServiceProvider _serviceProvider; - private readonly GlobalSettings _globalSettings; + _serviceProvider = serviceProvider; + _globalSettings = globalSettings; + } - public YubicoOtpTokenProvider( - IServiceProvider serviceProvider, - GlobalSettings globalSettings) + public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) + { + var userService = _serviceProvider.GetRequiredService(); + if (!(await userService.CanAccessPremium(user))) { - _serviceProvider = serviceProvider; - _globalSettings = globalSettings; + return false; } - public async Task CanGenerateTwoFactorTokenAsync(UserManager manager, User user) + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.YubiKey); + if (!provider?.MetaData.Values.Any(v => !string.IsNullOrWhiteSpace((string)v)) ?? true) { - var userService = _serviceProvider.GetRequiredService(); - if (!(await userService.CanAccessPremium(user))) - { - return false; - } - - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.YubiKey); - if (!provider?.MetaData.Values.Any(v => !string.IsNullOrWhiteSpace((string)v)) ?? true) - { - return false; - } - - return await userService.TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.YubiKey, user); + return false; } - public Task GenerateAsync(string purpose, UserManager manager, User user) + return await userService.TwoFactorProviderIsEnabledAsync(TwoFactorProviderType.YubiKey, user); + } + + public Task GenerateAsync(string purpose, UserManager manager, User user) + { + return Task.FromResult(null); + } + + public async Task ValidateAsync(string purpose, string token, UserManager manager, User user) + { + var userService = _serviceProvider.GetRequiredService(); + if (!(await userService.CanAccessPremium(user))) { - return Task.FromResult(null); + return false; } - public async Task ValidateAsync(string purpose, string token, UserManager manager, User user) + if (string.IsNullOrWhiteSpace(token) || token.Length < 32 || token.Length > 48) { - var userService = _serviceProvider.GetRequiredService(); - if (!(await userService.CanAccessPremium(user))) - { - return false; - } - - if (string.IsNullOrWhiteSpace(token) || token.Length < 32 || token.Length > 48) - { - return false; - } - - var id = token.Substring(0, 12); - - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.YubiKey); - if (!provider.MetaData.ContainsValue(id)) - { - return false; - } - - var client = new YubicoClient(_globalSettings.Yubico.ClientId, _globalSettings.Yubico.Key); - if (_globalSettings.Yubico.ValidationUrls != null && _globalSettings.Yubico.ValidationUrls.Length > 0) - { - client.SetUrls(_globalSettings.Yubico.ValidationUrls); - } - var response = await client.VerifyAsync(token); - return response.Status == YubicoResponseStatus.Ok; + return false; } + + var id = token.Substring(0, 12); + + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.YubiKey); + if (!provider.MetaData.ContainsValue(id)) + { + return false; + } + + var client = new YubicoClient(_globalSettings.Yubico.ClientId, _globalSettings.Yubico.Key); + if (_globalSettings.Yubico.ValidationUrls != null && _globalSettings.Yubico.ValidationUrls.Length > 0) + { + client.SetUrls(_globalSettings.Yubico.ValidationUrls); + } + var response = await client.VerifyAsync(token); + return response.Status == YubicoResponseStatus.Ok; } } diff --git a/src/Core/IdentityServer/ApiClient.cs b/src/Core/IdentityServer/ApiClient.cs index a17bb32f9..b289da001 100644 --- a/src/Core/IdentityServer/ApiClient.cs +++ b/src/Core/IdentityServer/ApiClient.cs @@ -1,78 +1,77 @@ using Bit.Core.Settings; using IdentityServer4.Models; -namespace Bit.Core.IdentityServer +namespace Bit.Core.IdentityServer; + +public class ApiClient : Client { - public class ApiClient : Client + public ApiClient( + GlobalSettings globalSettings, + string id, + int refreshTokenSlidingDays, + int accessTokenLifetimeHours, + string[] scopes = null) { - public ApiClient( - GlobalSettings globalSettings, - string id, - int refreshTokenSlidingDays, - int accessTokenLifetimeHours, - string[] scopes = null) + ClientId = id; + AllowedGrantTypes = new[] { GrantType.ResourceOwnerPassword, GrantType.AuthorizationCode }; + RefreshTokenExpiration = TokenExpiration.Sliding; + RefreshTokenUsage = TokenUsage.ReUse; + SlidingRefreshTokenLifetime = 86400 * refreshTokenSlidingDays; + AbsoluteRefreshTokenLifetime = 0; // forever + UpdateAccessTokenClaimsOnRefresh = true; + AccessTokenLifetime = 3600 * accessTokenLifetimeHours; + AllowOfflineAccess = true; + + RequireConsent = false; + RequirePkce = true; + RequireClientSecret = false; + if (id == "web") { - ClientId = id; - AllowedGrantTypes = new[] { GrantType.ResourceOwnerPassword, GrantType.AuthorizationCode }; - RefreshTokenExpiration = TokenExpiration.Sliding; - RefreshTokenUsage = TokenUsage.ReUse; - SlidingRefreshTokenLifetime = 86400 * refreshTokenSlidingDays; - AbsoluteRefreshTokenLifetime = 0; // forever - UpdateAccessTokenClaimsOnRefresh = true; - AccessTokenLifetime = 3600 * accessTokenLifetimeHours; - AllowOfflineAccess = true; - - RequireConsent = false; - RequirePkce = true; - RequireClientSecret = false; - if (id == "web") - { - RedirectUris = new[] { $"{globalSettings.BaseServiceUri.Vault}/sso-connector.html" }; - PostLogoutRedirectUris = new[] { globalSettings.BaseServiceUri.Vault }; - AllowedCorsOrigins = new[] { globalSettings.BaseServiceUri.Vault }; - } - else if (id == "desktop") - { - RedirectUris = new[] { "bitwarden://sso-callback" }; - PostLogoutRedirectUris = new[] { "bitwarden://logged-out" }; - } - else if (id == "connector") - { - var connectorUris = new List(); - for (var port = 8065; port <= 8070; port++) - { - connectorUris.Add(string.Format("http://localhost:{0}", port)); - } - RedirectUris = connectorUris.Append("bwdc://sso-callback").ToList(); - PostLogoutRedirectUris = connectorUris.Append("bwdc://logged-out").ToList(); - } - else if (id == "browser") - { - RedirectUris = new[] { $"{globalSettings.BaseServiceUri.Vault}/sso-connector.html" }; - PostLogoutRedirectUris = new[] { globalSettings.BaseServiceUri.Vault }; - AllowedCorsOrigins = new[] { globalSettings.BaseServiceUri.Vault }; - } - else if (id == "cli") - { - var cliUris = new List(); - for (var port = 8065; port <= 8070; port++) - { - cliUris.Add(string.Format("http://localhost:{0}", port)); - } - RedirectUris = cliUris; - PostLogoutRedirectUris = cliUris; - } - else if (id == "mobile") - { - RedirectUris = new[] { "bitwarden://sso-callback" }; - PostLogoutRedirectUris = new[] { "bitwarden://logged-out" }; - } - - if (scopes == null) - { - scopes = new string[] { "api" }; - } - AllowedScopes = scopes; + RedirectUris = new[] { $"{globalSettings.BaseServiceUri.Vault}/sso-connector.html" }; + PostLogoutRedirectUris = new[] { globalSettings.BaseServiceUri.Vault }; + AllowedCorsOrigins = new[] { globalSettings.BaseServiceUri.Vault }; } + else if (id == "desktop") + { + RedirectUris = new[] { "bitwarden://sso-callback" }; + PostLogoutRedirectUris = new[] { "bitwarden://logged-out" }; + } + else if (id == "connector") + { + var connectorUris = new List(); + for (var port = 8065; port <= 8070; port++) + { + connectorUris.Add(string.Format("http://localhost:{0}", port)); + } + RedirectUris = connectorUris.Append("bwdc://sso-callback").ToList(); + PostLogoutRedirectUris = connectorUris.Append("bwdc://logged-out").ToList(); + } + else if (id == "browser") + { + RedirectUris = new[] { $"{globalSettings.BaseServiceUri.Vault}/sso-connector.html" }; + PostLogoutRedirectUris = new[] { globalSettings.BaseServiceUri.Vault }; + AllowedCorsOrigins = new[] { globalSettings.BaseServiceUri.Vault }; + } + else if (id == "cli") + { + var cliUris = new List(); + for (var port = 8065; port <= 8070; port++) + { + cliUris.Add(string.Format("http://localhost:{0}", port)); + } + RedirectUris = cliUris; + PostLogoutRedirectUris = cliUris; + } + else if (id == "mobile") + { + RedirectUris = new[] { "bitwarden://sso-callback" }; + PostLogoutRedirectUris = new[] { "bitwarden://logged-out" }; + } + + if (scopes == null) + { + scopes = new string[] { "api" }; + } + AllowedScopes = scopes; } } diff --git a/src/Core/IdentityServer/ApiResources.cs b/src/Core/IdentityServer/ApiResources.cs index 55b3427cd..5a19fa2ca 100644 --- a/src/Core/IdentityServer/ApiResources.cs +++ b/src/Core/IdentityServer/ApiResources.cs @@ -1,36 +1,35 @@ using IdentityModel; using IdentityServer4.Models; -namespace Bit.Core.IdentityServer +namespace Bit.Core.IdentityServer; + +public class ApiResources { - public class ApiResources + public static IEnumerable GetApiResources() { - public static IEnumerable GetApiResources() + return new List { - return new List - { - new ApiResource("api", new string[] { - JwtClaimTypes.Name, - JwtClaimTypes.Email, - JwtClaimTypes.EmailVerified, - "sstamp", // security stamp - "premium", - "device", - "orgowner", - "orgadmin", - "orgmanager", - "orguser", - "orgcustom", - "providerprovideradmin", - "providerserviceuser", - }), - new ApiResource("internal", new string[] { JwtClaimTypes.Subject }), - new ApiResource("api.push", new string[] { JwtClaimTypes.Subject }), - new ApiResource("api.licensing", new string[] { JwtClaimTypes.Subject }), - new ApiResource("api.organization", new string[] { JwtClaimTypes.Subject }), - new ApiResource("api.provider", new string[] { JwtClaimTypes.Subject }), - new ApiResource("api.installation", new string[] { JwtClaimTypes.Subject }), - }; - } + new ApiResource("api", new string[] { + JwtClaimTypes.Name, + JwtClaimTypes.Email, + JwtClaimTypes.EmailVerified, + "sstamp", // security stamp + "premium", + "device", + "orgowner", + "orgadmin", + "orgmanager", + "orguser", + "orgcustom", + "providerprovideradmin", + "providerserviceuser", + }), + new ApiResource("internal", new string[] { JwtClaimTypes.Subject }), + new ApiResource("api.push", new string[] { JwtClaimTypes.Subject }), + new ApiResource("api.licensing", new string[] { JwtClaimTypes.Subject }), + new ApiResource("api.organization", new string[] { JwtClaimTypes.Subject }), + new ApiResource("api.provider", new string[] { JwtClaimTypes.Subject }), + new ApiResource("api.installation", new string[] { JwtClaimTypes.Subject }), + }; } } diff --git a/src/Core/IdentityServer/ApiScopes.cs b/src/Core/IdentityServer/ApiScopes.cs index e98465964..2af512eb8 100644 --- a/src/Core/IdentityServer/ApiScopes.cs +++ b/src/Core/IdentityServer/ApiScopes.cs @@ -1,20 +1,19 @@ using IdentityServer4.Models; -namespace Bit.Core.IdentityServer +namespace Bit.Core.IdentityServer; + +public class ApiScopes { - public class ApiScopes + public static IEnumerable GetApiScopes() { - public static IEnumerable GetApiScopes() + return new List { - return new List - { - new ApiScope("api", "API Access"), - new ApiScope("api.push", "API Push Access"), - new ApiScope("api.licensing", "API Licensing Access"), - new ApiScope("api.organization", "API Organization Access"), - new ApiScope("api.installation", "API Installation Access"), - new ApiScope("internal", "Internal Access") - }; - } + new ApiScope("api", "API Access"), + new ApiScope("api.push", "API Push Access"), + new ApiScope("api.licensing", "API Licensing Access"), + new ApiScope("api.organization", "API Organization Access"), + new ApiScope("api.installation", "API Installation Access"), + new ApiScope("internal", "Internal Access") + }; } } diff --git a/src/Core/IdentityServer/AuthorizationCodeStore.cs b/src/Core/IdentityServer/AuthorizationCodeStore.cs index 7bf01f6eb..fc07f7aa6 100644 --- a/src/Core/IdentityServer/AuthorizationCodeStore.cs +++ b/src/Core/IdentityServer/AuthorizationCodeStore.cs @@ -6,39 +6,38 @@ using IdentityServer4.Stores; using IdentityServer4.Stores.Serialization; using Microsoft.Extensions.Logging; -namespace Bit.Core.IdentityServer +namespace Bit.Core.IdentityServer; + +// ref: https://raw.githubusercontent.com/IdentityServer/IdentityServer4/3.1.3/src/IdentityServer4/src/Stores/Default/DefaultAuthorizationCodeStore.cs +public class AuthorizationCodeStore : DefaultGrantStore, IAuthorizationCodeStore { - // ref: https://raw.githubusercontent.com/IdentityServer/IdentityServer4/3.1.3/src/IdentityServer4/src/Stores/Default/DefaultAuthorizationCodeStore.cs - public class AuthorizationCodeStore : DefaultGrantStore, IAuthorizationCodeStore + public AuthorizationCodeStore( + IPersistedGrantStore store, + IPersistentGrantSerializer serializer, + IHandleGenerationService handleGenerationService, + ILogger logger) + : base(IdentityServerConstants.PersistedGrantTypes.AuthorizationCode, store, serializer, + handleGenerationService, logger) + { } + + public Task StoreAuthorizationCodeAsync(AuthorizationCode code) { - public AuthorizationCodeStore( - IPersistedGrantStore store, - IPersistentGrantSerializer serializer, - IHandleGenerationService handleGenerationService, - ILogger logger) - : base(IdentityServerConstants.PersistedGrantTypes.AuthorizationCode, store, serializer, - handleGenerationService, logger) - { } + return CreateItemAsync(code, code.ClientId, code.Subject.GetSubjectId(), code.SessionId, + code.Description, code.CreationTime, code.Lifetime); + } - public Task StoreAuthorizationCodeAsync(AuthorizationCode code) - { - return CreateItemAsync(code, code.ClientId, code.Subject.GetSubjectId(), code.SessionId, - code.Description, code.CreationTime, code.Lifetime); - } + public Task GetAuthorizationCodeAsync(string code) + { + return GetItemAsync(code); + } - public Task GetAuthorizationCodeAsync(string code) - { - return GetItemAsync(code); - } + public Task RemoveAuthorizationCodeAsync(string code) + { + // return RemoveItemAsync(code); - public Task RemoveAuthorizationCodeAsync(string code) - { - // return RemoveItemAsync(code); - - // We don't want to delete authorization codes during validation. - // We'll rely on the authorization code lifecycle for short term validation and the - // DatabaseExpiredGrantsJob to clean up old authorization codes. - return Task.FromResult(0); - } + // We don't want to delete authorization codes during validation. + // We'll rely on the authorization code lifecycle for short term validation and the + // DatabaseExpiredGrantsJob to clean up old authorization codes. + return Task.FromResult(0); } } diff --git a/src/Core/IdentityServer/BaseRequestValidator.cs b/src/Core/IdentityServer/BaseRequestValidator.cs index 632b548b2..d2c72c132 100644 --- a/src/Core/IdentityServer/BaseRequestValidator.cs +++ b/src/Core/IdentityServer/BaseRequestValidator.cs @@ -17,607 +17,606 @@ using IdentityServer4.Validation; using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.Logging; -namespace Bit.Core.IdentityServer -{ - public abstract class BaseRequestValidator where T : class - { - private UserManager _userManager; - private readonly IDeviceRepository _deviceRepository; - private readonly IDeviceService _deviceService; - private readonly IUserService _userService; - private readonly IEventService _eventService; - private readonly IOrganizationDuoWebTokenProvider _organizationDuoWebTokenProvider; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IApplicationCacheService _applicationCacheService; - private readonly IMailService _mailService; - private readonly ILogger _logger; - private readonly ICurrentContext _currentContext; - private readonly GlobalSettings _globalSettings; - private readonly IPolicyRepository _policyRepository; - private readonly IUserRepository _userRepository; - private readonly ICaptchaValidationService _captchaValidationService; +namespace Bit.Core.IdentityServer; - public BaseRequestValidator( - UserManager userManager, - IDeviceRepository deviceRepository, - IDeviceService deviceService, - IUserService userService, - IEventService eventService, - IOrganizationDuoWebTokenProvider organizationDuoWebTokenProvider, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IApplicationCacheService applicationCacheService, - IMailService mailService, - ILogger logger, - ICurrentContext currentContext, - GlobalSettings globalSettings, - IPolicyRepository policyRepository, - IUserRepository userRepository, - ICaptchaValidationService captchaValidationService) +public abstract class BaseRequestValidator where T : class +{ + private UserManager _userManager; + private readonly IDeviceRepository _deviceRepository; + private readonly IDeviceService _deviceService; + private readonly IUserService _userService; + private readonly IEventService _eventService; + private readonly IOrganizationDuoWebTokenProvider _organizationDuoWebTokenProvider; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IApplicationCacheService _applicationCacheService; + private readonly IMailService _mailService; + private readonly ILogger _logger; + private readonly ICurrentContext _currentContext; + private readonly GlobalSettings _globalSettings; + private readonly IPolicyRepository _policyRepository; + private readonly IUserRepository _userRepository; + private readonly ICaptchaValidationService _captchaValidationService; + + public BaseRequestValidator( + UserManager userManager, + IDeviceRepository deviceRepository, + IDeviceService deviceService, + IUserService userService, + IEventService eventService, + IOrganizationDuoWebTokenProvider organizationDuoWebTokenProvider, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IApplicationCacheService applicationCacheService, + IMailService mailService, + ILogger logger, + ICurrentContext currentContext, + GlobalSettings globalSettings, + IPolicyRepository policyRepository, + IUserRepository userRepository, + ICaptchaValidationService captchaValidationService) + { + _userManager = userManager; + _deviceRepository = deviceRepository; + _deviceService = deviceService; + _userService = userService; + _eventService = eventService; + _organizationDuoWebTokenProvider = organizationDuoWebTokenProvider; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _applicationCacheService = applicationCacheService; + _mailService = mailService; + _logger = logger; + _currentContext = currentContext; + _globalSettings = globalSettings; + _policyRepository = policyRepository; + _userRepository = userRepository; + _captchaValidationService = captchaValidationService; + } + + protected async Task ValidateAsync(T context, ValidatedTokenRequest request, + CustomValidatorRequestContext validatorContext) + { + var isBot = (validatorContext.CaptchaResponse?.IsBot ?? false); + if (isBot) { - _userManager = userManager; - _deviceRepository = deviceRepository; - _deviceService = deviceService; - _userService = userService; - _eventService = eventService; - _organizationDuoWebTokenProvider = organizationDuoWebTokenProvider; - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _applicationCacheService = applicationCacheService; - _mailService = mailService; - _logger = logger; - _currentContext = currentContext; - _globalSettings = globalSettings; - _policyRepository = policyRepository; - _userRepository = userRepository; - _captchaValidationService = captchaValidationService; + _logger.LogInformation(Constants.BypassFiltersEventId, + "Login attempt for {0} detected as a captcha bot with score {1}.", + request.UserName, validatorContext.CaptchaResponse.Score); } - protected async Task ValidateAsync(T context, ValidatedTokenRequest request, - CustomValidatorRequestContext validatorContext) + var twoFactorToken = request.Raw["TwoFactorToken"]?.ToString(); + var twoFactorProvider = request.Raw["TwoFactorProvider"]?.ToString(); + var twoFactorRemember = request.Raw["TwoFactorRemember"]?.ToString() == "1"; + var twoFactorRequest = !string.IsNullOrWhiteSpace(twoFactorToken) && + !string.IsNullOrWhiteSpace(twoFactorProvider); + + var valid = await ValidateContextAsync(context, validatorContext); + var user = validatorContext.User; + if (!valid) { - var isBot = (validatorContext.CaptchaResponse?.IsBot ?? false); - if (isBot) - { - _logger.LogInformation(Constants.BypassFiltersEventId, - "Login attempt for {0} detected as a captcha bot with score {1}.", - request.UserName, validatorContext.CaptchaResponse.Score); - } + await UpdateFailedAuthDetailsAsync(user, false, !validatorContext.KnownDevice); + } + if (!valid || isBot) + { + await BuildErrorResultAsync("Username or password is incorrect. Try again.", false, context, user); + return; + } - var twoFactorToken = request.Raw["TwoFactorToken"]?.ToString(); - var twoFactorProvider = request.Raw["TwoFactorProvider"]?.ToString(); - var twoFactorRemember = request.Raw["TwoFactorRemember"]?.ToString() == "1"; - var twoFactorRequest = !string.IsNullOrWhiteSpace(twoFactorToken) && - !string.IsNullOrWhiteSpace(twoFactorProvider); - - var valid = await ValidateContextAsync(context, validatorContext); - var user = validatorContext.User; - if (!valid) + var (isTwoFactorRequired, requires2FABecauseNewDevice, twoFactorOrganization) = await RequiresTwoFactorAsync(user, request); + if (isTwoFactorRequired) + { + // Just defaulting it + var twoFactorProviderType = TwoFactorProviderType.Authenticator; + if (!twoFactorRequest || !Enum.TryParse(twoFactorProvider, out twoFactorProviderType)) { - await UpdateFailedAuthDetailsAsync(user, false, !validatorContext.KnownDevice); - } - if (!valid || isBot) - { - await BuildErrorResultAsync("Username or password is incorrect. Try again.", false, context, user); + await BuildTwoFactorResultAsync(user, twoFactorOrganization, context, requires2FABecauseNewDevice); return; } - var (isTwoFactorRequired, requires2FABecauseNewDevice, twoFactorOrganization) = await RequiresTwoFactorAsync(user, request); - if (isTwoFactorRequired) + BeforeVerifyTwoFactor(user, twoFactorProviderType, requires2FABecauseNewDevice); + + var verified = await VerifyTwoFactor(user, twoFactorOrganization, + twoFactorProviderType, twoFactorToken); + + AfterVerifyTwoFactor(user, twoFactorProviderType, requires2FABecauseNewDevice); + + if ((!verified || isBot) && twoFactorProviderType != TwoFactorProviderType.Remember) { - // Just defaulting it - var twoFactorProviderType = TwoFactorProviderType.Authenticator; - if (!twoFactorRequest || !Enum.TryParse(twoFactorProvider, out twoFactorProviderType)) - { - await BuildTwoFactorResultAsync(user, twoFactorOrganization, context, requires2FABecauseNewDevice); - return; - } - - BeforeVerifyTwoFactor(user, twoFactorProviderType, requires2FABecauseNewDevice); - - var verified = await VerifyTwoFactor(user, twoFactorOrganization, - twoFactorProviderType, twoFactorToken); - - AfterVerifyTwoFactor(user, twoFactorProviderType, requires2FABecauseNewDevice); - - if ((!verified || isBot) && twoFactorProviderType != TwoFactorProviderType.Remember) - { - await UpdateFailedAuthDetailsAsync(user, true, !validatorContext.KnownDevice); - await BuildErrorResultAsync("Two-step token is invalid. Try again.", true, context, user); - return; - } - else if ((!verified || isBot) && twoFactorProviderType == TwoFactorProviderType.Remember) - { - // Delay for brute force. - await Task.Delay(2000); - await BuildTwoFactorResultAsync(user, twoFactorOrganization, context, requires2FABecauseNewDevice); - return; - } + await UpdateFailedAuthDetailsAsync(user, true, !validatorContext.KnownDevice); + await BuildErrorResultAsync("Two-step token is invalid. Try again.", true, context, user); + return; } - else + else if ((!verified || isBot) && twoFactorProviderType == TwoFactorProviderType.Remember) { - twoFactorRequest = false; - twoFactorRemember = false; - twoFactorToken = null; + // Delay for brute force. + await Task.Delay(2000); + await BuildTwoFactorResultAsync(user, twoFactorOrganization, context, requires2FABecauseNewDevice); + return; } + } + else + { + twoFactorRequest = false; + twoFactorRemember = false; + twoFactorToken = null; + } - // Returns true if can finish validation process - if (await IsValidAuthTypeAsync(user, request.GrantType)) + // Returns true if can finish validation process + if (await IsValidAuthTypeAsync(user, request.GrantType)) + { + var device = await SaveDeviceAsync(user, request); + if (device == null) { - var device = await SaveDeviceAsync(user, request); - if (device == null) - { - await BuildErrorResultAsync("No device information provided.", false, context, user); - return; - } - await BuildSuccessResultAsync(user, context, device, twoFactorRequest && twoFactorRemember); + await BuildErrorResultAsync("No device information provided.", false, context, user); + return; } - else + await BuildSuccessResultAsync(user, context, device, twoFactorRequest && twoFactorRemember); + } + else + { + SetSsoResult(context, new Dictionary + {{ + "ErrorModel", new ErrorResponseModel("SSO authentication is required.") + }}); + } + } + + protected abstract Task ValidateContextAsync(T context, CustomValidatorRequestContext validatorContext); + + protected async Task BuildSuccessResultAsync(User user, T context, Device device, bool sendRememberToken) + { + await _eventService.LogUserEventAsync(user.Id, EventType.User_LoggedIn); + + var claims = new List(); + + if (device != null) + { + claims.Add(new Claim("device", device.Identifier)); + } + + var customResponse = new Dictionary(); + if (!string.IsNullOrWhiteSpace(user.PrivateKey)) + { + customResponse.Add("PrivateKey", user.PrivateKey); + } + + if (!string.IsNullOrWhiteSpace(user.Key)) + { + customResponse.Add("Key", user.Key); + } + + customResponse.Add("ForcePasswordReset", user.ForcePasswordReset); + customResponse.Add("ResetMasterPassword", string.IsNullOrWhiteSpace(user.MasterPassword)); + customResponse.Add("Kdf", (byte)user.Kdf); + customResponse.Add("KdfIterations", user.KdfIterations); + + if (sendRememberToken) + { + var token = await _userManager.GenerateTwoFactorTokenAsync(user, + CoreHelpers.CustomProviderName(TwoFactorProviderType.Remember)); + customResponse.Add("TwoFactorToken", token); + } + + await ResetFailedAuthDetailsAsync(user); + await SetSuccessResult(context, user, claims, customResponse); + } + + protected async Task BuildTwoFactorResultAsync(User user, Organization organization, T context, bool requires2FABecauseNewDevice) + { + var providerKeys = new List(); + var providers = new Dictionary>(); + + var enabledProviders = new List>(); + if (organization?.GetTwoFactorProviders() != null) + { + enabledProviders.AddRange(organization.GetTwoFactorProviders().Where( + p => organization.TwoFactorProviderIsEnabled(p.Key))); + } + + if (user.GetTwoFactorProviders() != null) + { + foreach (var p in user.GetTwoFactorProviders()) { - SetSsoResult(context, new Dictionary - {{ - "ErrorModel", new ErrorResponseModel("SSO authentication is required.") - }}); + if (await _userService.TwoFactorProviderIsEnabledAsync(p.Key, user)) + { + enabledProviders.Add(p); + } } } - protected abstract Task ValidateContextAsync(T context, CustomValidatorRequestContext validatorContext); - - protected async Task BuildSuccessResultAsync(User user, T context, Device device, bool sendRememberToken) + if (!enabledProviders.Any()) { - await _eventService.LogUserEventAsync(user.Id, EventType.User_LoggedIn); - - var claims = new List(); - - if (device != null) + if (!requires2FABecauseNewDevice) { - claims.Add(new Claim("device", device.Identifier)); + await BuildErrorResultAsync("No two-step providers enabled.", false, context, user); + return; } - var customResponse = new Dictionary(); - if (!string.IsNullOrWhiteSpace(user.PrivateKey)) + var emailProvider = new TwoFactorProvider { - customResponse.Add("PrivateKey", user.PrivateKey); - } - - if (!string.IsNullOrWhiteSpace(user.Key)) + MetaData = new Dictionary { ["Email"] = user.Email.ToLowerInvariant() }, + Enabled = true + }; + enabledProviders.Add(new KeyValuePair( + TwoFactorProviderType.Email, emailProvider)); + user.SetTwoFactorProviders(new Dictionary { - customResponse.Add("Key", user.Key); - } - - customResponse.Add("ForcePasswordReset", user.ForcePasswordReset); - customResponse.Add("ResetMasterPassword", string.IsNullOrWhiteSpace(user.MasterPassword)); - customResponse.Add("Kdf", (byte)user.Kdf); - customResponse.Add("KdfIterations", user.KdfIterations); - - if (sendRememberToken) - { - var token = await _userManager.GenerateTwoFactorTokenAsync(user, - CoreHelpers.CustomProviderName(TwoFactorProviderType.Remember)); - customResponse.Add("TwoFactorToken", token); - } - - await ResetFailedAuthDetailsAsync(user); - await SetSuccessResult(context, user, claims, customResponse); + [TwoFactorProviderType.Email] = emailProvider + }); } - protected async Task BuildTwoFactorResultAsync(User user, Organization organization, T context, bool requires2FABecauseNewDevice) + foreach (var provider in enabledProviders) { - var providerKeys = new List(); - var providers = new Dictionary>(); + providerKeys.Add((byte)provider.Key); + var infoDict = await BuildTwoFactorParams(organization, user, provider.Key, provider.Value); + providers.Add(((byte)provider.Key).ToString(), infoDict); + } - var enabledProviders = new List>(); - if (organization?.GetTwoFactorProviders() != null) + SetTwoFactorResult(context, + new Dictionary { - enabledProviders.AddRange(organization.GetTwoFactorProviders().Where( - p => organization.TwoFactorProviderIsEnabled(p.Key))); - } + { "TwoFactorProviders", providers.Keys }, + { "TwoFactorProviders2", providers } + }); - if (user.GetTwoFactorProviders() != null) + if (enabledProviders.Count() == 1 && enabledProviders.First().Key == TwoFactorProviderType.Email) + { + // Send email now if this is their only 2FA method + await _userService.SendTwoFactorEmailAsync(user, requires2FABecauseNewDevice); + } + } + + protected async Task BuildErrorResultAsync(string message, bool twoFactorRequest, T context, User user) + { + if (user != null) + { + await _eventService.LogUserEventAsync(user.Id, + twoFactorRequest ? EventType.User_FailedLogIn2fa : EventType.User_FailedLogIn); + } + + if (_globalSettings.SelfHosted) + { + _logger.LogWarning(Constants.BypassFiltersEventId, + string.Format("Failed login attempt{0}{1}", twoFactorRequest ? ", 2FA invalid." : ".", + $" {_currentContext.IpAddress}")); + } + + await Task.Delay(2000); // Delay for brute force. + SetErrorResult(context, + new Dictionary + {{ + "ErrorModel", new ErrorResponseModel(message) + }}); + } + + protected abstract void SetTwoFactorResult(T context, Dictionary customResponse); + + protected abstract void SetSsoResult(T context, Dictionary customResponse); + + protected abstract Task SetSuccessResult(T context, User user, List claims, + Dictionary customResponse); + + protected abstract void SetErrorResult(T context, Dictionary customResponse); + + private async Task> RequiresTwoFactorAsync(User user, ValidatedTokenRequest request) + { + if (request.GrantType == "client_credentials") + { + // Do not require MFA for api key logins + return new Tuple(false, false, null); + } + + var individualRequired = _userManager.SupportsUserTwoFactor && + await _userManager.GetTwoFactorEnabledAsync(user) && + (await _userManager.GetValidTwoFactorProvidersAsync(user)).Count > 0; + + Organization firstEnabledOrg = null; + var orgs = (await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id)) + .ToList(); + if (orgs.Any()) + { + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + var twoFactorOrgs = orgs.Where(o => OrgUsing2fa(orgAbilities, o.Id)); + if (twoFactorOrgs.Any()) { - foreach (var p in user.GetTwoFactorProviders()) - { - if (await _userService.TwoFactorProviderIsEnabledAsync(p.Key, user)) - { - enabledProviders.Add(p); - } - } - } - - if (!enabledProviders.Any()) - { - if (!requires2FABecauseNewDevice) - { - await BuildErrorResultAsync("No two-step providers enabled.", false, context, user); - return; - } - - var emailProvider = new TwoFactorProvider - { - MetaData = new Dictionary { ["Email"] = user.Email.ToLowerInvariant() }, - Enabled = true - }; - enabledProviders.Add(new KeyValuePair( - TwoFactorProviderType.Email, emailProvider)); - user.SetTwoFactorProviders(new Dictionary - { - [TwoFactorProviderType.Email] = emailProvider - }); - } - - foreach (var provider in enabledProviders) - { - providerKeys.Add((byte)provider.Key); - var infoDict = await BuildTwoFactorParams(organization, user, provider.Key, provider.Value); - providers.Add(((byte)provider.Key).ToString(), infoDict); - } - - SetTwoFactorResult(context, - new Dictionary - { - { "TwoFactorProviders", providers.Keys }, - { "TwoFactorProviders2", providers } - }); - - if (enabledProviders.Count() == 1 && enabledProviders.First().Key == TwoFactorProviderType.Email) - { - // Send email now if this is their only 2FA method - await _userService.SendTwoFactorEmailAsync(user, requires2FABecauseNewDevice); + var userOrgs = await _organizationRepository.GetManyByUserIdAsync(user.Id); + firstEnabledOrg = userOrgs.FirstOrDefault( + o => orgs.Any(om => om.Id == o.Id) && o.TwoFactorIsEnabled()); } } - protected async Task BuildErrorResultAsync(string message, bool twoFactorRequest, T context, User user) + var requires2FA = individualRequired || firstEnabledOrg != null; + var requires2FABecauseNewDevice = !requires2FA + && + await _userService.Needs2FABecauseNewDeviceAsync( + user, + GetDeviceFromRequest(request)?.Identifier, + request.GrantType); + + requires2FA = requires2FA || requires2FABecauseNewDevice; + + return new Tuple(requires2FA, requires2FABecauseNewDevice, firstEnabledOrg); + } + + private async Task IsValidAuthTypeAsync(User user, string grantType) + { + if (grantType == "authorization_code" || grantType == "client_credentials") { - if (user != null) - { - await _eventService.LogUserEventAsync(user.Id, - twoFactorRequest ? EventType.User_FailedLogIn2fa : EventType.User_FailedLogIn); - } - - if (_globalSettings.SelfHosted) - { - _logger.LogWarning(Constants.BypassFiltersEventId, - string.Format("Failed login attempt{0}{1}", twoFactorRequest ? ", 2FA invalid." : ".", - $" {_currentContext.IpAddress}")); - } - - await Task.Delay(2000); // Delay for brute force. - SetErrorResult(context, - new Dictionary - {{ - "ErrorModel", new ErrorResponseModel(message) - }}); - } - - protected abstract void SetTwoFactorResult(T context, Dictionary customResponse); - - protected abstract void SetSsoResult(T context, Dictionary customResponse); - - protected abstract Task SetSuccessResult(T context, User user, List claims, - Dictionary customResponse); - - protected abstract void SetErrorResult(T context, Dictionary customResponse); - - private async Task> RequiresTwoFactorAsync(User user, ValidatedTokenRequest request) - { - if (request.GrantType == "client_credentials") - { - // Do not require MFA for api key logins - return new Tuple(false, false, null); - } - - var individualRequired = _userManager.SupportsUserTwoFactor && - await _userManager.GetTwoFactorEnabledAsync(user) && - (await _userManager.GetValidTwoFactorProvidersAsync(user)).Count > 0; - - Organization firstEnabledOrg = null; - var orgs = (await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id)) - .ToList(); - if (orgs.Any()) - { - var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - var twoFactorOrgs = orgs.Where(o => OrgUsing2fa(orgAbilities, o.Id)); - if (twoFactorOrgs.Any()) - { - var userOrgs = await _organizationRepository.GetManyByUserIdAsync(user.Id); - firstEnabledOrg = userOrgs.FirstOrDefault( - o => orgs.Any(om => om.Id == o.Id) && o.TwoFactorIsEnabled()); - } - } - - var requires2FA = individualRequired || firstEnabledOrg != null; - var requires2FABecauseNewDevice = !requires2FA - && - await _userService.Needs2FABecauseNewDeviceAsync( - user, - GetDeviceFromRequest(request)?.Identifier, - request.GrantType); - - requires2FA = requires2FA || requires2FABecauseNewDevice; - - return new Tuple(requires2FA, requires2FABecauseNewDevice, firstEnabledOrg); - } - - private async Task IsValidAuthTypeAsync(User user, string grantType) - { - if (grantType == "authorization_code" || grantType == "client_credentials") - { - // Already using SSO to authorize, finish successfully - // Or login via api key, skip SSO requirement - return true; - } - - // Is user apart of any orgs? Use cache for initial checks. - var orgs = (await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id)) - .ToList(); - if (orgs.Any()) - { - // Get all org abilities - var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - // Parse all user orgs that are enabled and have the ability to use sso - var ssoOrgs = orgs.Where(o => OrgCanUseSso(orgAbilities, o.Id)); - if (ssoOrgs.Any()) - { - // Parse users orgs and determine if require sso policy is enabled - var userOrgs = await _organizationUserRepository.GetManyDetailsByUserAsync(user.Id, - OrganizationUserStatusType.Confirmed); - foreach (var userOrg in userOrgs.Where(o => o.Enabled && o.UseSso)) - { - var orgPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(userOrg.OrganizationId, - PolicyType.RequireSso); - // Owners and Admins are exempt from this policy - if (orgPolicy != null && orgPolicy.Enabled && - userOrg.Type != OrganizationUserType.Owner && userOrg.Type != OrganizationUserType.Admin) - { - return false; - } - } - } - } - - // Default - continue validation process + // Already using SSO to authorize, finish successfully + // Or login via api key, skip SSO requirement return true; } - private bool OrgUsing2fa(IDictionary orgAbilities, Guid orgId) + // Is user apart of any orgs? Use cache for initial checks. + var orgs = (await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id)) + .ToList(); + if (orgs.Any()) { - return orgAbilities != null && orgAbilities.ContainsKey(orgId) && - orgAbilities[orgId].Enabled && orgAbilities[orgId].Using2fa; - } - - private bool OrgCanUseSso(IDictionary orgAbilities, Guid orgId) - { - return orgAbilities != null && orgAbilities.ContainsKey(orgId) && - orgAbilities[orgId].Enabled && orgAbilities[orgId].UseSso; - } - - private Device GetDeviceFromRequest(ValidatedRequest request) - { - var deviceIdentifier = request.Raw["DeviceIdentifier"]?.ToString(); - var deviceType = request.Raw["DeviceType"]?.ToString(); - var deviceName = request.Raw["DeviceName"]?.ToString(); - var devicePushToken = request.Raw["DevicePushToken"]?.ToString(); - - if (string.IsNullOrWhiteSpace(deviceIdentifier) || string.IsNullOrWhiteSpace(deviceType) || - string.IsNullOrWhiteSpace(deviceName) || !Enum.TryParse(deviceType, out DeviceType type)) + // Get all org abilities + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + // Parse all user orgs that are enabled and have the ability to use sso + var ssoOrgs = orgs.Where(o => OrgCanUseSso(orgAbilities, o.Id)); + if (ssoOrgs.Any()) { - return null; - } - - return new Device - { - Identifier = deviceIdentifier, - Name = deviceName, - Type = type, - PushToken = string.IsNullOrWhiteSpace(devicePushToken) ? null : devicePushToken - }; - } - - private void BeforeVerifyTwoFactor(User user, TwoFactorProviderType type, bool requires2FABecauseNewDevice) - { - if (type == TwoFactorProviderType.Email && requires2FABecauseNewDevice) - { - user.SetTwoFactorProviders(new Dictionary + // Parse users orgs and determine if require sso policy is enabled + var userOrgs = await _organizationUserRepository.GetManyDetailsByUserAsync(user.Id, + OrganizationUserStatusType.Confirmed); + foreach (var userOrg in userOrgs.Where(o => o.Enabled && o.UseSso)) { - [TwoFactorProviderType.Email] = new TwoFactorProvider - { - MetaData = new Dictionary { ["Email"] = user.Email.ToLowerInvariant() }, - Enabled = true - } - }); - } - } - - private void AfterVerifyTwoFactor(User user, TwoFactorProviderType type, bool requires2FABecauseNewDevice) - { - if (type == TwoFactorProviderType.Email && requires2FABecauseNewDevice) - { - user.ClearTwoFactorProviders(); - } - } - - private async Task VerifyTwoFactor(User user, Organization organization, TwoFactorProviderType type, - string token) - { - switch (type) - { - case TwoFactorProviderType.Authenticator: - case TwoFactorProviderType.Email: - case TwoFactorProviderType.Duo: - case TwoFactorProviderType.YubiKey: - case TwoFactorProviderType.WebAuthn: - case TwoFactorProviderType.Remember: - if (type != TwoFactorProviderType.Remember && - !(await _userService.TwoFactorProviderIsEnabledAsync(type, user))) - { - return false; - } - return await _userManager.VerifyTwoFactorTokenAsync(user, - CoreHelpers.CustomProviderName(type), token); - case TwoFactorProviderType.OrganizationDuo: - if (!organization?.TwoFactorProviderIsEnabled(type) ?? true) + var orgPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(userOrg.OrganizationId, + PolicyType.RequireSso); + // Owners and Admins are exempt from this policy + if (orgPolicy != null && orgPolicy.Enabled && + userOrg.Type != OrganizationUserType.Owner && userOrg.Type != OrganizationUserType.Admin) { return false; } + } + } + } - return await _organizationDuoWebTokenProvider.ValidateAsync(token, organization, user); - default: + // Default - continue validation process + return true; + } + + private bool OrgUsing2fa(IDictionary orgAbilities, Guid orgId) + { + return orgAbilities != null && orgAbilities.ContainsKey(orgId) && + orgAbilities[orgId].Enabled && orgAbilities[orgId].Using2fa; + } + + private bool OrgCanUseSso(IDictionary orgAbilities, Guid orgId) + { + return orgAbilities != null && orgAbilities.ContainsKey(orgId) && + orgAbilities[orgId].Enabled && orgAbilities[orgId].UseSso; + } + + private Device GetDeviceFromRequest(ValidatedRequest request) + { + var deviceIdentifier = request.Raw["DeviceIdentifier"]?.ToString(); + var deviceType = request.Raw["DeviceType"]?.ToString(); + var deviceName = request.Raw["DeviceName"]?.ToString(); + var devicePushToken = request.Raw["DevicePushToken"]?.ToString(); + + if (string.IsNullOrWhiteSpace(deviceIdentifier) || string.IsNullOrWhiteSpace(deviceType) || + string.IsNullOrWhiteSpace(deviceName) || !Enum.TryParse(deviceType, out DeviceType type)) + { + return null; + } + + return new Device + { + Identifier = deviceIdentifier, + Name = deviceName, + Type = type, + PushToken = string.IsNullOrWhiteSpace(devicePushToken) ? null : devicePushToken + }; + } + + private void BeforeVerifyTwoFactor(User user, TwoFactorProviderType type, bool requires2FABecauseNewDevice) + { + if (type == TwoFactorProviderType.Email && requires2FABecauseNewDevice) + { + user.SetTwoFactorProviders(new Dictionary + { + [TwoFactorProviderType.Email] = new TwoFactorProvider + { + MetaData = new Dictionary { ["Email"] = user.Email.ToLowerInvariant() }, + Enabled = true + } + }); + } + } + + private void AfterVerifyTwoFactor(User user, TwoFactorProviderType type, bool requires2FABecauseNewDevice) + { + if (type == TwoFactorProviderType.Email && requires2FABecauseNewDevice) + { + user.ClearTwoFactorProviders(); + } + } + + private async Task VerifyTwoFactor(User user, Organization organization, TwoFactorProviderType type, + string token) + { + switch (type) + { + case TwoFactorProviderType.Authenticator: + case TwoFactorProviderType.Email: + case TwoFactorProviderType.Duo: + case TwoFactorProviderType.YubiKey: + case TwoFactorProviderType.WebAuthn: + case TwoFactorProviderType.Remember: + if (type != TwoFactorProviderType.Remember && + !(await _userService.TwoFactorProviderIsEnabledAsync(type, user))) + { return false; - } - } + } + return await _userManager.VerifyTwoFactorTokenAsync(user, + CoreHelpers.CustomProviderName(type), token); + case TwoFactorProviderType.OrganizationDuo: + if (!organization?.TwoFactorProviderIsEnabled(type) ?? true) + { + return false; + } - private async Task> BuildTwoFactorParams(Organization organization, User user, - TwoFactorProviderType type, TwoFactorProvider provider) + return await _organizationDuoWebTokenProvider.ValidateAsync(token, organization, user); + default: + return false; + } + } + + private async Task> BuildTwoFactorParams(Organization organization, User user, + TwoFactorProviderType type, TwoFactorProvider provider) + { + switch (type) { - switch (type) - { - case TwoFactorProviderType.Duo: - case TwoFactorProviderType.WebAuthn: - case TwoFactorProviderType.Email: - case TwoFactorProviderType.YubiKey: - if (!(await _userService.TwoFactorProviderIsEnabledAsync(type, user))) + case TwoFactorProviderType.Duo: + case TwoFactorProviderType.WebAuthn: + case TwoFactorProviderType.Email: + case TwoFactorProviderType.YubiKey: + if (!(await _userService.TwoFactorProviderIsEnabledAsync(type, user))) + { + return null; + } + + var token = await _userManager.GenerateTwoFactorTokenAsync(user, + CoreHelpers.CustomProviderName(type)); + if (type == TwoFactorProviderType.Duo) + { + return new Dictionary + { + ["Host"] = provider.MetaData["Host"], + ["Signature"] = token + }; + } + else if (type == TwoFactorProviderType.WebAuthn) + { + if (token == null) { return null; } - var token = await _userManager.GenerateTwoFactorTokenAsync(user, - CoreHelpers.CustomProviderName(type)); - if (type == TwoFactorProviderType.Duo) - { - return new Dictionary - { - ["Host"] = provider.MetaData["Host"], - ["Signature"] = token - }; - } - else if (type == TwoFactorProviderType.WebAuthn) - { - if (token == null) - { - return null; - } - - return JsonSerializer.Deserialize>(token); - } - else if (type == TwoFactorProviderType.Email) - { - return new Dictionary - { - ["Email"] = token - }; - } - else if (type == TwoFactorProviderType.YubiKey) - { - return new Dictionary - { - ["Nfc"] = (bool)provider.MetaData["Nfc"] - }; - } - return null; - case TwoFactorProviderType.OrganizationDuo: - if (await _organizationDuoWebTokenProvider.CanGenerateTwoFactorTokenAsync(organization)) - { - return new Dictionary - { - ["Host"] = provider.MetaData["Host"], - ["Signature"] = await _organizationDuoWebTokenProvider.GenerateAsync(organization, user) - }; - } - return null; - default: - return null; - } - } - - protected async Task KnownDeviceAsync(User user, ValidatedTokenRequest request) => - (await GetKnownDeviceAsync(user, request)) != default; - - protected async Task GetKnownDeviceAsync(User user, ValidatedTokenRequest request) - { - if (user == null) - { - return default; - } - - return await _deviceRepository.GetByIdentifierAsync(GetDeviceFromRequest(request).Identifier, user.Id); - } - - private async Task SaveDeviceAsync(User user, ValidatedTokenRequest request) - { - var device = GetDeviceFromRequest(request); - if (device != null) - { - var existingDevice = await GetKnownDeviceAsync(user, request); - if (existingDevice == null) - { - device.UserId = user.Id; - await _deviceService.SaveAsync(device); - - var now = DateTime.UtcNow; - if (now - user.CreationDate > TimeSpan.FromMinutes(10)) - { - var deviceType = device.Type.GetType().GetMember(device.Type.ToString()) - .FirstOrDefault()?.GetCustomAttribute()?.GetName(); - if (!_globalSettings.DisableEmailNewDevice) - { - await _mailService.SendNewDeviceLoggedInEmail(user.Email, deviceType, now, - _currentContext.IpAddress); - } - } - - return device; + return JsonSerializer.Deserialize>(token); } - - return existingDevice; - } - - return null; - } - - private async Task ResetFailedAuthDetailsAsync(User user) - { - // Early escape if db hit not necessary - if (user == null || user.FailedLoginCount == 0) - { - return; - } - - user.FailedLoginCount = 0; - user.RevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); - } - - private async Task UpdateFailedAuthDetailsAsync(User user, bool twoFactorInvalid, bool unknownDevice) - { - if (user == null) - { - return; - } - - var utcNow = DateTime.UtcNow; - user.FailedLoginCount = ++user.FailedLoginCount; - user.LastFailedLoginDate = user.RevisionDate = utcNow; - await _userRepository.ReplaceAsync(user); - - if (ValidateFailedAuthEmailConditions(unknownDevice, user)) - { - if (twoFactorInvalid) + else if (type == TwoFactorProviderType.Email) { - await _mailService.SendFailedTwoFactorAttemptsEmailAsync(user.Email, utcNow, _currentContext.IpAddress); + return new Dictionary + { + ["Email"] = token + }; } - else + else if (type == TwoFactorProviderType.YubiKey) { - await _mailService.SendFailedLoginAttemptsEmailAsync(user.Email, utcNow, _currentContext.IpAddress); + return new Dictionary + { + ["Nfc"] = (bool)provider.MetaData["Nfc"] + }; } - } - } - - private bool ValidateFailedAuthEmailConditions(bool unknownDevice, User user) - { - var failedLoginCeiling = _globalSettings.Captcha.MaximumFailedLoginAttempts; - var failedLoginCount = user?.FailedLoginCount ?? 0; - return unknownDevice && failedLoginCeiling > 0 && failedLoginCount == failedLoginCeiling; + return null; + case TwoFactorProviderType.OrganizationDuo: + if (await _organizationDuoWebTokenProvider.CanGenerateTwoFactorTokenAsync(organization)) + { + return new Dictionary + { + ["Host"] = provider.MetaData["Host"], + ["Signature"] = await _organizationDuoWebTokenProvider.GenerateAsync(organization, user) + }; + } + return null; + default: + return null; } } + + protected async Task KnownDeviceAsync(User user, ValidatedTokenRequest request) => + (await GetKnownDeviceAsync(user, request)) != default; + + protected async Task GetKnownDeviceAsync(User user, ValidatedTokenRequest request) + { + if (user == null) + { + return default; + } + + return await _deviceRepository.GetByIdentifierAsync(GetDeviceFromRequest(request).Identifier, user.Id); + } + + private async Task SaveDeviceAsync(User user, ValidatedTokenRequest request) + { + var device = GetDeviceFromRequest(request); + if (device != null) + { + var existingDevice = await GetKnownDeviceAsync(user, request); + if (existingDevice == null) + { + device.UserId = user.Id; + await _deviceService.SaveAsync(device); + + var now = DateTime.UtcNow; + if (now - user.CreationDate > TimeSpan.FromMinutes(10)) + { + var deviceType = device.Type.GetType().GetMember(device.Type.ToString()) + .FirstOrDefault()?.GetCustomAttribute()?.GetName(); + if (!_globalSettings.DisableEmailNewDevice) + { + await _mailService.SendNewDeviceLoggedInEmail(user.Email, deviceType, now, + _currentContext.IpAddress); + } + } + + return device; + } + + return existingDevice; + } + + return null; + } + + private async Task ResetFailedAuthDetailsAsync(User user) + { + // Early escape if db hit not necessary + if (user == null || user.FailedLoginCount == 0) + { + return; + } + + user.FailedLoginCount = 0; + user.RevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + } + + private async Task UpdateFailedAuthDetailsAsync(User user, bool twoFactorInvalid, bool unknownDevice) + { + if (user == null) + { + return; + } + + var utcNow = DateTime.UtcNow; + user.FailedLoginCount = ++user.FailedLoginCount; + user.LastFailedLoginDate = user.RevisionDate = utcNow; + await _userRepository.ReplaceAsync(user); + + if (ValidateFailedAuthEmailConditions(unknownDevice, user)) + { + if (twoFactorInvalid) + { + await _mailService.SendFailedTwoFactorAttemptsEmailAsync(user.Email, utcNow, _currentContext.IpAddress); + } + else + { + await _mailService.SendFailedLoginAttemptsEmailAsync(user.Email, utcNow, _currentContext.IpAddress); + } + } + } + + private bool ValidateFailedAuthEmailConditions(bool unknownDevice, User user) + { + var failedLoginCeiling = _globalSettings.Captcha.MaximumFailedLoginAttempts; + var failedLoginCount = user?.FailedLoginCount ?? 0; + return unknownDevice && failedLoginCeiling > 0 && failedLoginCount == failedLoginCeiling; + } } diff --git a/src/Core/IdentityServer/ClientStore.cs b/src/Core/IdentityServer/ClientStore.cs index ebe247f19..2e6fa06bd 100644 --- a/src/Core/IdentityServer/ClientStore.cs +++ b/src/Core/IdentityServer/ClientStore.cs @@ -10,172 +10,171 @@ using IdentityModel; using IdentityServer4.Models; using IdentityServer4.Stores; -namespace Bit.Core.IdentityServer +namespace Bit.Core.IdentityServer; + +public class ClientStore : IClientStore { - public class ClientStore : IClientStore + private readonly IInstallationRepository _installationRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IUserRepository _userRepository; + private readonly GlobalSettings _globalSettings; + private readonly StaticClientStore _staticClientStore; + private readonly ILicensingService _licensingService; + private readonly ICurrentContext _currentContext; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IProviderUserRepository _providerUserRepository; + private readonly IProviderOrganizationRepository _providerOrganizationRepository; + private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; + + public ClientStore( + IInstallationRepository installationRepository, + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + GlobalSettings globalSettings, + StaticClientStore staticClientStore, + ILicensingService licensingService, + ICurrentContext currentContext, + IOrganizationUserRepository organizationUserRepository, + IProviderUserRepository providerUserRepository, + IProviderOrganizationRepository providerOrganizationRepository, + IOrganizationApiKeyRepository organizationApiKeyRepository) { - private readonly IInstallationRepository _installationRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IUserRepository _userRepository; - private readonly GlobalSettings _globalSettings; - private readonly StaticClientStore _staticClientStore; - private readonly ILicensingService _licensingService; - private readonly ICurrentContext _currentContext; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IProviderUserRepository _providerUserRepository; - private readonly IProviderOrganizationRepository _providerOrganizationRepository; - private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; + _installationRepository = installationRepository; + _organizationRepository = organizationRepository; + _userRepository = userRepository; + _globalSettings = globalSettings; + _staticClientStore = staticClientStore; + _licensingService = licensingService; + _currentContext = currentContext; + _organizationUserRepository = organizationUserRepository; + _providerUserRepository = providerUserRepository; + _providerOrganizationRepository = providerOrganizationRepository; + _organizationApiKeyRepository = organizationApiKeyRepository; + } - public ClientStore( - IInstallationRepository installationRepository, - IOrganizationRepository organizationRepository, - IUserRepository userRepository, - GlobalSettings globalSettings, - StaticClientStore staticClientStore, - ILicensingService licensingService, - ICurrentContext currentContext, - IOrganizationUserRepository organizationUserRepository, - IProviderUserRepository providerUserRepository, - IProviderOrganizationRepository providerOrganizationRepository, - IOrganizationApiKeyRepository organizationApiKeyRepository) + public async Task FindClientByIdAsync(string clientId) + { + if (!_globalSettings.SelfHosted && clientId.StartsWith("installation.")) { - _installationRepository = installationRepository; - _organizationRepository = organizationRepository; - _userRepository = userRepository; - _globalSettings = globalSettings; - _staticClientStore = staticClientStore; - _licensingService = licensingService; - _currentContext = currentContext; - _organizationUserRepository = organizationUserRepository; - _providerUserRepository = providerUserRepository; - _providerOrganizationRepository = providerOrganizationRepository; - _organizationApiKeyRepository = organizationApiKeyRepository; - } - - public async Task FindClientByIdAsync(string clientId) - { - if (!_globalSettings.SelfHosted && clientId.StartsWith("installation.")) + var idParts = clientId.Split('.'); + if (idParts.Length > 1 && Guid.TryParse(idParts[1], out Guid id)) { - var idParts = clientId.Split('.'); - if (idParts.Length > 1 && Guid.TryParse(idParts[1], out Guid id)) + var installation = await _installationRepository.GetByIdAsync(id); + if (installation != null) { - var installation = await _installationRepository.GetByIdAsync(id); - if (installation != null) + return new Client { - return new Client + ClientId = $"installation.{installation.Id}", + RequireClientSecret = true, + ClientSecrets = { new Secret(installation.Key.Sha256()) }, + AllowedScopes = new string[] { "api.push", "api.licensing", "api.installation" }, + AllowedGrantTypes = GrantTypes.ClientCredentials, + AccessTokenLifetime = 3600 * 24, + Enabled = installation.Enabled, + Claims = new List { - ClientId = $"installation.{installation.Id}", - RequireClientSecret = true, - ClientSecrets = { new Secret(installation.Key.Sha256()) }, - AllowedScopes = new string[] { "api.push", "api.licensing", "api.installation" }, - AllowedGrantTypes = GrantTypes.ClientCredentials, - AccessTokenLifetime = 3600 * 24, - Enabled = installation.Enabled, - Claims = new List - { - new ClientClaim(JwtClaimTypes.Subject, installation.Id.ToString()) - } - }; - } - } - } - else if (_globalSettings.SelfHosted && clientId.StartsWith("internal.") && - CoreHelpers.SettingHasValue(_globalSettings.InternalIdentityKey)) - { - var idParts = clientId.Split('.'); - if (idParts.Length > 1) - { - var id = idParts[1]; - if (!string.IsNullOrWhiteSpace(id)) - { - return new Client - { - ClientId = $"internal.{id}", - RequireClientSecret = true, - ClientSecrets = { new Secret(_globalSettings.InternalIdentityKey.Sha256()) }, - AllowedScopes = new string[] { "internal" }, - AllowedGrantTypes = GrantTypes.ClientCredentials, - AccessTokenLifetime = 3600 * 24, - Enabled = true, - Claims = new List - { - new ClientClaim(JwtClaimTypes.Subject, id) - } - }; - } - } - } - else if (clientId.StartsWith("organization.")) - { - var idParts = clientId.Split('.'); - if (idParts.Length > 1 && Guid.TryParse(idParts[1], out var id)) - { - var org = await _organizationRepository.GetByIdAsync(id); - if (org != null) - { - var orgApiKey = (await _organizationApiKeyRepository - .GetManyByOrganizationIdTypeAsync(org.Id, OrganizationApiKeyType.Default)) - .First(); - return new Client - { - ClientId = $"organization.{org.Id}", - RequireClientSecret = true, - ClientSecrets = { new Secret(orgApiKey.ApiKey.Sha256()) }, - AllowedScopes = new string[] { "api.organization" }, - AllowedGrantTypes = GrantTypes.ClientCredentials, - AccessTokenLifetime = 3600 * 1, - Enabled = org.Enabled && org.UseApi, - Claims = new List - { - new ClientClaim(JwtClaimTypes.Subject, org.Id.ToString()) - } - }; - } - } - } - else if (clientId.StartsWith("user.")) - { - var idParts = clientId.Split('.'); - if (idParts.Length > 1 && Guid.TryParse(idParts[1], out var id)) - { - var user = await _userRepository.GetByIdAsync(id); - if (user != null) - { - var claims = new Collection() - { - new ClientClaim(JwtClaimTypes.Subject, user.Id.ToString()), - new ClientClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external") - }; - var orgs = await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id); - var providers = await _currentContext.ProviderMembershipAsync(_providerUserRepository, user.Id); - var isPremium = await _licensingService.ValidateUserPremiumAsync(user); - foreach (var claim in CoreHelpers.BuildIdentityClaims(user, orgs, providers, isPremium)) - { - var upperValue = claim.Value.ToUpperInvariant(); - var isBool = upperValue == "TRUE" || upperValue == "FALSE"; - claims.Add(isBool ? - new ClientClaim(claim.Key, claim.Value, ClaimValueTypes.Boolean) : - new ClientClaim(claim.Key, claim.Value) - ); + new ClientClaim(JwtClaimTypes.Subject, installation.Id.ToString()) } - - return new Client - { - ClientId = clientId, - RequireClientSecret = true, - ClientSecrets = { new Secret(user.ApiKey.Sha256()) }, - AllowedScopes = new string[] { "api" }, - AllowedGrantTypes = GrantTypes.ClientCredentials, - AccessTokenLifetime = 3600 * 1, - ClientClaimsPrefix = null, - Claims = claims - }; - } + }; } } - - return _staticClientStore.ApiClients.ContainsKey(clientId) ? - _staticClientStore.ApiClients[clientId] : null; } + else if (_globalSettings.SelfHosted && clientId.StartsWith("internal.") && + CoreHelpers.SettingHasValue(_globalSettings.InternalIdentityKey)) + { + var idParts = clientId.Split('.'); + if (idParts.Length > 1) + { + var id = idParts[1]; + if (!string.IsNullOrWhiteSpace(id)) + { + return new Client + { + ClientId = $"internal.{id}", + RequireClientSecret = true, + ClientSecrets = { new Secret(_globalSettings.InternalIdentityKey.Sha256()) }, + AllowedScopes = new string[] { "internal" }, + AllowedGrantTypes = GrantTypes.ClientCredentials, + AccessTokenLifetime = 3600 * 24, + Enabled = true, + Claims = new List + { + new ClientClaim(JwtClaimTypes.Subject, id) + } + }; + } + } + } + else if (clientId.StartsWith("organization.")) + { + var idParts = clientId.Split('.'); + if (idParts.Length > 1 && Guid.TryParse(idParts[1], out var id)) + { + var org = await _organizationRepository.GetByIdAsync(id); + if (org != null) + { + var orgApiKey = (await _organizationApiKeyRepository + .GetManyByOrganizationIdTypeAsync(org.Id, OrganizationApiKeyType.Default)) + .First(); + return new Client + { + ClientId = $"organization.{org.Id}", + RequireClientSecret = true, + ClientSecrets = { new Secret(orgApiKey.ApiKey.Sha256()) }, + AllowedScopes = new string[] { "api.organization" }, + AllowedGrantTypes = GrantTypes.ClientCredentials, + AccessTokenLifetime = 3600 * 1, + Enabled = org.Enabled && org.UseApi, + Claims = new List + { + new ClientClaim(JwtClaimTypes.Subject, org.Id.ToString()) + } + }; + } + } + } + else if (clientId.StartsWith("user.")) + { + var idParts = clientId.Split('.'); + if (idParts.Length > 1 && Guid.TryParse(idParts[1], out var id)) + { + var user = await _userRepository.GetByIdAsync(id); + if (user != null) + { + var claims = new Collection() + { + new ClientClaim(JwtClaimTypes.Subject, user.Id.ToString()), + new ClientClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external") + }; + var orgs = await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id); + var providers = await _currentContext.ProviderMembershipAsync(_providerUserRepository, user.Id); + var isPremium = await _licensingService.ValidateUserPremiumAsync(user); + foreach (var claim in CoreHelpers.BuildIdentityClaims(user, orgs, providers, isPremium)) + { + var upperValue = claim.Value.ToUpperInvariant(); + var isBool = upperValue == "TRUE" || upperValue == "FALSE"; + claims.Add(isBool ? + new ClientClaim(claim.Key, claim.Value, ClaimValueTypes.Boolean) : + new ClientClaim(claim.Key, claim.Value) + ); + } + + return new Client + { + ClientId = clientId, + RequireClientSecret = true, + ClientSecrets = { new Secret(user.ApiKey.Sha256()) }, + AllowedScopes = new string[] { "api" }, + AllowedGrantTypes = GrantTypes.ClientCredentials, + AccessTokenLifetime = 3600 * 1, + ClientClaimsPrefix = null, + Claims = claims + }; + } + } + } + + return _staticClientStore.ApiClients.ContainsKey(clientId) ? + _staticClientStore.ApiClients[clientId] : null; } } diff --git a/src/Core/IdentityServer/ConfigureOpenIdConnectDistributedOptions.cs b/src/Core/IdentityServer/ConfigureOpenIdConnectDistributedOptions.cs index b3846e81f..084f98a27 100644 --- a/src/Core/IdentityServer/ConfigureOpenIdConnectDistributedOptions.cs +++ b/src/Core/IdentityServer/ConfigureOpenIdConnectDistributedOptions.cs @@ -5,49 +5,48 @@ using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Caching.StackExchangeRedis; using Microsoft.Extensions.Options; -namespace Bit.Core.IdentityServer -{ - public class ConfigureOpenIdConnectDistributedOptions : IPostConfigureOptions - { - private readonly IdentityServerOptions _idsrv; - private readonly IHttpContextAccessor _httpContextAccessor; - private readonly GlobalSettings _globalSettings; +namespace Bit.Core.IdentityServer; - public ConfigureOpenIdConnectDistributedOptions(IHttpContextAccessor httpContextAccessor, GlobalSettings globalSettings, - IdentityServerOptions idsrv) +public class ConfigureOpenIdConnectDistributedOptions : IPostConfigureOptions +{ + private readonly IdentityServerOptions _idsrv; + private readonly IHttpContextAccessor _httpContextAccessor; + private readonly GlobalSettings _globalSettings; + + public ConfigureOpenIdConnectDistributedOptions(IHttpContextAccessor httpContextAccessor, GlobalSettings globalSettings, + IdentityServerOptions idsrv) + { + _httpContextAccessor = httpContextAccessor ?? throw new ArgumentNullException(nameof(httpContextAccessor)); + _globalSettings = globalSettings; + _idsrv = idsrv; + } + + public void PostConfigure(string name, CookieAuthenticationOptions options) + { + options.CookieManager = new DistributedCacheCookieManager(); + + if (name != AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme) { - _httpContextAccessor = httpContextAccessor ?? throw new ArgumentNullException(nameof(httpContextAccessor)); - _globalSettings = globalSettings; - _idsrv = idsrv; + // Ignore + return; } - public void PostConfigure(string name, CookieAuthenticationOptions options) + options.Cookie.Name = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme; + options.Cookie.IsEssential = true; + options.Cookie.SameSite = _idsrv.Authentication.CookieSameSiteMode; + options.TicketDataFormat = new DistributedCacheTicketDataFormatter(_httpContextAccessor, name); + + if (string.IsNullOrWhiteSpace(_globalSettings.IdentityServer?.RedisConnectionString)) { - options.CookieManager = new DistributedCacheCookieManager(); - - if (name != AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme) + options.SessionStore = new MemoryCacheTicketStore(); + } + else + { + var redisOptions = new RedisCacheOptions { - // Ignore - return; - } - - options.Cookie.Name = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme; - options.Cookie.IsEssential = true; - options.Cookie.SameSite = _idsrv.Authentication.CookieSameSiteMode; - options.TicketDataFormat = new DistributedCacheTicketDataFormatter(_httpContextAccessor, name); - - if (string.IsNullOrWhiteSpace(_globalSettings.IdentityServer?.RedisConnectionString)) - { - options.SessionStore = new MemoryCacheTicketStore(); - } - else - { - var redisOptions = new RedisCacheOptions - { - Configuration = _globalSettings.IdentityServer.RedisConnectionString, - }; - options.SessionStore = new RedisCacheTicketStore(redisOptions); - } + Configuration = _globalSettings.IdentityServer.RedisConnectionString, + }; + options.SessionStore = new RedisCacheTicketStore(redisOptions); } } } diff --git a/src/Core/IdentityServer/CustomTokenRequestValidator.cs b/src/Core/IdentityServer/CustomTokenRequestValidator.cs index f37e165f3..1354af70a 100644 --- a/src/Core/IdentityServer/CustomTokenRequestValidator.cs +++ b/src/Core/IdentityServer/CustomTokenRequestValidator.cs @@ -11,143 +11,142 @@ using IdentityServer4.Validation; using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.Logging; -namespace Bit.Core.IdentityServer +namespace Bit.Core.IdentityServer; + +public class CustomTokenRequestValidator : BaseRequestValidator, + ICustomTokenRequestValidator { - public class CustomTokenRequestValidator : BaseRequestValidator, - ICustomTokenRequestValidator + private UserManager _userManager; + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly IOrganizationRepository _organizationRepository; + + public CustomTokenRequestValidator( + UserManager userManager, + IDeviceRepository deviceRepository, + IDeviceService deviceService, + IUserService userService, + IEventService eventService, + IOrganizationDuoWebTokenProvider organizationDuoWebTokenProvider, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IApplicationCacheService applicationCacheService, + IMailService mailService, + ILogger logger, + ICurrentContext currentContext, + GlobalSettings globalSettings, + IPolicyRepository policyRepository, + ISsoConfigRepository ssoConfigRepository, + IUserRepository userRepository, + ICaptchaValidationService captchaValidationService) + : base(userManager, deviceRepository, deviceService, userService, eventService, + organizationDuoWebTokenProvider, organizationRepository, organizationUserRepository, + applicationCacheService, mailService, logger, currentContext, globalSettings, policyRepository, + userRepository, captchaValidationService) { - private UserManager _userManager; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly IOrganizationRepository _organizationRepository; + _userManager = userManager; + _ssoConfigRepository = ssoConfigRepository; + _organizationRepository = organizationRepository; + } - public CustomTokenRequestValidator( - UserManager userManager, - IDeviceRepository deviceRepository, - IDeviceService deviceService, - IUserService userService, - IEventService eventService, - IOrganizationDuoWebTokenProvider organizationDuoWebTokenProvider, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IApplicationCacheService applicationCacheService, - IMailService mailService, - ILogger logger, - ICurrentContext currentContext, - GlobalSettings globalSettings, - IPolicyRepository policyRepository, - ISsoConfigRepository ssoConfigRepository, - IUserRepository userRepository, - ICaptchaValidationService captchaValidationService) - : base(userManager, deviceRepository, deviceService, userService, eventService, - organizationDuoWebTokenProvider, organizationRepository, organizationUserRepository, - applicationCacheService, mailService, logger, currentContext, globalSettings, policyRepository, - userRepository, captchaValidationService) + public async Task ValidateAsync(CustomTokenRequestValidationContext context) + { + string[] allowedGrantTypes = { "authorization_code", "client_credentials" }; + if (!allowedGrantTypes.Contains(context.Result.ValidatedRequest.GrantType) + || context.Result.ValidatedRequest.ClientId.StartsWith("organization") + || context.Result.ValidatedRequest.ClientId.StartsWith("installation")) { - _userManager = userManager; - _ssoConfigRepository = ssoConfigRepository; - _organizationRepository = organizationRepository; + return; } + await ValidateAsync(context, context.Result.ValidatedRequest, + new CustomValidatorRequestContext { KnownDevice = true }); + } - public async Task ValidateAsync(CustomTokenRequestValidationContext context) + protected async override Task ValidateContextAsync(CustomTokenRequestValidationContext context, + CustomValidatorRequestContext validatorContext) + { + var email = context.Result.ValidatedRequest.Subject?.GetDisplayName() + ?? context.Result.ValidatedRequest.ClientClaims?.FirstOrDefault(claim => claim.Type == JwtClaimTypes.Email)?.Value; + if (!string.IsNullOrWhiteSpace(email)) { - string[] allowedGrantTypes = { "authorization_code", "client_credentials" }; - if (!allowedGrantTypes.Contains(context.Result.ValidatedRequest.GrantType) - || context.Result.ValidatedRequest.ClientId.StartsWith("organization") - || context.Result.ValidatedRequest.ClientId.StartsWith("installation")) - { - return; - } - await ValidateAsync(context, context.Result.ValidatedRequest, - new CustomValidatorRequestContext { KnownDevice = true }); + validatorContext.User = await _userManager.FindByEmailAsync(email); } + return validatorContext.User != null; + } - protected async override Task ValidateContextAsync(CustomTokenRequestValidationContext context, - CustomValidatorRequestContext validatorContext) + protected override async Task SetSuccessResult(CustomTokenRequestValidationContext context, User user, + List claims, Dictionary customResponse) + { + context.Result.CustomResponse = customResponse; + if (claims?.Any() ?? false) { - var email = context.Result.ValidatedRequest.Subject?.GetDisplayName() - ?? context.Result.ValidatedRequest.ClientClaims?.FirstOrDefault(claim => claim.Type == JwtClaimTypes.Email)?.Value; - if (!string.IsNullOrWhiteSpace(email)) + context.Result.ValidatedRequest.Client.AlwaysSendClientClaims = true; + context.Result.ValidatedRequest.Client.ClientClaimsPrefix = string.Empty; + foreach (var claim in claims) { - validatorContext.User = await _userManager.FindByEmailAsync(email); - } - return validatorContext.User != null; - } - - protected override async Task SetSuccessResult(CustomTokenRequestValidationContext context, User user, - List claims, Dictionary customResponse) - { - context.Result.CustomResponse = customResponse; - if (claims?.Any() ?? false) - { - context.Result.ValidatedRequest.Client.AlwaysSendClientClaims = true; - context.Result.ValidatedRequest.Client.ClientClaimsPrefix = string.Empty; - foreach (var claim in claims) - { - context.Result.ValidatedRequest.ClientClaims.Add(claim); - } - } - - if (context.Result.CustomResponse == null || user.MasterPassword != null) - { - return; - } - - // KeyConnector responses below - - // Apikey login - if (context.Result.ValidatedRequest.GrantType == "client_credentials") - { - if (user.UsesKeyConnector) - { - // KeyConnectorUrl is configured in the CLI client, we just need to tell the client to use it - context.Result.CustomResponse["ApiUseKeyConnector"] = true; - context.Result.CustomResponse["ResetMasterPassword"] = false; - } - return; - } - - // SSO login - var organizationClaim = context.Result.ValidatedRequest.Subject?.FindFirst(c => c.Type == "organizationId"); - if (organizationClaim?.Value != null) - { - var organizationId = new Guid(organizationClaim.Value); - - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organizationId); - var ssoConfigData = ssoConfig.GetData(); - - if (ssoConfigData is { KeyConnectorEnabled: true } && !string.IsNullOrEmpty(ssoConfigData.KeyConnectorUrl)) - { - context.Result.CustomResponse["KeyConnectorUrl"] = ssoConfigData.KeyConnectorUrl; - // Prevent clients redirecting to set-password - context.Result.CustomResponse["ResetMasterPassword"] = false; - } + context.Result.ValidatedRequest.ClientClaims.Add(claim); } } - protected override void SetTwoFactorResult(CustomTokenRequestValidationContext context, - Dictionary customResponse) + if (context.Result.CustomResponse == null || user.MasterPassword != null) { - context.Result.Error = "invalid_grant"; - context.Result.ErrorDescription = "Two factor required."; - context.Result.IsError = true; - context.Result.CustomResponse = customResponse; + return; } - protected override void SetSsoResult(CustomTokenRequestValidationContext context, - Dictionary customResponse) + // KeyConnector responses below + + // Apikey login + if (context.Result.ValidatedRequest.GrantType == "client_credentials") { - context.Result.Error = "invalid_grant"; - context.Result.ErrorDescription = "Single Sign on required."; - context.Result.IsError = true; - context.Result.CustomResponse = customResponse; + if (user.UsesKeyConnector) + { + // KeyConnectorUrl is configured in the CLI client, we just need to tell the client to use it + context.Result.CustomResponse["ApiUseKeyConnector"] = true; + context.Result.CustomResponse["ResetMasterPassword"] = false; + } + return; } - protected override void SetErrorResult(CustomTokenRequestValidationContext context, - Dictionary customResponse) + // SSO login + var organizationClaim = context.Result.ValidatedRequest.Subject?.FindFirst(c => c.Type == "organizationId"); + if (organizationClaim?.Value != null) { - context.Result.Error = "invalid_grant"; - context.Result.IsError = true; - context.Result.CustomResponse = customResponse; + var organizationId = new Guid(organizationClaim.Value); + + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organizationId); + var ssoConfigData = ssoConfig.GetData(); + + if (ssoConfigData is { KeyConnectorEnabled: true } && !string.IsNullOrEmpty(ssoConfigData.KeyConnectorUrl)) + { + context.Result.CustomResponse["KeyConnectorUrl"] = ssoConfigData.KeyConnectorUrl; + // Prevent clients redirecting to set-password + context.Result.CustomResponse["ResetMasterPassword"] = false; + } } } + + protected override void SetTwoFactorResult(CustomTokenRequestValidationContext context, + Dictionary customResponse) + { + context.Result.Error = "invalid_grant"; + context.Result.ErrorDescription = "Two factor required."; + context.Result.IsError = true; + context.Result.CustomResponse = customResponse; + } + + protected override void SetSsoResult(CustomTokenRequestValidationContext context, + Dictionary customResponse) + { + context.Result.Error = "invalid_grant"; + context.Result.ErrorDescription = "Single Sign on required."; + context.Result.IsError = true; + context.Result.CustomResponse = customResponse; + } + + protected override void SetErrorResult(CustomTokenRequestValidationContext context, + Dictionary customResponse) + { + context.Result.Error = "invalid_grant"; + context.Result.IsError = true; + context.Result.CustomResponse = customResponse; + } } diff --git a/src/Core/IdentityServer/CustomValidatorRequestContext.cs b/src/Core/IdentityServer/CustomValidatorRequestContext.cs index f5e95aaa8..66fdc1e7e 100644 --- a/src/Core/IdentityServer/CustomValidatorRequestContext.cs +++ b/src/Core/IdentityServer/CustomValidatorRequestContext.cs @@ -1,12 +1,11 @@ using Bit.Core.Entities; using Bit.Core.Models.Business; -namespace Bit.Core.IdentityServer +namespace Bit.Core.IdentityServer; + +public class CustomValidatorRequestContext { - public class CustomValidatorRequestContext - { - public User User { get; set; } - public bool KnownDevice { get; set; } - public CaptchaResponse CaptchaResponse { get; set; } - } + public User User { get; set; } + public bool KnownDevice { get; set; } + public CaptchaResponse CaptchaResponse { get; set; } } diff --git a/src/Core/IdentityServer/DistributedCacheCookieManager.cs b/src/Core/IdentityServer/DistributedCacheCookieManager.cs index 988afc018..d202581bd 100644 --- a/src/Core/IdentityServer/DistributedCacheCookieManager.cs +++ b/src/Core/IdentityServer/DistributedCacheCookieManager.cs @@ -4,66 +4,65 @@ using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.IdentityServer +namespace Bit.Core.IdentityServer; + +public class DistributedCacheCookieManager : ICookieManager { - public class DistributedCacheCookieManager : ICookieManager + private readonly ChunkingCookieManager _cookieManager; + + public DistributedCacheCookieManager() { - private readonly ChunkingCookieManager _cookieManager; - - public DistributedCacheCookieManager() - { - _cookieManager = new ChunkingCookieManager(); - } - - private string CacheKeyPrefix => "cookie-data"; - - public void AppendResponseCookie(HttpContext context, string key, string value, CookieOptions options) - { - var id = Guid.NewGuid().ToString(); - var cacheKey = GetKey(key, id); - - var expiresUtc = options.Expires ?? DateTimeOffset.UtcNow.AddMinutes(15); - var cacheOptions = new DistributedCacheEntryOptions() - .SetAbsoluteExpiration(expiresUtc); - - var data = Encoding.UTF8.GetBytes(value); - - var cache = GetCache(context); - cache.Set(cacheKey, data, cacheOptions); - - // Write the cookie with the identifier as the body - _cookieManager.AppendResponseCookie(context, key, id, options); - } - - public void DeleteCookie(HttpContext context, string key, CookieOptions options) - { - _cookieManager.DeleteCookie(context, key, options); - var id = GetId(context, key); - if (!string.IsNullOrWhiteSpace(id)) - { - var cacheKey = GetKey(key, id); - GetCache(context).Remove(cacheKey); - } - } - - public string GetRequestCookie(HttpContext context, string key) - { - var id = GetId(context, key); - if (string.IsNullOrWhiteSpace(id)) - { - return null; - } - var cacheKey = GetKey(key, id); - return GetCache(context).GetString(cacheKey); - } - - private IDistributedCache GetCache(HttpContext context) => - context.RequestServices.GetRequiredService(); - - private string GetKey(string key, string id) => $"{CacheKeyPrefix}-{key}-{id}"; - - private string GetId(HttpContext context, string key) => - context.Request.Cookies.ContainsKey(key) ? - context.Request.Cookies[key] : null; + _cookieManager = new ChunkingCookieManager(); } + + private string CacheKeyPrefix => "cookie-data"; + + public void AppendResponseCookie(HttpContext context, string key, string value, CookieOptions options) + { + var id = Guid.NewGuid().ToString(); + var cacheKey = GetKey(key, id); + + var expiresUtc = options.Expires ?? DateTimeOffset.UtcNow.AddMinutes(15); + var cacheOptions = new DistributedCacheEntryOptions() + .SetAbsoluteExpiration(expiresUtc); + + var data = Encoding.UTF8.GetBytes(value); + + var cache = GetCache(context); + cache.Set(cacheKey, data, cacheOptions); + + // Write the cookie with the identifier as the body + _cookieManager.AppendResponseCookie(context, key, id, options); + } + + public void DeleteCookie(HttpContext context, string key, CookieOptions options) + { + _cookieManager.DeleteCookie(context, key, options); + var id = GetId(context, key); + if (!string.IsNullOrWhiteSpace(id)) + { + var cacheKey = GetKey(key, id); + GetCache(context).Remove(cacheKey); + } + } + + public string GetRequestCookie(HttpContext context, string key) + { + var id = GetId(context, key); + if (string.IsNullOrWhiteSpace(id)) + { + return null; + } + var cacheKey = GetKey(key, id); + return GetCache(context).GetString(cacheKey); + } + + private IDistributedCache GetCache(HttpContext context) => + context.RequestServices.GetRequiredService(); + + private string GetKey(string key, string id) => $"{CacheKeyPrefix}-{key}-{id}"; + + private string GetId(HttpContext context, string key) => + context.Request.Cookies.ContainsKey(key) ? + context.Request.Cookies[key] : null; } diff --git a/src/Core/IdentityServer/DistributedCacheTicketDataFormatter.cs b/src/Core/IdentityServer/DistributedCacheTicketDataFormatter.cs index bbd1d4087..ec47a0f7c 100644 --- a/src/Core/IdentityServer/DistributedCacheTicketDataFormatter.cs +++ b/src/Core/IdentityServer/DistributedCacheTicketDataFormatter.cs @@ -4,62 +4,61 @@ using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.IdentityServer +namespace Bit.Core.IdentityServer; + +public class DistributedCacheTicketDataFormatter : ISecureDataFormat { - public class DistributedCacheTicketDataFormatter : ISecureDataFormat + private readonly IHttpContextAccessor _httpContext; + private readonly string _name; + + public DistributedCacheTicketDataFormatter(IHttpContextAccessor httpContext, string name) { - private readonly IHttpContextAccessor _httpContext; - private readonly string _name; + _httpContext = httpContext; + _name = name; + } - public DistributedCacheTicketDataFormatter(IHttpContextAccessor httpContext, string name) + private string CacheKeyPrefix => "ticket-data"; + private IDistributedCache Cache => _httpContext.HttpContext.RequestServices.GetRequiredService(); + private IDataProtector Protector => _httpContext.HttpContext.RequestServices.GetRequiredService() + .CreateProtector(CacheKeyPrefix, _name); + + public string Protect(AuthenticationTicket data) => Protect(data, null); + public string Protect(AuthenticationTicket data, string purpose) + { + var key = Guid.NewGuid().ToString(); + var cacheKey = $"{CacheKeyPrefix}-{_name}-{purpose}-{key}"; + + var expiresUtc = data.Properties.ExpiresUtc ?? + DateTimeOffset.UtcNow.AddMinutes(15); + + var options = new DistributedCacheEntryOptions(); + options.SetAbsoluteExpiration(expiresUtc); + + var ticket = TicketSerializer.Default.Serialize(data); + Cache.Set(cacheKey, ticket, options); + + return Protector.Protect(key); + } + + public AuthenticationTicket Unprotect(string protectedText) => Unprotect(protectedText, null); + public AuthenticationTicket Unprotect(string protectedText, string purpose) + { + if (string.IsNullOrWhiteSpace(protectedText)) { - _httpContext = httpContext; - _name = name; + return null; } - private string CacheKeyPrefix => "ticket-data"; - private IDistributedCache Cache => _httpContext.HttpContext.RequestServices.GetRequiredService(); - private IDataProtector Protector => _httpContext.HttpContext.RequestServices.GetRequiredService() - .CreateProtector(CacheKeyPrefix, _name); + // Decrypt the key and retrieve the data from the cache. + var key = Protector.Unprotect(protectedText); + var cacheKey = $"{CacheKeyPrefix}-{_name}-{purpose}-{key}"; + var ticket = Cache.Get(cacheKey); - public string Protect(AuthenticationTicket data) => Protect(data, null); - public string Protect(AuthenticationTicket data, string purpose) + if (ticket == null) { - var key = Guid.NewGuid().ToString(); - var cacheKey = $"{CacheKeyPrefix}-{_name}-{purpose}-{key}"; - - var expiresUtc = data.Properties.ExpiresUtc ?? - DateTimeOffset.UtcNow.AddMinutes(15); - - var options = new DistributedCacheEntryOptions(); - options.SetAbsoluteExpiration(expiresUtc); - - var ticket = TicketSerializer.Default.Serialize(data); - Cache.Set(cacheKey, ticket, options); - - return Protector.Protect(key); + return null; } - public AuthenticationTicket Unprotect(string protectedText) => Unprotect(protectedText, null); - public AuthenticationTicket Unprotect(string protectedText, string purpose) - { - if (string.IsNullOrWhiteSpace(protectedText)) - { - return null; - } - - // Decrypt the key and retrieve the data from the cache. - var key = Protector.Unprotect(protectedText); - var cacheKey = $"{CacheKeyPrefix}-{_name}-{purpose}-{key}"; - var ticket = Cache.Get(cacheKey); - - if (ticket == null) - { - return null; - } - - var data = TicketSerializer.Default.Deserialize(ticket); - return data; - } + var data = TicketSerializer.Default.Deserialize(ticket); + return data; } } diff --git a/src/Core/IdentityServer/MemoryCacheTicketStore.cs b/src/Core/IdentityServer/MemoryCacheTicketStore.cs index 7120aee07..dc8d763c9 100644 --- a/src/Core/IdentityServer/MemoryCacheTicketStore.cs +++ b/src/Core/IdentityServer/MemoryCacheTicketStore.cs @@ -2,53 +2,52 @@ using Microsoft.AspNetCore.Authentication.Cookies; using Microsoft.Extensions.Caching.Memory; -namespace Bit.Core.IdentityServer +namespace Bit.Core.IdentityServer; + +public class MemoryCacheTicketStore : ITicketStore { - public class MemoryCacheTicketStore : ITicketStore + private const string _keyPrefix = "auth-"; + private readonly IMemoryCache _cache; + + public MemoryCacheTicketStore() { - private const string _keyPrefix = "auth-"; - private readonly IMemoryCache _cache; + _cache = new MemoryCache(new MemoryCacheOptions()); + } - public MemoryCacheTicketStore() + public async Task StoreAsync(AuthenticationTicket ticket) + { + var key = $"{_keyPrefix}{Guid.NewGuid()}"; + await RenewAsync(key, ticket); + return key; + } + + public Task RenewAsync(string key, AuthenticationTicket ticket) + { + var options = new MemoryCacheEntryOptions(); + var expiresUtc = ticket.Properties.ExpiresUtc; + if (expiresUtc.HasValue) { - _cache = new MemoryCache(new MemoryCacheOptions()); + options.SetAbsoluteExpiration(expiresUtc.Value); + } + else + { + options.SetSlidingExpiration(TimeSpan.FromMinutes(15)); } - public async Task StoreAsync(AuthenticationTicket ticket) - { - var key = $"{_keyPrefix}{Guid.NewGuid()}"; - await RenewAsync(key, ticket); - return key; - } + _cache.Set(key, ticket, options); - public Task RenewAsync(string key, AuthenticationTicket ticket) - { - var options = new MemoryCacheEntryOptions(); - var expiresUtc = ticket.Properties.ExpiresUtc; - if (expiresUtc.HasValue) - { - options.SetAbsoluteExpiration(expiresUtc.Value); - } - else - { - options.SetSlidingExpiration(TimeSpan.FromMinutes(15)); - } + return Task.FromResult(0); + } - _cache.Set(key, ticket, options); + public Task RetrieveAsync(string key) + { + _cache.TryGetValue(key, out AuthenticationTicket ticket); + return Task.FromResult(ticket); + } - return Task.FromResult(0); - } - - public Task RetrieveAsync(string key) - { - _cache.TryGetValue(key, out AuthenticationTicket ticket); - return Task.FromResult(ticket); - } - - public Task RemoveAsync(string key) - { - _cache.Remove(key); - return Task.FromResult(0); - } + public Task RemoveAsync(string key) + { + _cache.Remove(key); + return Task.FromResult(0); } } diff --git a/src/Core/IdentityServer/OidcIdentityClient.cs b/src/Core/IdentityServer/OidcIdentityClient.cs index 7f24f66e2..822ac56cd 100644 --- a/src/Core/IdentityServer/OidcIdentityClient.cs +++ b/src/Core/IdentityServer/OidcIdentityClient.cs @@ -2,25 +2,24 @@ using IdentityServer4; using IdentityServer4.Models; -namespace Bit.Core.IdentityServer +namespace Bit.Core.IdentityServer; + +public class OidcIdentityClient : Client { - public class OidcIdentityClient : Client + public OidcIdentityClient(GlobalSettings globalSettings) { - public OidcIdentityClient(GlobalSettings globalSettings) + ClientId = "oidc-identity"; + RequireClientSecret = true; + RequirePkce = true; + ClientSecrets = new List { new Secret(globalSettings.OidcIdentityClientKey.Sha256()) }; + AllowedScopes = new string[] { - ClientId = "oidc-identity"; - RequireClientSecret = true; - RequirePkce = true; - ClientSecrets = new List { new Secret(globalSettings.OidcIdentityClientKey.Sha256()) }; - AllowedScopes = new string[] - { - IdentityServerConstants.StandardScopes.OpenId, - IdentityServerConstants.StandardScopes.Profile - }; - AllowedGrantTypes = GrantTypes.Code; - Enabled = true; - RedirectUris = new List { $"{globalSettings.BaseServiceUri.Identity}/signin-oidc" }; - RequireConsent = false; - } + IdentityServerConstants.StandardScopes.OpenId, + IdentityServerConstants.StandardScopes.Profile + }; + AllowedGrantTypes = GrantTypes.Code; + Enabled = true; + RedirectUris = new List { $"{globalSettings.BaseServiceUri.Identity}/signin-oidc" }; + RequireConsent = false; } } diff --git a/src/Core/IdentityServer/PersistedGrantStore.cs b/src/Core/IdentityServer/PersistedGrantStore.cs index 7094265e7..a1b3294ba 100644 --- a/src/Core/IdentityServer/PersistedGrantStore.cs +++ b/src/Core/IdentityServer/PersistedGrantStore.cs @@ -3,86 +3,85 @@ using IdentityServer4.Models; using IdentityServer4.Stores; using Grant = Bit.Core.Entities.Grant; -namespace Bit.Core.IdentityServer +namespace Bit.Core.IdentityServer; + +public class PersistedGrantStore : IPersistedGrantStore { - public class PersistedGrantStore : IPersistedGrantStore + private readonly IGrantRepository _grantRepository; + + public PersistedGrantStore( + IGrantRepository grantRepository) { - private readonly IGrantRepository _grantRepository; + _grantRepository = grantRepository; + } - public PersistedGrantStore( - IGrantRepository grantRepository) + public async Task GetAsync(string key) + { + var grant = await _grantRepository.GetByKeyAsync(key); + if (grant == null) { - _grantRepository = grantRepository; + return null; } - public async Task GetAsync(string key) - { - var grant = await _grantRepository.GetByKeyAsync(key); - if (grant == null) - { - return null; - } + var pGrant = ToPersistedGrant(grant); + return pGrant; + } - var pGrant = ToPersistedGrant(grant); - return pGrant; - } + public async Task> GetAllAsync(PersistedGrantFilter filter) + { + var grants = await _grantRepository.GetManyAsync(filter.SubjectId, filter.SessionId, + filter.ClientId, filter.Type); + var pGrants = grants.Select(g => ToPersistedGrant(g)); + return pGrants; + } - public async Task> GetAllAsync(PersistedGrantFilter filter) - { - var grants = await _grantRepository.GetManyAsync(filter.SubjectId, filter.SessionId, - filter.ClientId, filter.Type); - var pGrants = grants.Select(g => ToPersistedGrant(g)); - return pGrants; - } + public async Task RemoveAllAsync(PersistedGrantFilter filter) + { + await _grantRepository.DeleteManyAsync(filter.SubjectId, filter.SessionId, filter.ClientId, filter.Type); + } - public async Task RemoveAllAsync(PersistedGrantFilter filter) - { - await _grantRepository.DeleteManyAsync(filter.SubjectId, filter.SessionId, filter.ClientId, filter.Type); - } + public async Task RemoveAsync(string key) + { + await _grantRepository.DeleteByKeyAsync(key); + } - public async Task RemoveAsync(string key) - { - await _grantRepository.DeleteByKeyAsync(key); - } + public async Task StoreAsync(PersistedGrant pGrant) + { + var grant = ToGrant(pGrant); + await _grantRepository.SaveAsync(grant); + } - public async Task StoreAsync(PersistedGrant pGrant) + private Grant ToGrant(PersistedGrant pGrant) + { + return new Grant { - var grant = ToGrant(pGrant); - await _grantRepository.SaveAsync(grant); - } + Key = pGrant.Key, + Type = pGrant.Type, + SubjectId = pGrant.SubjectId, + SessionId = pGrant.SessionId, + ClientId = pGrant.ClientId, + Description = pGrant.Description, + CreationDate = pGrant.CreationTime, + ExpirationDate = pGrant.Expiration, + ConsumedDate = pGrant.ConsumedTime, + Data = pGrant.Data + }; + } - private Grant ToGrant(PersistedGrant pGrant) + private PersistedGrant ToPersistedGrant(Grant grant) + { + return new PersistedGrant { - return new Grant - { - Key = pGrant.Key, - Type = pGrant.Type, - SubjectId = pGrant.SubjectId, - SessionId = pGrant.SessionId, - ClientId = pGrant.ClientId, - Description = pGrant.Description, - CreationDate = pGrant.CreationTime, - ExpirationDate = pGrant.Expiration, - ConsumedDate = pGrant.ConsumedTime, - Data = pGrant.Data - }; - } - - private PersistedGrant ToPersistedGrant(Grant grant) - { - return new PersistedGrant - { - Key = grant.Key, - Type = grant.Type, - SubjectId = grant.SubjectId, - SessionId = grant.SessionId, - ClientId = grant.ClientId, - Description = grant.Description, - CreationTime = grant.CreationDate, - Expiration = grant.ExpirationDate, - ConsumedTime = grant.ConsumedDate, - Data = grant.Data - }; - } + Key = grant.Key, + Type = grant.Type, + SubjectId = grant.SubjectId, + SessionId = grant.SessionId, + ClientId = grant.ClientId, + Description = grant.Description, + CreationTime = grant.CreationDate, + Expiration = grant.ExpirationDate, + ConsumedTime = grant.ConsumedDate, + Data = grant.Data + }; } } diff --git a/src/Core/IdentityServer/ProfileService.cs b/src/Core/IdentityServer/ProfileService.cs index aa79c60d1..873ad6b5a 100644 --- a/src/Core/IdentityServer/ProfileService.cs +++ b/src/Core/IdentityServer/ProfileService.cs @@ -6,83 +6,82 @@ using Bit.Core.Utilities; using IdentityServer4.Models; using IdentityServer4.Services; -namespace Bit.Core.IdentityServer +namespace Bit.Core.IdentityServer; + +public class ProfileService : IProfileService { - public class ProfileService : IProfileService + private readonly IUserService _userService; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IProviderUserRepository _providerUserRepository; + private readonly IProviderOrganizationRepository _providerOrganizationRepository; + private readonly ILicensingService _licensingService; + private readonly ICurrentContext _currentContext; + + public ProfileService( + IUserService userService, + IOrganizationUserRepository organizationUserRepository, + IProviderUserRepository providerUserRepository, + IProviderOrganizationRepository providerOrganizationRepository, + ILicensingService licensingService, + ICurrentContext currentContext) { - private readonly IUserService _userService; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IProviderUserRepository _providerUserRepository; - private readonly IProviderOrganizationRepository _providerOrganizationRepository; - private readonly ILicensingService _licensingService; - private readonly ICurrentContext _currentContext; + _userService = userService; + _organizationUserRepository = organizationUserRepository; + _providerUserRepository = providerUserRepository; + _providerOrganizationRepository = providerOrganizationRepository; + _licensingService = licensingService; + _currentContext = currentContext; + } - public ProfileService( - IUserService userService, - IOrganizationUserRepository organizationUserRepository, - IProviderUserRepository providerUserRepository, - IProviderOrganizationRepository providerOrganizationRepository, - ILicensingService licensingService, - ICurrentContext currentContext) + public async Task GetProfileDataAsync(ProfileDataRequestContext context) + { + var existingClaims = context.Subject.Claims; + var newClaims = new List(); + + var user = await _userService.GetUserByPrincipalAsync(context.Subject); + if (user != null) { - _userService = userService; - _organizationUserRepository = organizationUserRepository; - _providerUserRepository = providerUserRepository; - _providerOrganizationRepository = providerOrganizationRepository; - _licensingService = licensingService; - _currentContext = currentContext; - } - - public async Task GetProfileDataAsync(ProfileDataRequestContext context) - { - var existingClaims = context.Subject.Claims; - var newClaims = new List(); - - var user = await _userService.GetUserByPrincipalAsync(context.Subject); - if (user != null) + var isPremium = await _licensingService.ValidateUserPremiumAsync(user); + var orgs = await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id); + var providers = await _currentContext.ProviderMembershipAsync(_providerUserRepository, user.Id); + foreach (var claim in CoreHelpers.BuildIdentityClaims(user, orgs, providers, isPremium)) { - var isPremium = await _licensingService.ValidateUserPremiumAsync(user); - var orgs = await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, user.Id); - var providers = await _currentContext.ProviderMembershipAsync(_providerUserRepository, user.Id); - foreach (var claim in CoreHelpers.BuildIdentityClaims(user, orgs, providers, isPremium)) - { - var upperValue = claim.Value.ToUpperInvariant(); - var isBool = upperValue == "TRUE" || upperValue == "FALSE"; - newClaims.Add(isBool ? - new Claim(claim.Key, claim.Value, ClaimValueTypes.Boolean) : - new Claim(claim.Key, claim.Value) - ); - } - } - - // filter out any of the new claims - var existingClaimsToKeep = existingClaims - .Where(c => !c.Type.StartsWith("org") && - (newClaims.Count == 0 || !newClaims.Any(nc => nc.Type == c.Type))) - .ToList(); - - newClaims.AddRange(existingClaimsToKeep); - if (newClaims.Any()) - { - context.IssuedClaims.AddRange(newClaims); + var upperValue = claim.Value.ToUpperInvariant(); + var isBool = upperValue == "TRUE" || upperValue == "FALSE"; + newClaims.Add(isBool ? + new Claim(claim.Key, claim.Value, ClaimValueTypes.Boolean) : + new Claim(claim.Key, claim.Value) + ); } } - public async Task IsActiveAsync(IsActiveContext context) - { - var securityTokenClaim = context.Subject?.Claims.FirstOrDefault(c => c.Type == "sstamp"); - var user = await _userService.GetUserByPrincipalAsync(context.Subject); + // filter out any of the new claims + var existingClaimsToKeep = existingClaims + .Where(c => !c.Type.StartsWith("org") && + (newClaims.Count == 0 || !newClaims.Any(nc => nc.Type == c.Type))) + .ToList(); - if (user != null && securityTokenClaim != null) - { - context.IsActive = string.Equals(user.SecurityStamp, securityTokenClaim.Value, - StringComparison.InvariantCultureIgnoreCase); - return; - } - else - { - context.IsActive = true; - } + newClaims.AddRange(existingClaimsToKeep); + if (newClaims.Any()) + { + context.IssuedClaims.AddRange(newClaims); + } + } + + public async Task IsActiveAsync(IsActiveContext context) + { + var securityTokenClaim = context.Subject?.Claims.FirstOrDefault(c => c.Type == "sstamp"); + var user = await _userService.GetUserByPrincipalAsync(context.Subject); + + if (user != null && securityTokenClaim != null) + { + context.IsActive = string.Equals(user.SecurityStamp, securityTokenClaim.Value, + StringComparison.InvariantCultureIgnoreCase); + return; + } + else + { + context.IsActive = true; } } } diff --git a/src/Core/IdentityServer/RedisCacheTicketStore.cs b/src/Core/IdentityServer/RedisCacheTicketStore.cs index f7aa8c0a9..139158c32 100644 --- a/src/Core/IdentityServer/RedisCacheTicketStore.cs +++ b/src/Core/IdentityServer/RedisCacheTicketStore.cs @@ -3,63 +3,62 @@ using Microsoft.AspNetCore.Authentication.Cookies; using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.Caching.StackExchangeRedis; -namespace Bit.Core.IdentityServer +namespace Bit.Core.IdentityServer; + +public class RedisCacheTicketStore : ITicketStore { - public class RedisCacheTicketStore : ITicketStore + private const string _keyPrefix = "auth-"; + private readonly IDistributedCache _cache; + + public RedisCacheTicketStore(RedisCacheOptions options) { - private const string _keyPrefix = "auth-"; - private readonly IDistributedCache _cache; + _cache = new RedisCache(options); + } - public RedisCacheTicketStore(RedisCacheOptions options) - { - _cache = new RedisCache(options); - } + public async Task StoreAsync(AuthenticationTicket ticket) + { + var key = $"{_keyPrefix}{Guid.NewGuid()}"; + await RenewAsync(key, ticket); - public async Task StoreAsync(AuthenticationTicket ticket) - { - var key = $"{_keyPrefix}{Guid.NewGuid()}"; - await RenewAsync(key, ticket); + return key; + } - return key; - } + public Task RenewAsync(string key, AuthenticationTicket ticket) + { + var options = new DistributedCacheEntryOptions(); + var expiresUtc = ticket.Properties.ExpiresUtc ?? + DateTimeOffset.UtcNow.AddMinutes(15); + options.SetAbsoluteExpiration(expiresUtc); - public Task RenewAsync(string key, AuthenticationTicket ticket) - { - var options = new DistributedCacheEntryOptions(); - var expiresUtc = ticket.Properties.ExpiresUtc ?? - DateTimeOffset.UtcNow.AddMinutes(15); - options.SetAbsoluteExpiration(expiresUtc); + var val = SerializeToBytes(ticket); + _cache.Set(key, val, options); - var val = SerializeToBytes(ticket); - _cache.Set(key, val, options); + return Task.FromResult(0); + } - return Task.FromResult(0); - } + public Task RetrieveAsync(string key) + { + AuthenticationTicket ticket; + var bytes = _cache.Get(key); + ticket = DeserializeFromBytes(bytes); - public Task RetrieveAsync(string key) - { - AuthenticationTicket ticket; - var bytes = _cache.Get(key); - ticket = DeserializeFromBytes(bytes); + return Task.FromResult(ticket); + } - return Task.FromResult(ticket); - } + public Task RemoveAsync(string key) + { + _cache.Remove(key); - public Task RemoveAsync(string key) - { - _cache.Remove(key); + return Task.FromResult(0); + } - return Task.FromResult(0); - } + private static byte[] SerializeToBytes(AuthenticationTicket source) + { + return TicketSerializer.Default.Serialize(source); + } - private static byte[] SerializeToBytes(AuthenticationTicket source) - { - return TicketSerializer.Default.Serialize(source); - } - - private static AuthenticationTicket DeserializeFromBytes(byte[] source) - { - return source == null ? null : TicketSerializer.Default.Deserialize(source); - } + private static AuthenticationTicket DeserializeFromBytes(byte[] source) + { + return source == null ? null : TicketSerializer.Default.Deserialize(source); } } diff --git a/src/Core/IdentityServer/ResourceOwnerPasswordValidator.cs b/src/Core/IdentityServer/ResourceOwnerPasswordValidator.cs index f83143141..82b3cf50a 100644 --- a/src/Core/IdentityServer/ResourceOwnerPasswordValidator.cs +++ b/src/Core/IdentityServer/ResourceOwnerPasswordValidator.cs @@ -11,163 +11,162 @@ using IdentityServer4.Validation; using Microsoft.AspNetCore.Identity; using Microsoft.Extensions.Logging; -namespace Bit.Core.IdentityServer +namespace Bit.Core.IdentityServer; + +public class ResourceOwnerPasswordValidator : BaseRequestValidator, + IResourceOwnerPasswordValidator { - public class ResourceOwnerPasswordValidator : BaseRequestValidator, - IResourceOwnerPasswordValidator + private UserManager _userManager; + private readonly IUserService _userService; + private readonly ICurrentContext _currentContext; + private readonly ICaptchaValidationService _captchaValidationService; + public ResourceOwnerPasswordValidator( + UserManager userManager, + IDeviceRepository deviceRepository, + IDeviceService deviceService, + IUserService userService, + IEventService eventService, + IOrganizationDuoWebTokenProvider organizationDuoWebTokenProvider, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IApplicationCacheService applicationCacheService, + IMailService mailService, + ILogger logger, + ICurrentContext currentContext, + GlobalSettings globalSettings, + IPolicyRepository policyRepository, + ICaptchaValidationService captchaValidationService, + IUserRepository userRepository) + : base(userManager, deviceRepository, deviceService, userService, eventService, + organizationDuoWebTokenProvider, organizationRepository, organizationUserRepository, + applicationCacheService, mailService, logger, currentContext, globalSettings, policyRepository, + userRepository, captchaValidationService) { - private UserManager _userManager; - private readonly IUserService _userService; - private readonly ICurrentContext _currentContext; - private readonly ICaptchaValidationService _captchaValidationService; - public ResourceOwnerPasswordValidator( - UserManager userManager, - IDeviceRepository deviceRepository, - IDeviceService deviceService, - IUserService userService, - IEventService eventService, - IOrganizationDuoWebTokenProvider organizationDuoWebTokenProvider, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IApplicationCacheService applicationCacheService, - IMailService mailService, - ILogger logger, - ICurrentContext currentContext, - GlobalSettings globalSettings, - IPolicyRepository policyRepository, - ICaptchaValidationService captchaValidationService, - IUserRepository userRepository) - : base(userManager, deviceRepository, deviceService, userService, eventService, - organizationDuoWebTokenProvider, organizationRepository, organizationUserRepository, - applicationCacheService, mailService, logger, currentContext, globalSettings, policyRepository, - userRepository, captchaValidationService) + _userManager = userManager; + _userService = userService; + _currentContext = currentContext; + _captchaValidationService = captchaValidationService; + } + + public async Task ValidateAsync(ResourceOwnerPasswordValidationContext context) + { + if (!AuthEmailHeaderIsValid(context)) { - _userManager = userManager; - _userService = userService; - _currentContext = currentContext; - _captchaValidationService = captchaValidationService; + context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, + "Auth-Email header invalid."); + return; } - public async Task ValidateAsync(ResourceOwnerPasswordValidationContext context) + var user = await _userManager.FindByEmailAsync(context.UserName.ToLowerInvariant()); + var validatorContext = new CustomValidatorRequestContext { - if (!AuthEmailHeaderIsValid(context)) + User = user, + KnownDevice = await KnownDeviceAsync(user, context.Request) + }; + string bypassToken = null; + if (!validatorContext.KnownDevice && + _captchaValidationService.RequireCaptchaValidation(_currentContext, user)) + { + var captchaResponse = context.Request.Raw["captchaResponse"]?.ToString(); + + if (string.IsNullOrWhiteSpace(captchaResponse)) { - context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, - "Auth-Email header invalid."); + context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Captcha required.", + new Dictionary + { + { _captchaValidationService.SiteKeyResponseKeyName, _captchaValidationService.SiteKey }, + }); return; } - var user = await _userManager.FindByEmailAsync(context.UserName.ToLowerInvariant()); - var validatorContext = new CustomValidatorRequestContext + validatorContext.CaptchaResponse = await _captchaValidationService.ValidateCaptchaResponseAsync( + captchaResponse, _currentContext.IpAddress, user); + if (!validatorContext.CaptchaResponse.Success) { - User = user, - KnownDevice = await KnownDeviceAsync(user, context.Request) - }; - string bypassToken = null; - if (!validatorContext.KnownDevice && - _captchaValidationService.RequireCaptchaValidation(_currentContext, user)) - { - var captchaResponse = context.Request.Raw["captchaResponse"]?.ToString(); + await BuildErrorResultAsync("Captcha is invalid. Please refresh and try again", false, context, null); + return; + } + bypassToken = _captchaValidationService.GenerateCaptchaBypassToken(user); + } - if (string.IsNullOrWhiteSpace(captchaResponse)) + await ValidateAsync(context, context.Request, validatorContext); + if (context.Result.CustomResponse != null && bypassToken != null) + { + context.Result.CustomResponse["CaptchaBypassToken"] = bypassToken; + } + } + + protected async override Task ValidateContextAsync(ResourceOwnerPasswordValidationContext context, + CustomValidatorRequestContext validatorContext) + { + if (string.IsNullOrWhiteSpace(context.UserName) || validatorContext.User == null) + { + return false; + } + + if (!await _userService.CheckPasswordAsync(validatorContext.User, context.Password)) + { + return false; + } + + return true; + } + + protected override Task SetSuccessResult(ResourceOwnerPasswordValidationContext context, User user, + List claims, Dictionary customResponse) + { + context.Result = new GrantValidationResult(user.Id.ToString(), "Application", + identityProvider: "bitwarden", + claims: claims.Count > 0 ? claims : null, + customResponse: customResponse); + return Task.CompletedTask; + } + + protected override void SetTwoFactorResult(ResourceOwnerPasswordValidationContext context, + Dictionary customResponse) + { + context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Two factor required.", + customResponse); + } + + protected override void SetSsoResult(ResourceOwnerPasswordValidationContext context, + Dictionary customResponse) + { + context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Sso authentication required.", + customResponse); + } + + protected override void SetErrorResult(ResourceOwnerPasswordValidationContext context, + Dictionary customResponse) + { + context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, customResponse: customResponse); + } + + private bool AuthEmailHeaderIsValid(ResourceOwnerPasswordValidationContext context) + { + if (!_currentContext.HttpContext.Request.Headers.ContainsKey("Auth-Email")) + { + return false; + } + else + { + try + { + var authEmailHeader = _currentContext.HttpContext.Request.Headers["Auth-Email"]; + var authEmailDecoded = CoreHelpers.Base64UrlDecodeString(authEmailHeader); + + if (authEmailDecoded != context.UserName) { - context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Captcha required.", - new Dictionary - { - { _captchaValidationService.SiteKeyResponseKeyName, _captchaValidationService.SiteKey }, - }); - return; - } - - validatorContext.CaptchaResponse = await _captchaValidationService.ValidateCaptchaResponseAsync( - captchaResponse, _currentContext.IpAddress, user); - if (!validatorContext.CaptchaResponse.Success) - { - await BuildErrorResultAsync("Captcha is invalid. Please refresh and try again", false, context, null); - return; - } - bypassToken = _captchaValidationService.GenerateCaptchaBypassToken(user); - } - - await ValidateAsync(context, context.Request, validatorContext); - if (context.Result.CustomResponse != null && bypassToken != null) - { - context.Result.CustomResponse["CaptchaBypassToken"] = bypassToken; - } - } - - protected async override Task ValidateContextAsync(ResourceOwnerPasswordValidationContext context, - CustomValidatorRequestContext validatorContext) - { - if (string.IsNullOrWhiteSpace(context.UserName) || validatorContext.User == null) - { - return false; - } - - if (!await _userService.CheckPasswordAsync(validatorContext.User, context.Password)) - { - return false; - } - - return true; - } - - protected override Task SetSuccessResult(ResourceOwnerPasswordValidationContext context, User user, - List claims, Dictionary customResponse) - { - context.Result = new GrantValidationResult(user.Id.ToString(), "Application", - identityProvider: "bitwarden", - claims: claims.Count > 0 ? claims : null, - customResponse: customResponse); - return Task.CompletedTask; - } - - protected override void SetTwoFactorResult(ResourceOwnerPasswordValidationContext context, - Dictionary customResponse) - { - context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Two factor required.", - customResponse); - } - - protected override void SetSsoResult(ResourceOwnerPasswordValidationContext context, - Dictionary customResponse) - { - context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Sso authentication required.", - customResponse); - } - - protected override void SetErrorResult(ResourceOwnerPasswordValidationContext context, - Dictionary customResponse) - { - context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, customResponse: customResponse); - } - - private bool AuthEmailHeaderIsValid(ResourceOwnerPasswordValidationContext context) - { - if (!_currentContext.HttpContext.Request.Headers.ContainsKey("Auth-Email")) - { - return false; - } - else - { - try - { - var authEmailHeader = _currentContext.HttpContext.Request.Headers["Auth-Email"]; - var authEmailDecoded = CoreHelpers.Base64UrlDecodeString(authEmailHeader); - - if (authEmailDecoded != context.UserName) - { - return false; - } - } - catch (System.Exception e) when (e is System.InvalidOperationException || e is System.FormatException) - { - // Invalid B64 encoding return false; } } - - return true; + catch (System.Exception e) when (e is System.InvalidOperationException || e is System.FormatException) + { + // Invalid B64 encoding + return false; + } } + + return true; } } diff --git a/src/Core/IdentityServer/StaticClientStore.cs b/src/Core/IdentityServer/StaticClientStore.cs index 60bff26e7..92c124f26 100644 --- a/src/Core/IdentityServer/StaticClientStore.cs +++ b/src/Core/IdentityServer/StaticClientStore.cs @@ -2,23 +2,22 @@ using Bit.Core.Settings; using IdentityServer4.Models; -namespace Bit.Core.IdentityServer -{ - public class StaticClientStore - { - public StaticClientStore(GlobalSettings globalSettings) - { - ApiClients = new List - { - new ApiClient(globalSettings, BitwardenClient.Mobile, 90, 1), - new ApiClient(globalSettings, BitwardenClient.Web, 30, 1), - new ApiClient(globalSettings, BitwardenClient.Browser, 30, 1), - new ApiClient(globalSettings, BitwardenClient.Desktop, 30, 1), - new ApiClient(globalSettings, BitwardenClient.Cli, 30, 1), - new ApiClient(globalSettings, BitwardenClient.DirectoryConnector, 30, 24) - }.ToDictionary(c => c.ClientId); - } +namespace Bit.Core.IdentityServer; - public IDictionary ApiClients { get; private set; } +public class StaticClientStore +{ + public StaticClientStore(GlobalSettings globalSettings) + { + ApiClients = new List + { + new ApiClient(globalSettings, BitwardenClient.Mobile, 90, 1), + new ApiClient(globalSettings, BitwardenClient.Web, 30, 1), + new ApiClient(globalSettings, BitwardenClient.Browser, 30, 1), + new ApiClient(globalSettings, BitwardenClient.Desktop, 30, 1), + new ApiClient(globalSettings, BitwardenClient.Cli, 30, 1), + new ApiClient(globalSettings, BitwardenClient.DirectoryConnector, 30, 24) + }.ToDictionary(c => c.ClientId); } + + public IDictionary ApiClients { get; private set; } } diff --git a/src/Core/IdentityServer/TokenRetrieval.cs b/src/Core/IdentityServer/TokenRetrieval.cs index 7290576f0..8c8ecfbc4 100644 --- a/src/Core/IdentityServer/TokenRetrieval.cs +++ b/src/Core/IdentityServer/TokenRetrieval.cs @@ -1,30 +1,29 @@ using Microsoft.AspNetCore.Http; -namespace Bit.Core.IdentityServer +namespace Bit.Core.IdentityServer; + +public static class TokenRetrieval { - public static class TokenRetrieval + private static string _headerScheme = "Bearer "; + private static string _queuryScheme = "access_token"; + private static string _authHeader = "Authorization"; + + public static Func FromAuthorizationHeaderOrQueryString() { - private static string _headerScheme = "Bearer "; - private static string _queuryScheme = "access_token"; - private static string _authHeader = "Authorization"; - - public static Func FromAuthorizationHeaderOrQueryString() + return (request) => { - return (request) => + var authorization = request.Headers[_authHeader].FirstOrDefault(); + if (string.IsNullOrWhiteSpace(authorization)) { - var authorization = request.Headers[_authHeader].FirstOrDefault(); - if (string.IsNullOrWhiteSpace(authorization)) - { - return request.Query[_queuryScheme].FirstOrDefault(); - } + return request.Query[_queuryScheme].FirstOrDefault(); + } - if (authorization.StartsWith(_headerScheme, StringComparison.OrdinalIgnoreCase)) - { - return authorization.Substring(_headerScheme.Length).Trim(); - } + if (authorization.StartsWith(_headerScheme, StringComparison.OrdinalIgnoreCase)) + { + return authorization.Substring(_headerScheme.Length).Trim(); + } - return null; - }; - } + return null; + }; } } diff --git a/src/Core/IdentityServer/VaultCorsPolicyService.cs b/src/Core/IdentityServer/VaultCorsPolicyService.cs index 42b76135e..49abcb4aa 100644 --- a/src/Core/IdentityServer/VaultCorsPolicyService.cs +++ b/src/Core/IdentityServer/VaultCorsPolicyService.cs @@ -2,20 +2,19 @@ using Bit.Core.Utilities; using IdentityServer4.Services; -namespace Bit.Core.IdentityServer +namespace Bit.Core.IdentityServer; + +public class CustomCorsPolicyService : ICorsPolicyService { - public class CustomCorsPolicyService : ICorsPolicyService + private readonly GlobalSettings _globalSettings; + + public CustomCorsPolicyService(GlobalSettings globalSettings) { - private readonly GlobalSettings _globalSettings; + _globalSettings = globalSettings; + } - public CustomCorsPolicyService(GlobalSettings globalSettings) - { - _globalSettings = globalSettings; - } - - public Task IsOriginAllowedAsync(string origin) - { - return Task.FromResult(CoreHelpers.IsCorsOriginAllowed(origin, _globalSettings)); - } + public Task IsOriginAllowedAsync(string origin) + { + return Task.FromResult(CoreHelpers.IsCorsOriginAllowed(origin, _globalSettings)); } } diff --git a/src/Core/Jobs/BaseJob.cs b/src/Core/Jobs/BaseJob.cs index c1eb7d264..56c39014a 100644 --- a/src/Core/Jobs/BaseJob.cs +++ b/src/Core/Jobs/BaseJob.cs @@ -1,29 +1,28 @@ using Microsoft.Extensions.Logging; using Quartz; -namespace Bit.Core.Jobs +namespace Bit.Core.Jobs; + +public abstract class BaseJob : IJob { - public abstract class BaseJob : IJob + protected readonly ILogger _logger; + + public BaseJob(ILogger logger) { - protected readonly ILogger _logger; - - public BaseJob(ILogger logger) - { - _logger = logger; - } - - public async Task Execute(IJobExecutionContext context) - { - try - { - await ExecuteJobAsync(context); - } - catch (Exception e) - { - _logger.LogError(2, e, "Error performing {0}.", GetType().Name); - } - } - - protected abstract Task ExecuteJobAsync(IJobExecutionContext context); + _logger = logger; } + + public async Task Execute(IJobExecutionContext context) + { + try + { + await ExecuteJobAsync(context); + } + catch (Exception e) + { + _logger.LogError(2, e, "Error performing {0}.", GetType().Name); + } + } + + protected abstract Task ExecuteJobAsync(IJobExecutionContext context); } diff --git a/src/Core/Jobs/BaseJobsHostedService.cs b/src/Core/Jobs/BaseJobsHostedService.cs index c9d2bda1c..897a382a2 100644 --- a/src/Core/Jobs/BaseJobsHostedService.cs +++ b/src/Core/Jobs/BaseJobsHostedService.cs @@ -6,146 +6,145 @@ using Quartz; using Quartz.Impl; using Quartz.Impl.Matchers; -namespace Bit.Core.Jobs +namespace Bit.Core.Jobs; + +public abstract class BaseJobsHostedService : IHostedService, IDisposable { - public abstract class BaseJobsHostedService : IHostedService, IDisposable + private const int MaximumJobRetries = 10; + + private readonly IServiceProvider _serviceProvider; + private readonly ILogger _listenerLogger; + protected readonly ILogger _logger; + + private IScheduler _scheduler; + protected GlobalSettings _globalSettings; + + public BaseJobsHostedService( + GlobalSettings globalSettings, + IServiceProvider serviceProvider, + ILogger logger, + ILogger listenerLogger) { - private const int MaximumJobRetries = 10; + _serviceProvider = serviceProvider; + _logger = logger; + _listenerLogger = listenerLogger; + _globalSettings = globalSettings; + } - private readonly IServiceProvider _serviceProvider; - private readonly ILogger _listenerLogger; - protected readonly ILogger _logger; + public IEnumerable> Jobs { get; protected set; } - private IScheduler _scheduler; - protected GlobalSettings _globalSettings; - - public BaseJobsHostedService( - GlobalSettings globalSettings, - IServiceProvider serviceProvider, - ILogger logger, - ILogger listenerLogger) + public virtual async Task StartAsync(CancellationToken cancellationToken) + { + var props = new NameValueCollection { - _serviceProvider = serviceProvider; - _logger = logger; - _listenerLogger = listenerLogger; - _globalSettings = globalSettings; + {"quartz.serializer.type", "binary"}, + }; + + if (!string.IsNullOrEmpty(_globalSettings.SqlServer.JobSchedulerConnectionString)) + { + // Ensure each project has a unique instanceName + props.Add("quartz.scheduler.instanceName", GetType().FullName); + props.Add("quartz.scheduler.instanceId", "AUTO"); + props.Add("quartz.jobStore.type", "Quartz.Impl.AdoJobStore.JobStoreTX"); + props.Add("quartz.jobStore.driverDelegateType", "Quartz.Impl.AdoJobStore.SqlServerDelegate"); + props.Add("quartz.jobStore.useProperties", "true"); + props.Add("quartz.jobStore.dataSource", "default"); + props.Add("quartz.jobStore.tablePrefix", "QRTZ_"); + props.Add("quartz.jobStore.clustered", "true"); + props.Add("quartz.dataSource.default.provider", "SqlServer"); + props.Add("quartz.dataSource.default.connectionString", _globalSettings.SqlServer.JobSchedulerConnectionString); } - public IEnumerable> Jobs { get; protected set; } - - public virtual async Task StartAsync(CancellationToken cancellationToken) + var factory = new StdSchedulerFactory(props); + _scheduler = await factory.GetScheduler(cancellationToken); + _scheduler.JobFactory = new JobFactory(_serviceProvider); + _scheduler.ListenerManager.AddJobListener(new JobListener(_listenerLogger), + GroupMatcher.AnyGroup()); + await _scheduler.Start(cancellationToken); + if (Jobs != null) { - var props = new NameValueCollection + foreach (var (job, trigger) in Jobs) { - {"quartz.serializer.type", "binary"}, - }; - - if (!string.IsNullOrEmpty(_globalSettings.SqlServer.JobSchedulerConnectionString)) - { - // Ensure each project has a unique instanceName - props.Add("quartz.scheduler.instanceName", GetType().FullName); - props.Add("quartz.scheduler.instanceId", "AUTO"); - props.Add("quartz.jobStore.type", "Quartz.Impl.AdoJobStore.JobStoreTX"); - props.Add("quartz.jobStore.driverDelegateType", "Quartz.Impl.AdoJobStore.SqlServerDelegate"); - props.Add("quartz.jobStore.useProperties", "true"); - props.Add("quartz.jobStore.dataSource", "default"); - props.Add("quartz.jobStore.tablePrefix", "QRTZ_"); - props.Add("quartz.jobStore.clustered", "true"); - props.Add("quartz.dataSource.default.provider", "SqlServer"); - props.Add("quartz.dataSource.default.connectionString", _globalSettings.SqlServer.JobSchedulerConnectionString); - } - - var factory = new StdSchedulerFactory(props); - _scheduler = await factory.GetScheduler(cancellationToken); - _scheduler.JobFactory = new JobFactory(_serviceProvider); - _scheduler.ListenerManager.AddJobListener(new JobListener(_listenerLogger), - GroupMatcher.AnyGroup()); - await _scheduler.Start(cancellationToken); - if (Jobs != null) - { - foreach (var (job, trigger) in Jobs) + for (var retry = 0; retry < MaximumJobRetries; retry++) { - for (var retry = 0; retry < MaximumJobRetries; retry++) + // There's a race condition when starting multiple containers simultaneously, retry until it succeeds.. + try { - // There's a race condition when starting multiple containers simultaneously, retry until it succeeds.. - try + var dupeT = await _scheduler.GetTrigger(trigger.Key); + if (dupeT != null) { - var dupeT = await _scheduler.GetTrigger(trigger.Key); - if (dupeT != null) - { - await _scheduler.RescheduleJob(trigger.Key, trigger); - } - - var jobDetail = JobBuilder.Create(job) - .WithIdentity(job.FullName) - .Build(); - - var dupeJ = await _scheduler.GetJobDetail(jobDetail.Key); - if (dupeJ != null) - { - await _scheduler.DeleteJob(jobDetail.Key); - } - - await _scheduler.ScheduleJob(jobDetail, trigger); - break; + await _scheduler.RescheduleJob(trigger.Key, trigger); } - catch (Exception e) + + var jobDetail = JobBuilder.Create(job) + .WithIdentity(job.FullName) + .Build(); + + var dupeJ = await _scheduler.GetJobDetail(jobDetail.Key); + if (dupeJ != null) { - if (retry == MaximumJobRetries - 1) - { - throw new Exception("Job failed to start after 10 retries."); - } - - _logger.LogWarning($"Exception while trying to schedule job: {job.FullName}, {e}"); - var random = new Random(); - Thread.Sleep(random.Next(50, 250)); + await _scheduler.DeleteJob(jobDetail.Key); } + + await _scheduler.ScheduleJob(jobDetail, trigger); + break; + } + catch (Exception e) + { + if (retry == MaximumJobRetries - 1) + { + throw new Exception("Job failed to start after 10 retries."); + } + + _logger.LogWarning($"Exception while trying to schedule job: {job.FullName}, {e}"); + var random = new Random(); + Thread.Sleep(random.Next(50, 250)); } } } - - // Delete old Jobs and Triggers - var existingJobKeys = await _scheduler.GetJobKeys(GroupMatcher.AnyGroup()); - var jobKeys = Jobs.Select(j => - { - var job = j.Item1; - return JobBuilder.Create(job) - .WithIdentity(job.FullName) - .Build().Key; - }); - - foreach (var key in existingJobKeys) - { - if (jobKeys.Contains(key)) - { - continue; - } - - _logger.LogInformation($"Deleting old job with key {key}"); - await _scheduler.DeleteJob(key); - } - - var existingTriggerKeys = await _scheduler.GetTriggerKeys(GroupMatcher.AnyGroup()); - var triggerKeys = Jobs.Select(j => j.Item2.Key); - - foreach (var key in existingTriggerKeys) - { - if (triggerKeys.Contains(key)) - { - continue; - } - - _logger.LogInformation($"Unscheduling old trigger with key {key}"); - await _scheduler.UnscheduleJob(key); - } } - public virtual async Task StopAsync(CancellationToken cancellationToken) + // Delete old Jobs and Triggers + var existingJobKeys = await _scheduler.GetJobKeys(GroupMatcher.AnyGroup()); + var jobKeys = Jobs.Select(j => { - await _scheduler?.Shutdown(cancellationToken); + var job = j.Item1; + return JobBuilder.Create(job) + .WithIdentity(job.FullName) + .Build().Key; + }); + + foreach (var key in existingJobKeys) + { + if (jobKeys.Contains(key)) + { + continue; + } + + _logger.LogInformation($"Deleting old job with key {key}"); + await _scheduler.DeleteJob(key); } - public virtual void Dispose() - { } + var existingTriggerKeys = await _scheduler.GetTriggerKeys(GroupMatcher.AnyGroup()); + var triggerKeys = Jobs.Select(j => j.Item2.Key); + + foreach (var key in existingTriggerKeys) + { + if (triggerKeys.Contains(key)) + { + continue; + } + + _logger.LogInformation($"Unscheduling old trigger with key {key}"); + await _scheduler.UnscheduleJob(key); + } } + + public virtual async Task StopAsync(CancellationToken cancellationToken) + { + await _scheduler?.Shutdown(cancellationToken); + } + + public virtual void Dispose() + { } } diff --git a/src/Core/Jobs/JobFactory.cs b/src/Core/Jobs/JobFactory.cs index 00cf63b26..ee95c6b2d 100644 --- a/src/Core/Jobs/JobFactory.cs +++ b/src/Core/Jobs/JobFactory.cs @@ -1,26 +1,25 @@ using Quartz; using Quartz.Spi; -namespace Bit.Core.Jobs +namespace Bit.Core.Jobs; + +public class JobFactory : IJobFactory { - public class JobFactory : IJobFactory + private readonly IServiceProvider _container; + + public JobFactory(IServiceProvider container) { - private readonly IServiceProvider _container; + _container = container; + } - public JobFactory(IServiceProvider container) - { - _container = container; - } + public IJob NewJob(TriggerFiredBundle bundle, IScheduler scheduler) + { + return _container.GetService(bundle.JobDetail.JobType) as IJob; + } - public IJob NewJob(TriggerFiredBundle bundle, IScheduler scheduler) - { - return _container.GetService(bundle.JobDetail.JobType) as IJob; - } - - public void ReturnJob(IJob job) - { - var disposable = job as IDisposable; - disposable?.Dispose(); - } + public void ReturnJob(IJob job) + { + var disposable = job as IDisposable; + disposable?.Dispose(); } } diff --git a/src/Core/Jobs/JobListener.cs b/src/Core/Jobs/JobListener.cs index 8fb56828e..e5e05e4b6 100644 --- a/src/Core/Jobs/JobListener.cs +++ b/src/Core/Jobs/JobListener.cs @@ -1,39 +1,38 @@ using Microsoft.Extensions.Logging; using Quartz; -namespace Bit.Core.Jobs +namespace Bit.Core.Jobs; + +public class JobListener : IJobListener { - public class JobListener : IJobListener + private readonly ILogger _logger; + + public JobListener(ILogger logger) { - private readonly ILogger _logger; + _logger = logger; + } - public JobListener(ILogger logger) - { - _logger = logger; - } + public string Name => "JobListener"; - public string Name => "JobListener"; + public Task JobExecutionVetoed(IJobExecutionContext context, + CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.FromResult(0); + } - public Task JobExecutionVetoed(IJobExecutionContext context, - CancellationToken cancellationToken = default(CancellationToken)) - { - return Task.FromResult(0); - } + public Task JobToBeExecuted(IJobExecutionContext context, + CancellationToken cancellationToken = default(CancellationToken)) + { + _logger.LogInformation(Constants.BypassFiltersEventId, null, "Starting job {0} at {1}.", + context.JobDetail.JobType.Name, DateTime.UtcNow); + return Task.FromResult(0); + } - public Task JobToBeExecuted(IJobExecutionContext context, - CancellationToken cancellationToken = default(CancellationToken)) - { - _logger.LogInformation(Constants.BypassFiltersEventId, null, "Starting job {0} at {1}.", - context.JobDetail.JobType.Name, DateTime.UtcNow); - return Task.FromResult(0); - } - - public Task JobWasExecuted(IJobExecutionContext context, JobExecutionException jobException, - CancellationToken cancellationToken = default(CancellationToken)) - { - _logger.LogInformation(Constants.BypassFiltersEventId, null, "Finished job {0} at {1}.", - context.JobDetail.JobType.Name, DateTime.UtcNow); - return Task.FromResult(0); - } + public Task JobWasExecuted(IJobExecutionContext context, JobExecutionException jobException, + CancellationToken cancellationToken = default(CancellationToken)) + { + _logger.LogInformation(Constants.BypassFiltersEventId, null, "Finished job {0} at {1}.", + context.JobDetail.JobType.Name, DateTime.UtcNow); + return Task.FromResult(0); } } diff --git a/src/Core/Models/Api/Request/Accounts/KeysRequestModel.cs b/src/Core/Models/Api/Request/Accounts/KeysRequestModel.cs index 77015a96e..18e6c1f5e 100644 --- a/src/Core/Models/Api/Request/Accounts/KeysRequestModel.cs +++ b/src/Core/Models/Api/Request/Accounts/KeysRequestModel.cs @@ -1,27 +1,26 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Entities; -namespace Bit.Core.Models.Api.Request.Accounts +namespace Bit.Core.Models.Api.Request.Accounts; + +public class KeysRequestModel { - public class KeysRequestModel + public string PublicKey { get; set; } + [Required] + public string EncryptedPrivateKey { get; set; } + + public User ToUser(User existingUser) { - public string PublicKey { get; set; } - [Required] - public string EncryptedPrivateKey { get; set; } - - public User ToUser(User existingUser) + if (string.IsNullOrWhiteSpace(existingUser.PublicKey) && !string.IsNullOrWhiteSpace(PublicKey)) { - if (string.IsNullOrWhiteSpace(existingUser.PublicKey) && !string.IsNullOrWhiteSpace(PublicKey)) - { - existingUser.PublicKey = PublicKey; - } - - if (string.IsNullOrWhiteSpace(existingUser.PrivateKey)) - { - existingUser.PrivateKey = EncryptedPrivateKey; - } - - return existingUser; + existingUser.PublicKey = PublicKey; } + + if (string.IsNullOrWhiteSpace(existingUser.PrivateKey)) + { + existingUser.PrivateKey = EncryptedPrivateKey; + } + + return existingUser; } } diff --git a/src/Core/Models/Api/Request/Accounts/PreloginRequestModel.cs b/src/Core/Models/Api/Request/Accounts/PreloginRequestModel.cs index dca9e08bf..43a391ab9 100644 --- a/src/Core/Models/Api/Request/Accounts/PreloginRequestModel.cs +++ b/src/Core/Models/Api/Request/Accounts/PreloginRequestModel.cs @@ -1,12 +1,11 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Models.Api.Request.Accounts +namespace Bit.Core.Models.Api.Request.Accounts; + +public class PreloginRequestModel { - public class PreloginRequestModel - { - [Required] - [EmailAddress] - [StringLength(256)] - public string Email { get; set; } - } + [Required] + [EmailAddress] + [StringLength(256)] + public string Email { get; set; } } diff --git a/src/Core/Models/Api/Request/Accounts/RegisterRequestModel.cs b/src/Core/Models/Api/Request/Accounts/RegisterRequestModel.cs index 2b7c36a89..eac394b11 100644 --- a/src/Core/Models/Api/Request/Accounts/RegisterRequestModel.cs +++ b/src/Core/Models/Api/Request/Accounts/RegisterRequestModel.cs @@ -4,74 +4,73 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.Models.Api.Request.Accounts +namespace Bit.Core.Models.Api.Request.Accounts; + +public class RegisterRequestModel : IValidatableObject, ICaptchaProtectedModel { - public class RegisterRequestModel : IValidatableObject, ICaptchaProtectedModel + [StringLength(50)] + public string Name { get; set; } + [Required] + [StrictEmailAddress] + [StringLength(256)] + public string Email { get; set; } + [Required] + [StringLength(1000)] + public string MasterPasswordHash { get; set; } + [StringLength(50)] + public string MasterPasswordHint { get; set; } + public string CaptchaResponse { get; set; } + public string Key { get; set; } + public KeysRequestModel Keys { get; set; } + public string Token { get; set; } + public Guid? OrganizationUserId { get; set; } + public KdfType? Kdf { get; set; } + public int? KdfIterations { get; set; } + public Dictionary ReferenceData { get; set; } + + public User ToUser() { - [StringLength(50)] - public string Name { get; set; } - [Required] - [StrictEmailAddress] - [StringLength(256)] - public string Email { get; set; } - [Required] - [StringLength(1000)] - public string MasterPasswordHash { get; set; } - [StringLength(50)] - public string MasterPasswordHint { get; set; } - public string CaptchaResponse { get; set; } - public string Key { get; set; } - public KeysRequestModel Keys { get; set; } - public string Token { get; set; } - public Guid? OrganizationUserId { get; set; } - public KdfType? Kdf { get; set; } - public int? KdfIterations { get; set; } - public Dictionary ReferenceData { get; set; } - - public User ToUser() + var user = new User { - var user = new User - { - Name = Name, - Email = Email, - MasterPasswordHint = MasterPasswordHint, - Kdf = Kdf.GetValueOrDefault(KdfType.PBKDF2_SHA256), - KdfIterations = KdfIterations.GetValueOrDefault(5000), - }; + Name = Name, + Email = Email, + MasterPasswordHint = MasterPasswordHint, + Kdf = Kdf.GetValueOrDefault(KdfType.PBKDF2_SHA256), + KdfIterations = KdfIterations.GetValueOrDefault(5000), + }; - if (ReferenceData != null) - { - user.ReferenceData = JsonSerializer.Serialize(ReferenceData); - } - - if (Key != null) - { - user.Key = Key; - } - - if (Keys != null) - { - Keys.ToUser(user); - } - - return user; + if (ReferenceData != null) + { + user.ReferenceData = JsonSerializer.Serialize(ReferenceData); } - public IEnumerable Validate(ValidationContext validationContext) + if (Key != null) { - if (Kdf.HasValue && KdfIterations.HasValue) + user.Key = Key; + } + + if (Keys != null) + { + Keys.ToUser(user); + } + + return user; + } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (Kdf.HasValue && KdfIterations.HasValue) + { + switch (Kdf.Value) { - switch (Kdf.Value) - { - case KdfType.PBKDF2_SHA256: - if (KdfIterations.Value < 5000 || KdfIterations.Value > 1_000_000) - { - yield return new ValidationResult("KDF iterations must be between 5000 and 1000000."); - } - break; - default: - break; - } + case KdfType.PBKDF2_SHA256: + if (KdfIterations.Value < 5000 || KdfIterations.Value > 1_000_000) + { + yield return new ValidationResult("KDF iterations must be between 5000 and 1000000."); + } + break; + default: + break; } } } diff --git a/src/Core/Models/Api/Request/ICaptchaProtectedModel.cs b/src/Core/Models/Api/Request/ICaptchaProtectedModel.cs index 9084ecc89..f1c9771d1 100644 --- a/src/Core/Models/Api/Request/ICaptchaProtectedModel.cs +++ b/src/Core/Models/Api/Request/ICaptchaProtectedModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Models.Api +namespace Bit.Core.Models.Api; + +public interface ICaptchaProtectedModel { - public interface ICaptchaProtectedModel - { - string CaptchaResponse { get; set; } - } + string CaptchaResponse { get; set; } } diff --git a/src/Core/Models/Api/Request/OrganizationSponsorships/OrganizationSponsorshipRequestModel.cs b/src/Core/Models/Api/Request/OrganizationSponsorships/OrganizationSponsorshipRequestModel.cs index 7440e7ba3..8be4a672d 100644 --- a/src/Core/Models/Api/Request/OrganizationSponsorships/OrganizationSponsorshipRequestModel.cs +++ b/src/Core/Models/Api/Request/OrganizationSponsorships/OrganizationSponsorshipRequestModel.cs @@ -2,55 +2,54 @@ using Bit.Core.Enums; using Bit.Core.Models.Data.Organizations.OrganizationSponsorships; -namespace Bit.Core.Models.Api.Request.OrganizationSponsorships +namespace Bit.Core.Models.Api.Request.OrganizationSponsorships; + +public class OrganizationSponsorshipRequestModel { - public class OrganizationSponsorshipRequestModel + public Guid SponsoringOrganizationUserId { get; set; } + public string FriendlyName { get; set; } + public string OfferedToEmail { get; set; } + public PlanSponsorshipType PlanSponsorshipType { get; set; } + public DateTime? LastSyncDate { get; set; } + public DateTime? ValidUntil { get; set; } + public bool ToDelete { get; set; } + + public OrganizationSponsorshipRequestModel() { } + + public OrganizationSponsorshipRequestModel(OrganizationSponsorshipData sponsorshipData) { - public Guid SponsoringOrganizationUserId { get; set; } - public string FriendlyName { get; set; } - public string OfferedToEmail { get; set; } - public PlanSponsorshipType PlanSponsorshipType { get; set; } - public DateTime? LastSyncDate { get; set; } - public DateTime? ValidUntil { get; set; } - public bool ToDelete { get; set; } + SponsoringOrganizationUserId = sponsorshipData.SponsoringOrganizationUserId; + FriendlyName = sponsorshipData.FriendlyName; + OfferedToEmail = sponsorshipData.OfferedToEmail; + PlanSponsorshipType = sponsorshipData.PlanSponsorshipType; + LastSyncDate = sponsorshipData.LastSyncDate; + ValidUntil = sponsorshipData.ValidUntil; + ToDelete = sponsorshipData.ToDelete; + } - public OrganizationSponsorshipRequestModel() { } + public OrganizationSponsorshipRequestModel(OrganizationSponsorship sponsorship) + { + SponsoringOrganizationUserId = sponsorship.SponsoringOrganizationUserId; + FriendlyName = sponsorship.FriendlyName; + OfferedToEmail = sponsorship.OfferedToEmail; + PlanSponsorshipType = sponsorship.PlanSponsorshipType.GetValueOrDefault(); + LastSyncDate = sponsorship.LastSyncDate; + ValidUntil = sponsorship.ValidUntil; + ToDelete = sponsorship.ToDelete; + } - public OrganizationSponsorshipRequestModel(OrganizationSponsorshipData sponsorshipData) + public OrganizationSponsorshipData ToOrganizationSponsorship() + { + return new OrganizationSponsorshipData { - SponsoringOrganizationUserId = sponsorshipData.SponsoringOrganizationUserId; - FriendlyName = sponsorshipData.FriendlyName; - OfferedToEmail = sponsorshipData.OfferedToEmail; - PlanSponsorshipType = sponsorshipData.PlanSponsorshipType; - LastSyncDate = sponsorshipData.LastSyncDate; - ValidUntil = sponsorshipData.ValidUntil; - ToDelete = sponsorshipData.ToDelete; - } + SponsoringOrganizationUserId = SponsoringOrganizationUserId, + FriendlyName = FriendlyName, + OfferedToEmail = OfferedToEmail, + PlanSponsorshipType = PlanSponsorshipType, + LastSyncDate = LastSyncDate, + ValidUntil = ValidUntil, + ToDelete = ToDelete, + }; - public OrganizationSponsorshipRequestModel(OrganizationSponsorship sponsorship) - { - SponsoringOrganizationUserId = sponsorship.SponsoringOrganizationUserId; - FriendlyName = sponsorship.FriendlyName; - OfferedToEmail = sponsorship.OfferedToEmail; - PlanSponsorshipType = sponsorship.PlanSponsorshipType.GetValueOrDefault(); - LastSyncDate = sponsorship.LastSyncDate; - ValidUntil = sponsorship.ValidUntil; - ToDelete = sponsorship.ToDelete; - } - - public OrganizationSponsorshipData ToOrganizationSponsorship() - { - return new OrganizationSponsorshipData - { - SponsoringOrganizationUserId = SponsoringOrganizationUserId, - FriendlyName = FriendlyName, - OfferedToEmail = OfferedToEmail, - PlanSponsorshipType = PlanSponsorshipType, - LastSyncDate = LastSyncDate, - ValidUntil = ValidUntil, - ToDelete = ToDelete, - }; - - } } } diff --git a/src/Core/Models/Api/Request/OrganizationSponsorships/OrganizationSponsorshipSyncRequestModel.cs b/src/Core/Models/Api/Request/OrganizationSponsorships/OrganizationSponsorshipSyncRequestModel.cs index 9def44d60..283c07d19 100644 --- a/src/Core/Models/Api/Request/OrganizationSponsorships/OrganizationSponsorshipSyncRequestModel.cs +++ b/src/Core/Models/Api/Request/OrganizationSponsorships/OrganizationSponsorshipSyncRequestModel.cs @@ -1,40 +1,39 @@ using Bit.Core.Models.Data.Organizations.OrganizationSponsorships; -namespace Bit.Core.Models.Api.Request.OrganizationSponsorships +namespace Bit.Core.Models.Api.Request.OrganizationSponsorships; + +public class OrganizationSponsorshipSyncRequestModel { - public class OrganizationSponsorshipSyncRequestModel + public string BillingSyncKey { get; set; } + public Guid SponsoringOrganizationCloudId { get; set; } + public IEnumerable SponsorshipsBatch { get; set; } + + public OrganizationSponsorshipSyncRequestModel() { } + + public OrganizationSponsorshipSyncRequestModel(IEnumerable sponsorshipsBatch) { - public string BillingSyncKey { get; set; } - public Guid SponsoringOrganizationCloudId { get; set; } - public IEnumerable SponsorshipsBatch { get; set; } - - public OrganizationSponsorshipSyncRequestModel() { } - - public OrganizationSponsorshipSyncRequestModel(IEnumerable sponsorshipsBatch) - { - SponsorshipsBatch = sponsorshipsBatch; - } - - public OrganizationSponsorshipSyncRequestModel(OrganizationSponsorshipSyncData syncData) - { - if (syncData == null) - { - return; - } - BillingSyncKey = syncData.BillingSyncKey; - SponsoringOrganizationCloudId = syncData.SponsoringOrganizationCloudId; - SponsorshipsBatch = syncData.SponsorshipsBatch.Select(o => new OrganizationSponsorshipRequestModel(o)); - } - - public OrganizationSponsorshipSyncData ToOrganizationSponsorshipSync() - { - return new OrganizationSponsorshipSyncData() - { - BillingSyncKey = BillingSyncKey, - SponsoringOrganizationCloudId = SponsoringOrganizationCloudId, - SponsorshipsBatch = SponsorshipsBatch.Select(o => o.ToOrganizationSponsorship()) - }; - } - + SponsorshipsBatch = sponsorshipsBatch; } + + public OrganizationSponsorshipSyncRequestModel(OrganizationSponsorshipSyncData syncData) + { + if (syncData == null) + { + return; + } + BillingSyncKey = syncData.BillingSyncKey; + SponsoringOrganizationCloudId = syncData.SponsoringOrganizationCloudId; + SponsorshipsBatch = syncData.SponsorshipsBatch.Select(o => new OrganizationSponsorshipRequestModel(o)); + } + + public OrganizationSponsorshipSyncData ToOrganizationSponsorshipSync() + { + return new OrganizationSponsorshipSyncData() + { + BillingSyncKey = BillingSyncKey, + SponsoringOrganizationCloudId = SponsoringOrganizationCloudId, + SponsorshipsBatch = SponsorshipsBatch.Select(o => o.ToOrganizationSponsorship()) + }; + } + } diff --git a/src/Core/Models/Api/Request/PushRegistrationRequestModel.cs b/src/Core/Models/Api/Request/PushRegistrationRequestModel.cs index fd74b50af..580c1c3b6 100644 --- a/src/Core/Models/Api/Request/PushRegistrationRequestModel.cs +++ b/src/Core/Models/Api/Request/PushRegistrationRequestModel.cs @@ -1,19 +1,18 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Enums; -namespace Bit.Core.Models.Api +namespace Bit.Core.Models.Api; + +public class PushRegistrationRequestModel { - public class PushRegistrationRequestModel - { - [Required] - public string DeviceId { get; set; } - [Required] - public string PushToken { get; set; } - [Required] - public string UserId { get; set; } - [Required] - public DeviceType Type { get; set; } - [Required] - public string Identifier { get; set; } - } + [Required] + public string DeviceId { get; set; } + [Required] + public string PushToken { get; set; } + [Required] + public string UserId { get; set; } + [Required] + public DeviceType Type { get; set; } + [Required] + public string Identifier { get; set; } } diff --git a/src/Core/Models/Api/Request/PushSendRequestModel.cs b/src/Core/Models/Api/Request/PushSendRequestModel.cs index 108db5804..b85c8fb55 100644 --- a/src/Core/Models/Api/Request/PushSendRequestModel.cs +++ b/src/Core/Models/Api/Request/PushSendRequestModel.cs @@ -1,25 +1,24 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Enums; -namespace Bit.Core.Models.Api -{ - public class PushSendRequestModel : IValidatableObject - { - public string UserId { get; set; } - public string OrganizationId { get; set; } - public string DeviceId { get; set; } - public string Identifier { get; set; } - [Required] - public PushType? Type { get; set; } - [Required] - public object Payload { get; set; } +namespace Bit.Core.Models.Api; - public IEnumerable Validate(ValidationContext validationContext) +public class PushSendRequestModel : IValidatableObject +{ + public string UserId { get; set; } + public string OrganizationId { get; set; } + public string DeviceId { get; set; } + public string Identifier { get; set; } + [Required] + public PushType? Type { get; set; } + [Required] + public object Payload { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (string.IsNullOrWhiteSpace(UserId) && string.IsNullOrWhiteSpace(OrganizationId)) { - if (string.IsNullOrWhiteSpace(UserId) && string.IsNullOrWhiteSpace(OrganizationId)) - { - yield return new ValidationResult($"{nameof(UserId)} or {nameof(OrganizationId)} is required."); - } + yield return new ValidationResult($"{nameof(UserId)} or {nameof(OrganizationId)} is required."); } } } diff --git a/src/Core/Models/Api/Request/PushUpdateRequestModel.cs b/src/Core/Models/Api/Request/PushUpdateRequestModel.cs index ba5c3bf96..2ccbf6eb0 100644 --- a/src/Core/Models/Api/Request/PushUpdateRequestModel.cs +++ b/src/Core/Models/Api/Request/PushUpdateRequestModel.cs @@ -1,21 +1,20 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Models.Api +namespace Bit.Core.Models.Api; + +public class PushUpdateRequestModel { - public class PushUpdateRequestModel + public PushUpdateRequestModel() + { } + + public PushUpdateRequestModel(IEnumerable deviceIds, string organizationId) { - public PushUpdateRequestModel() - { } - - public PushUpdateRequestModel(IEnumerable deviceIds, string organizationId) - { - DeviceIds = deviceIds; - OrganizationId = organizationId; - } - - [Required] - public IEnumerable DeviceIds { get; set; } - [Required] - public string OrganizationId { get; set; } + DeviceIds = deviceIds; + OrganizationId = organizationId; } + + [Required] + public IEnumerable DeviceIds { get; set; } + [Required] + public string OrganizationId { get; set; } } diff --git a/src/Core/Models/Api/Response/Accounts/PreloginResponseModel.cs b/src/Core/Models/Api/Response/Accounts/PreloginResponseModel.cs index 755182f76..9fb2de7de 100644 --- a/src/Core/Models/Api/Response/Accounts/PreloginResponseModel.cs +++ b/src/Core/Models/Api/Response/Accounts/PreloginResponseModel.cs @@ -1,17 +1,16 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; -namespace Bit.Core.Models.Api.Response.Accounts -{ - public class PreloginResponseModel - { - public PreloginResponseModel(UserKdfInformation kdfInformation) - { - Kdf = kdfInformation.Kdf; - KdfIterations = kdfInformation.KdfIterations; - } +namespace Bit.Core.Models.Api.Response.Accounts; - public KdfType Kdf { get; set; } - public int KdfIterations { get; set; } +public class PreloginResponseModel +{ + public PreloginResponseModel(UserKdfInformation kdfInformation) + { + Kdf = kdfInformation.Kdf; + KdfIterations = kdfInformation.KdfIterations; } + + public KdfType Kdf { get; set; } + public int KdfIterations { get; set; } } diff --git a/src/Core/Models/Api/Response/ErrorResponseModel.cs b/src/Core/Models/Api/Response/ErrorResponseModel.cs index e7f77099c..39d6adddb 100644 --- a/src/Core/Models/Api/Response/ErrorResponseModel.cs +++ b/src/Core/Models/Api/Response/ErrorResponseModel.cs @@ -1,74 +1,73 @@ using Microsoft.AspNetCore.Mvc.ModelBinding; -namespace Bit.Core.Models.Api +namespace Bit.Core.Models.Api; + +public class ErrorResponseModel : ResponseModel { - public class ErrorResponseModel : ResponseModel + public ErrorResponseModel() + : base("error") + { } + + public ErrorResponseModel(string message) + : this() { - public ErrorResponseModel() - : base("error") - { } - - public ErrorResponseModel(string message) - : this() - { - Message = message; - } - - public ErrorResponseModel(ModelStateDictionary modelState) - : this() - { - Message = "The model state is invalid."; - ValidationErrors = new Dictionary>(); - - var keys = modelState.Keys.ToList(); - var values = modelState.Values.ToList(); - - for (var i = 0; i < values.Count; i++) - { - var value = values[i]; - - if (keys.Count <= i) - { - // Keys not available for some reason. - break; - } - - var key = keys[i]; - - if (value.ValidationState != ModelValidationState.Invalid || value.Errors.Count == 0) - { - continue; - } - - var errors = value.Errors.Select(e => e.ErrorMessage); - ValidationErrors.Add(key, errors); - } - } - - public ErrorResponseModel(Dictionary> errors) - : this("Errors have occurred.", errors) - { } - - public ErrorResponseModel(string errorKey, string errorValue) - : this(errorKey, new string[] { errorValue }) - { } - - public ErrorResponseModel(string errorKey, IEnumerable errorValues) - : this(new Dictionary> { { errorKey, errorValues } }) - { } - - public ErrorResponseModel(string message, Dictionary> errors) - : this() - { - Message = message; - ValidationErrors = errors; - } - - public string Message { get; set; } - public Dictionary> ValidationErrors { get; set; } - // For use in development environments. - public string ExceptionMessage { get; set; } - public string ExceptionStackTrace { get; set; } - public string InnerExceptionMessage { get; set; } + Message = message; } + + public ErrorResponseModel(ModelStateDictionary modelState) + : this() + { + Message = "The model state is invalid."; + ValidationErrors = new Dictionary>(); + + var keys = modelState.Keys.ToList(); + var values = modelState.Values.ToList(); + + for (var i = 0; i < values.Count; i++) + { + var value = values[i]; + + if (keys.Count <= i) + { + // Keys not available for some reason. + break; + } + + var key = keys[i]; + + if (value.ValidationState != ModelValidationState.Invalid || value.Errors.Count == 0) + { + continue; + } + + var errors = value.Errors.Select(e => e.ErrorMessage); + ValidationErrors.Add(key, errors); + } + } + + public ErrorResponseModel(Dictionary> errors) + : this("Errors have occurred.", errors) + { } + + public ErrorResponseModel(string errorKey, string errorValue) + : this(errorKey, new string[] { errorValue }) + { } + + public ErrorResponseModel(string errorKey, IEnumerable errorValues) + : this(new Dictionary> { { errorKey, errorValues } }) + { } + + public ErrorResponseModel(string message, Dictionary> errors) + : this() + { + Message = message; + ValidationErrors = errors; + } + + public string Message { get; set; } + public Dictionary> ValidationErrors { get; set; } + // For use in development environments. + public string ExceptionMessage { get; set; } + public string ExceptionStackTrace { get; set; } + public string InnerExceptionMessage { get; set; } } diff --git a/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipResponseModel.cs b/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipResponseModel.cs index fc5fbc70d..58c1b2cff 100644 --- a/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipResponseModel.cs +++ b/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipResponseModel.cs @@ -1,48 +1,47 @@ using Bit.Core.Enums; using Bit.Core.Models.Data.Organizations.OrganizationSponsorships; -namespace Bit.Core.Models.Api.Response.OrganizationSponsorships +namespace Bit.Core.Models.Api.Response.OrganizationSponsorships; + +public class OrganizationSponsorshipResponseModel { - public class OrganizationSponsorshipResponseModel + public Guid SponsoringOrganizationUserId { get; set; } + public string FriendlyName { get; set; } + public string OfferedToEmail { get; set; } + public PlanSponsorshipType PlanSponsorshipType { get; set; } + public DateTime? LastSyncDate { get; set; } + public DateTime? ValidUntil { get; set; } + public bool ToDelete { get; set; } + + public bool CloudSponsorshipRemoved { get; set; } + + public OrganizationSponsorshipResponseModel() { } + + public OrganizationSponsorshipResponseModel(OrganizationSponsorshipData sponsorshipData) { - public Guid SponsoringOrganizationUserId { get; set; } - public string FriendlyName { get; set; } - public string OfferedToEmail { get; set; } - public PlanSponsorshipType PlanSponsorshipType { get; set; } - public DateTime? LastSyncDate { get; set; } - public DateTime? ValidUntil { get; set; } - public bool ToDelete { get; set; } + SponsoringOrganizationUserId = sponsorshipData.SponsoringOrganizationUserId; + FriendlyName = sponsorshipData.FriendlyName; + OfferedToEmail = sponsorshipData.OfferedToEmail; + PlanSponsorshipType = sponsorshipData.PlanSponsorshipType; + LastSyncDate = sponsorshipData.LastSyncDate; + ValidUntil = sponsorshipData.ValidUntil; + ToDelete = sponsorshipData.ToDelete; + CloudSponsorshipRemoved = sponsorshipData.CloudSponsorshipRemoved; + } - public bool CloudSponsorshipRemoved { get; set; } - - public OrganizationSponsorshipResponseModel() { } - - public OrganizationSponsorshipResponseModel(OrganizationSponsorshipData sponsorshipData) + public OrganizationSponsorshipData ToOrganizationSponsorship() + { + return new OrganizationSponsorshipData { - SponsoringOrganizationUserId = sponsorshipData.SponsoringOrganizationUserId; - FriendlyName = sponsorshipData.FriendlyName; - OfferedToEmail = sponsorshipData.OfferedToEmail; - PlanSponsorshipType = sponsorshipData.PlanSponsorshipType; - LastSyncDate = sponsorshipData.LastSyncDate; - ValidUntil = sponsorshipData.ValidUntil; - ToDelete = sponsorshipData.ToDelete; - CloudSponsorshipRemoved = sponsorshipData.CloudSponsorshipRemoved; - } + SponsoringOrganizationUserId = SponsoringOrganizationUserId, + FriendlyName = FriendlyName, + OfferedToEmail = OfferedToEmail, + PlanSponsorshipType = PlanSponsorshipType, + LastSyncDate = LastSyncDate, + ValidUntil = ValidUntil, + ToDelete = ToDelete, + CloudSponsorshipRemoved = CloudSponsorshipRemoved + }; - public OrganizationSponsorshipData ToOrganizationSponsorship() - { - return new OrganizationSponsorshipData - { - SponsoringOrganizationUserId = SponsoringOrganizationUserId, - FriendlyName = FriendlyName, - OfferedToEmail = OfferedToEmail, - PlanSponsorshipType = PlanSponsorshipType, - LastSyncDate = LastSyncDate, - ValidUntil = ValidUntil, - ToDelete = ToDelete, - CloudSponsorshipRemoved = CloudSponsorshipRemoved - }; - - } } } diff --git a/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipSyncResponseModel.cs b/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipSyncResponseModel.cs index 4d44ab165..5a6b635c5 100644 --- a/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipSyncResponseModel.cs +++ b/src/Core/Models/Api/Response/OrganizationSponsorships/OrganizationSponsorshipSyncResponseModel.cs @@ -1,30 +1,29 @@ using Bit.Core.Models.Data.Organizations.OrganizationSponsorships; -namespace Bit.Core.Models.Api.Response.OrganizationSponsorships +namespace Bit.Core.Models.Api.Response.OrganizationSponsorships; + +public class OrganizationSponsorshipSyncResponseModel { - public class OrganizationSponsorshipSyncResponseModel + public IEnumerable SponsorshipsBatch { get; set; } + + public OrganizationSponsorshipSyncResponseModel() { } + + public OrganizationSponsorshipSyncResponseModel(OrganizationSponsorshipSyncData syncData) { - public IEnumerable SponsorshipsBatch { get; set; } - - public OrganizationSponsorshipSyncResponseModel() { } - - public OrganizationSponsorshipSyncResponseModel(OrganizationSponsorshipSyncData syncData) + if (syncData == null) { - if (syncData == null) - { - return; - } - SponsorshipsBatch = syncData.SponsorshipsBatch.Select(o => new OrganizationSponsorshipResponseModel(o)); - - } - - public OrganizationSponsorshipSyncData ToOrganizationSponsorshipSync() - { - return new OrganizationSponsorshipSyncData() - { - SponsorshipsBatch = SponsorshipsBatch.Select(o => o.ToOrganizationSponsorship()) - }; + return; } + SponsorshipsBatch = syncData.SponsorshipsBatch.Select(o => new OrganizationSponsorshipResponseModel(o)); } + + public OrganizationSponsorshipSyncData ToOrganizationSponsorshipSync() + { + return new OrganizationSponsorshipSyncData() + { + SponsorshipsBatch = SponsorshipsBatch.Select(o => o.ToOrganizationSponsorship()) + }; + } + } diff --git a/src/Core/Models/Api/Response/ResponseModel.cs b/src/Core/Models/Api/Response/ResponseModel.cs index 539d52d10..22278b807 100644 --- a/src/Core/Models/Api/Response/ResponseModel.cs +++ b/src/Core/Models/Api/Response/ResponseModel.cs @@ -1,20 +1,19 @@ using Newtonsoft.Json; -namespace Bit.Core.Models.Api -{ - public abstract class ResponseModel - { - public ResponseModel(string obj) - { - if (string.IsNullOrWhiteSpace(obj)) - { - throw new ArgumentNullException(nameof(obj)); - } +namespace Bit.Core.Models.Api; - Object = obj; +public abstract class ResponseModel +{ + public ResponseModel(string obj) + { + if (string.IsNullOrWhiteSpace(obj)) + { + throw new ArgumentNullException(nameof(obj)); } - [JsonProperty(Order = -200)] // Always the first property - public string Object { get; private set; } + Object = obj; } + + [JsonProperty(Order = -200)] // Always the first property + public string Object { get; private set; } } diff --git a/src/Core/Models/Business/AppleReceiptStatus.cs b/src/Core/Models/Business/AppleReceiptStatus.cs index 26f7537af..e54ce91e6 100644 --- a/src/Core/Models/Business/AppleReceiptStatus.cs +++ b/src/Core/Models/Business/AppleReceiptStatus.cs @@ -3,133 +3,132 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Billing.Models +namespace Bit.Billing.Models; + +public class AppleReceiptStatus { - public class AppleReceiptStatus + [JsonPropertyName("status")] + public int? Status { get; set; } + [JsonPropertyName("environment")] + public string Environment { get; set; } + [JsonPropertyName("latest_receipt")] + public string LatestReceipt { get; set; } + [JsonPropertyName("receipt")] + public AppleReceipt Receipt { get; set; } + [JsonPropertyName("latest_receipt_info")] + public List LatestReceiptInfo { get; set; } + [JsonPropertyName("pending_renewal_info")] + public List PendingRenewalInfo { get; set; } + + public string GetOriginalTransactionId() { - [JsonPropertyName("status")] - public int? Status { get; set; } - [JsonPropertyName("environment")] - public string Environment { get; set; } - [JsonPropertyName("latest_receipt")] - public string LatestReceipt { get; set; } - [JsonPropertyName("receipt")] - public AppleReceipt Receipt { get; set; } - [JsonPropertyName("latest_receipt_info")] - public List LatestReceiptInfo { get; set; } - [JsonPropertyName("pending_renewal_info")] - public List PendingRenewalInfo { get; set; } + return LatestReceiptInfo?.LastOrDefault()?.OriginalTransactionId; + } - public string GetOriginalTransactionId() - { - return LatestReceiptInfo?.LastOrDefault()?.OriginalTransactionId; - } + public string GetLastTransactionId() + { + return LatestReceiptInfo?.LastOrDefault()?.TransactionId; + } - public string GetLastTransactionId() - { - return LatestReceiptInfo?.LastOrDefault()?.TransactionId; - } + public AppleTransaction GetLastTransaction() + { + return LatestReceiptInfo?.LastOrDefault(); + } - public AppleTransaction GetLastTransaction() - { - return LatestReceiptInfo?.LastOrDefault(); - } + public DateTime? GetLastExpiresDate() + { + return LatestReceiptInfo?.LastOrDefault()?.ExpiresDate; + } - public DateTime? GetLastExpiresDate() - { - return LatestReceiptInfo?.LastOrDefault()?.ExpiresDate; - } + public string GetReceiptData() + { + return LatestReceipt; + } - public string GetReceiptData() - { - return LatestReceipt; - } + public DateTime? GetLastCancellationDate() + { + return LatestReceiptInfo?.LastOrDefault()?.CancellationDate; + } - public DateTime? GetLastCancellationDate() + public bool IsRefunded() + { + var cancellationDate = GetLastCancellationDate(); + var expiresDate = GetLastCancellationDate(); + if (cancellationDate.HasValue && expiresDate.HasValue) { - return LatestReceiptInfo?.LastOrDefault()?.CancellationDate; + return cancellationDate.Value <= expiresDate.Value; } + return false; + } - public bool IsRefunded() + public Transaction BuildTransactionFromLastTransaction(decimal amount, Guid userId) + { + return new Transaction { - var cancellationDate = GetLastCancellationDate(); - var expiresDate = GetLastCancellationDate(); - if (cancellationDate.HasValue && expiresDate.HasValue) - { - return cancellationDate.Value <= expiresDate.Value; - } - return false; - } + Amount = amount, + CreationDate = GetLastTransaction().PurchaseDate, + Gateway = GatewayType.AppStore, + GatewayId = GetLastTransactionId(), + UserId = userId, + PaymentMethodType = PaymentMethodType.AppleInApp, + Details = GetLastTransactionId() + }; + } - public Transaction BuildTransactionFromLastTransaction(decimal amount, Guid userId) - { - return new Transaction - { - Amount = amount, - CreationDate = GetLastTransaction().PurchaseDate, - Gateway = GatewayType.AppStore, - GatewayId = GetLastTransactionId(), - UserId = userId, - PaymentMethodType = PaymentMethodType.AppleInApp, - Details = GetLastTransactionId() - }; - } + public class AppleReceipt + { + [JsonPropertyName("receipt_type")] + public string ReceiptType { get; set; } + [JsonPropertyName("bundle_id")] + public string BundleId { get; set; } + [JsonPropertyName("receipt_creation_date_ms")] + [JsonConverter(typeof(MsEpochConverter))] + public DateTime ReceiptCreationDate { get; set; } + [JsonPropertyName("in_app")] + public List InApp { get; set; } + } - public class AppleReceipt - { - [JsonPropertyName("receipt_type")] - public string ReceiptType { get; set; } - [JsonPropertyName("bundle_id")] - public string BundleId { get; set; } - [JsonPropertyName("receipt_creation_date_ms")] - [JsonConverter(typeof(MsEpochConverter))] - public DateTime ReceiptCreationDate { get; set; } - [JsonPropertyName("in_app")] - public List InApp { get; set; } - } + public class AppleRenewalInfo + { + [JsonPropertyName("expiration_intent")] + public string ExpirationIntent { get; set; } + [JsonPropertyName("auto_renew_product_id")] + public string AutoRenewProductId { get; set; } + [JsonPropertyName("original_transaction_id")] + public string OriginalTransactionId { get; set; } + [JsonPropertyName("is_in_billing_retry_period")] + public string IsInBillingRetryPeriod { get; set; } + [JsonPropertyName("product_id")] + public string ProductId { get; set; } + [JsonPropertyName("auto_renew_status")] + public string AutoRenewStatus { get; set; } + } - public class AppleRenewalInfo - { - [JsonPropertyName("expiration_intent")] - public string ExpirationIntent { get; set; } - [JsonPropertyName("auto_renew_product_id")] - public string AutoRenewProductId { get; set; } - [JsonPropertyName("original_transaction_id")] - public string OriginalTransactionId { get; set; } - [JsonPropertyName("is_in_billing_retry_period")] - public string IsInBillingRetryPeriod { get; set; } - [JsonPropertyName("product_id")] - public string ProductId { get; set; } - [JsonPropertyName("auto_renew_status")] - public string AutoRenewStatus { get; set; } - } - - public class AppleTransaction - { - [JsonPropertyName("quantity")] - public string Quantity { get; set; } - [JsonPropertyName("product_id")] - public string ProductId { get; set; } - [JsonPropertyName("transaction_id")] - public string TransactionId { get; set; } - [JsonPropertyName("original_transaction_id")] - public string OriginalTransactionId { get; set; } - [JsonPropertyName("purchase_date_ms")] - [JsonConverter(typeof(MsEpochConverter))] - public DateTime PurchaseDate { get; set; } - [JsonPropertyName("original_purchase_date_ms")] - [JsonConverter(typeof(MsEpochConverter))] - public DateTime OriginalPurchaseDate { get; set; } - [JsonPropertyName("expires_date_ms")] - [JsonConverter(typeof(MsEpochConverter))] - public DateTime ExpiresDate { get; set; } - [JsonPropertyName("cancellation_date_ms")] - [JsonConverter(typeof(MsEpochConverter))] - public DateTime? CancellationDate { get; set; } - [JsonPropertyName("web_order_line_item_id")] - public string WebOrderLineItemId { get; set; } - [JsonPropertyName("cancellation_reason")] - public string CancellationReason { get; set; } - } + public class AppleTransaction + { + [JsonPropertyName("quantity")] + public string Quantity { get; set; } + [JsonPropertyName("product_id")] + public string ProductId { get; set; } + [JsonPropertyName("transaction_id")] + public string TransactionId { get; set; } + [JsonPropertyName("original_transaction_id")] + public string OriginalTransactionId { get; set; } + [JsonPropertyName("purchase_date_ms")] + [JsonConverter(typeof(MsEpochConverter))] + public DateTime PurchaseDate { get; set; } + [JsonPropertyName("original_purchase_date_ms")] + [JsonConverter(typeof(MsEpochConverter))] + public DateTime OriginalPurchaseDate { get; set; } + [JsonPropertyName("expires_date_ms")] + [JsonConverter(typeof(MsEpochConverter))] + public DateTime ExpiresDate { get; set; } + [JsonPropertyName("cancellation_date_ms")] + [JsonConverter(typeof(MsEpochConverter))] + public DateTime? CancellationDate { get; set; } + [JsonPropertyName("web_order_line_item_id")] + public string WebOrderLineItemId { get; set; } + [JsonPropertyName("cancellation_reason")] + public string CancellationReason { get; set; } } } diff --git a/src/Core/Models/Business/BillingInfo.cs b/src/Core/Models/Business/BillingInfo.cs index 557a3288f..1e1915566 100644 --- a/src/Core/Models/Business/BillingInfo.cs +++ b/src/Core/Models/Business/BillingInfo.cs @@ -2,155 +2,154 @@ using Bit.Core.Enums; using Stripe; -namespace Bit.Core.Models.Business +namespace Bit.Core.Models.Business; + +public class BillingInfo { - public class BillingInfo + public decimal Balance { get; set; } + public BillingSource PaymentSource { get; set; } + public IEnumerable Invoices { get; set; } = new List(); + public IEnumerable Transactions { get; set; } = new List(); + + public class BillingSource { - public decimal Balance { get; set; } - public BillingSource PaymentSource { get; set; } - public IEnumerable Invoices { get; set; } = new List(); - public IEnumerable Transactions { get; set; } = new List(); + public BillingSource() { } - public class BillingSource + public BillingSource(PaymentMethod method) { - public BillingSource() { } - - public BillingSource(PaymentMethod method) + if (method.Card != null) { - if (method.Card != null) - { - Type = PaymentMethodType.Card; - Description = $"{method.Card.Brand?.ToUpperInvariant()}, *{method.Card.Last4}, " + - string.Format("{0}/{1}", - string.Concat(method.Card.ExpMonth < 10 ? - "0" : string.Empty, method.Card.ExpMonth), - method.Card.ExpYear); - CardBrand = method.Card.Brand; - } + Type = PaymentMethodType.Card; + Description = $"{method.Card.Brand?.ToUpperInvariant()}, *{method.Card.Last4}, " + + string.Format("{0}/{1}", + string.Concat(method.Card.ExpMonth < 10 ? + "0" : string.Empty, method.Card.ExpMonth), + method.Card.ExpYear); + CardBrand = method.Card.Brand; } + } - public BillingSource(IPaymentSource source) + public BillingSource(IPaymentSource source) + { + if (source is BankAccount bankAccount) { - if (source is BankAccount bankAccount) - { - Type = PaymentMethodType.BankAccount; - Description = $"{bankAccount.BankName}, *{bankAccount.Last4} - " + - (bankAccount.Status == "verified" ? "verified" : - bankAccount.Status == "errored" ? "invalid" : - bankAccount.Status == "verification_failed" ? "verification failed" : "unverified"); - NeedsVerification = bankAccount.Status == "new" || bankAccount.Status == "validated"; - } - else if (source is Card card) - { - Type = PaymentMethodType.Card; - Description = $"{card.Brand}, *{card.Last4}, " + - string.Format("{0}/{1}", - string.Concat(card.ExpMonth < 10 ? - "0" : string.Empty, card.ExpMonth), - card.ExpYear); - CardBrand = card.Brand; - } - else if (source is Source src && src.Card != null) - { - Type = PaymentMethodType.Card; - Description = $"{src.Card.Brand}, *{src.Card.Last4}, " + - string.Format("{0}/{1}", - string.Concat(src.Card.ExpMonth < 10 ? - "0" : string.Empty, src.Card.ExpMonth), - src.Card.ExpYear); - CardBrand = src.Card.Brand; - } + Type = PaymentMethodType.BankAccount; + Description = $"{bankAccount.BankName}, *{bankAccount.Last4} - " + + (bankAccount.Status == "verified" ? "verified" : + bankAccount.Status == "errored" ? "invalid" : + bankAccount.Status == "verification_failed" ? "verification failed" : "unverified"); + NeedsVerification = bankAccount.Status == "new" || bankAccount.Status == "validated"; } - - public BillingSource(Braintree.PaymentMethod method) + else if (source is Card card) { - if (method is Braintree.PayPalAccount paypal) - { - Type = PaymentMethodType.PayPal; - Description = paypal.Email; - } - else if (method is Braintree.CreditCard card) - { - Type = PaymentMethodType.Card; - Description = $"{card.CardType.ToString()}, *{card.LastFour}, " + - string.Format("{0}/{1}", - string.Concat(card.ExpirationMonth.Length == 1 ? - "0" : string.Empty, card.ExpirationMonth), - card.ExpirationYear); - CardBrand = card.CardType.ToString(); - } - else if (method is Braintree.UsBankAccount bank) - { - Type = PaymentMethodType.BankAccount; - Description = $"{bank.BankName}, *{bank.Last4}"; - } - else - { - throw new NotSupportedException("Method not supported."); - } + Type = PaymentMethodType.Card; + Description = $"{card.Brand}, *{card.Last4}, " + + string.Format("{0}/{1}", + string.Concat(card.ExpMonth < 10 ? + "0" : string.Empty, card.ExpMonth), + card.ExpYear); + CardBrand = card.Brand; } + else if (source is Source src && src.Card != null) + { + Type = PaymentMethodType.Card; + Description = $"{src.Card.Brand}, *{src.Card.Last4}, " + + string.Format("{0}/{1}", + string.Concat(src.Card.ExpMonth < 10 ? + "0" : string.Empty, src.Card.ExpMonth), + src.Card.ExpYear); + CardBrand = src.Card.Brand; + } + } - public BillingSource(Braintree.UsBankAccountDetails bank) + public BillingSource(Braintree.PaymentMethod method) + { + if (method is Braintree.PayPalAccount paypal) + { + Type = PaymentMethodType.PayPal; + Description = paypal.Email; + } + else if (method is Braintree.CreditCard card) + { + Type = PaymentMethodType.Card; + Description = $"{card.CardType.ToString()}, *{card.LastFour}, " + + string.Format("{0}/{1}", + string.Concat(card.ExpirationMonth.Length == 1 ? + "0" : string.Empty, card.ExpirationMonth), + card.ExpirationYear); + CardBrand = card.CardType.ToString(); + } + else if (method is Braintree.UsBankAccount bank) { Type = PaymentMethodType.BankAccount; Description = $"{bank.BankName}, *{bank.Last4}"; } - - public BillingSource(Braintree.PayPalDetails paypal) + else { - Type = PaymentMethodType.PayPal; - Description = paypal.PayerEmail; + throw new NotSupportedException("Method not supported."); } - - public PaymentMethodType Type { get; set; } - public string CardBrand { get; set; } - public string Description { get; set; } - public bool NeedsVerification { get; set; } } - public class BillingTransaction + public BillingSource(Braintree.UsBankAccountDetails bank) { - public BillingTransaction(Transaction transaction) - { - Id = transaction.Id; - CreatedDate = transaction.CreationDate; - Refunded = transaction.Refunded; - Type = transaction.Type; - PaymentMethodType = transaction.PaymentMethodType; - Details = transaction.Details; - Amount = transaction.Amount; - RefundedAmount = transaction.RefundedAmount; - } - - public Guid Id { get; set; } - public DateTime CreatedDate { get; set; } - public decimal Amount { get; set; } - public bool? Refunded { get; set; } - public bool? PartiallyRefunded => !Refunded.GetValueOrDefault() && RefundedAmount.GetValueOrDefault() > 0; - public decimal? RefundedAmount { get; set; } - public TransactionType Type { get; set; } - public PaymentMethodType? PaymentMethodType { get; set; } - public string Details { get; set; } + Type = PaymentMethodType.BankAccount; + Description = $"{bank.BankName}, *{bank.Last4}"; } - public class BillingInvoice + public BillingSource(Braintree.PayPalDetails paypal) { - public BillingInvoice(Invoice inv) - { - Date = inv.Created; - Url = inv.HostedInvoiceUrl; - PdfUrl = inv.InvoicePdf; - Number = inv.Number; - Paid = inv.Paid; - Amount = inv.Total / 100M; - } - - public decimal Amount { get; set; } - public DateTime? Date { get; set; } - public string Url { get; set; } - public string PdfUrl { get; set; } - public string Number { get; set; } - public bool Paid { get; set; } + Type = PaymentMethodType.PayPal; + Description = paypal.PayerEmail; } + + public PaymentMethodType Type { get; set; } + public string CardBrand { get; set; } + public string Description { get; set; } + public bool NeedsVerification { get; set; } + } + + public class BillingTransaction + { + public BillingTransaction(Transaction transaction) + { + Id = transaction.Id; + CreatedDate = transaction.CreationDate; + Refunded = transaction.Refunded; + Type = transaction.Type; + PaymentMethodType = transaction.PaymentMethodType; + Details = transaction.Details; + Amount = transaction.Amount; + RefundedAmount = transaction.RefundedAmount; + } + + public Guid Id { get; set; } + public DateTime CreatedDate { get; set; } + public decimal Amount { get; set; } + public bool? Refunded { get; set; } + public bool? PartiallyRefunded => !Refunded.GetValueOrDefault() && RefundedAmount.GetValueOrDefault() > 0; + public decimal? RefundedAmount { get; set; } + public TransactionType Type { get; set; } + public PaymentMethodType? PaymentMethodType { get; set; } + public string Details { get; set; } + } + + public class BillingInvoice + { + public BillingInvoice(Invoice inv) + { + Date = inv.Created; + Url = inv.HostedInvoiceUrl; + PdfUrl = inv.InvoicePdf; + Number = inv.Number; + Paid = inv.Paid; + Amount = inv.Total / 100M; + } + + public decimal Amount { get; set; } + public DateTime? Date { get; set; } + public string Url { get; set; } + public string PdfUrl { get; set; } + public string Number { get; set; } + public bool Paid { get; set; } } } diff --git a/src/Core/Models/Business/CaptchaResponse.cs b/src/Core/Models/Business/CaptchaResponse.cs index c77330242..aaafc8e7d 100644 --- a/src/Core/Models/Business/CaptchaResponse.cs +++ b/src/Core/Models/Business/CaptchaResponse.cs @@ -1,10 +1,9 @@ -namespace Bit.Core.Models.Business +namespace Bit.Core.Models.Business; + +public class CaptchaResponse { - public class CaptchaResponse - { - public bool Success { get; set; } - public bool MaybeBot { get; set; } - public bool IsBot { get; set; } - public double Score { get; set; } - } + public bool Success { get; set; } + public bool MaybeBot { get; set; } + public bool IsBot { get; set; } + public double Score { get; set; } } diff --git a/src/Core/Models/Business/ExpiringToken.cs b/src/Core/Models/Business/ExpiringToken.cs index 3ed16a1cf..db09a540f 100644 --- a/src/Core/Models/Business/ExpiringToken.cs +++ b/src/Core/Models/Business/ExpiringToken.cs @@ -1,14 +1,13 @@ -namespace Bit.Core.Models.Business -{ - public class ExpiringToken - { - public readonly string Token; - public readonly DateTime ExpirationDate; +namespace Bit.Core.Models.Business; - public ExpiringToken(string token, DateTime expirationDate) - { - Token = token; - ExpirationDate = expirationDate; - } +public class ExpiringToken +{ + public readonly string Token; + public readonly DateTime ExpirationDate; + + public ExpiringToken(string token, DateTime expirationDate) + { + Token = token; + ExpirationDate = expirationDate; } } diff --git a/src/Core/Models/Business/ILicense.cs b/src/Core/Models/Business/ILicense.cs index 0e03e41c6..ad389b0a1 100644 --- a/src/Core/Models/Business/ILicense.cs +++ b/src/Core/Models/Business/ILicense.cs @@ -1,21 +1,20 @@ using System.Security.Cryptography.X509Certificates; -namespace Bit.Core.Models.Business +namespace Bit.Core.Models.Business; + +public interface ILicense { - public interface ILicense - { - string LicenseKey { get; set; } - int Version { get; set; } - DateTime Issued { get; set; } - DateTime? Refresh { get; set; } - DateTime? Expires { get; set; } - bool Trial { get; set; } - string Hash { get; set; } - string Signature { get; set; } - byte[] SignatureBytes { get; } - byte[] GetDataBytes(bool forHash = false); - byte[] ComputeHash(); - bool VerifySignature(X509Certificate2 certificate); - byte[] Sign(X509Certificate2 certificate); - } + string LicenseKey { get; set; } + int Version { get; set; } + DateTime Issued { get; set; } + DateTime? Refresh { get; set; } + DateTime? Expires { get; set; } + bool Trial { get; set; } + string Hash { get; set; } + string Signature { get; set; } + byte[] SignatureBytes { get; } + byte[] GetDataBytes(bool forHash = false); + byte[] ComputeHash(); + bool VerifySignature(X509Certificate2 certificate); + byte[] Sign(X509Certificate2 certificate); } diff --git a/src/Core/Models/Business/ImportedGroup.cs b/src/Core/Models/Business/ImportedGroup.cs index ee4589dfa..bd0e38933 100644 --- a/src/Core/Models/Business/ImportedGroup.cs +++ b/src/Core/Models/Business/ImportedGroup.cs @@ -1,10 +1,9 @@ using Bit.Core.Entities; -namespace Bit.Core.Models.Business +namespace Bit.Core.Models.Business; + +public class ImportedGroup { - public class ImportedGroup - { - public Group Group { get; set; } - public HashSet ExternalUserIds { get; set; } - } + public Group Group { get; set; } + public HashSet ExternalUserIds { get; set; } } diff --git a/src/Core/Models/Business/ImportedOrganizationUser.cs b/src/Core/Models/Business/ImportedOrganizationUser.cs index c57ce2123..967cdf253 100644 --- a/src/Core/Models/Business/ImportedOrganizationUser.cs +++ b/src/Core/Models/Business/ImportedOrganizationUser.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Models.Business +namespace Bit.Core.Models.Business; + +public class ImportedOrganizationUser { - public class ImportedOrganizationUser - { - public string Email { get; set; } - public string ExternalId { get; set; } - } + public string Email { get; set; } + public string ExternalId { get; set; } } diff --git a/src/Core/Models/Business/OrganizationLicense.cs b/src/Core/Models/Business/OrganizationLicense.cs index 6d1ff069d..6f13bc358 100644 --- a/src/Core/Models/Business/OrganizationLicense.cs +++ b/src/Core/Models/Business/OrganizationLicense.cs @@ -8,304 +8,303 @@ using Bit.Core.Enums; using Bit.Core.Services; using Bit.Core.Settings; -namespace Bit.Core.Models.Business +namespace Bit.Core.Models.Business; + +public class OrganizationLicense : ILicense { - public class OrganizationLicense : ILicense + public OrganizationLicense() + { } + + public OrganizationLicense(Organization org, SubscriptionInfo subscriptionInfo, Guid installationId, + ILicensingService licenseService, int? version = null) { - public OrganizationLicense() - { } + Version = version.GetValueOrDefault(CURRENT_LICENSE_FILE_VERSION); // TODO: Remember to change the constant + LicenseType = Enums.LicenseType.Organization; + LicenseKey = org.LicenseKey; + InstallationId = installationId; + Id = org.Id; + Name = org.Name; + BillingEmail = org.BillingEmail; + BusinessName = org.BusinessName; + Enabled = org.Enabled; + Plan = org.Plan; + PlanType = org.PlanType; + Seats = org.Seats; + MaxCollections = org.MaxCollections; + UsePolicies = org.UsePolicies; + UseSso = org.UseSso; + UseKeyConnector = org.UseKeyConnector; + UseScim = org.UseScim; + UseGroups = org.UseGroups; + UseEvents = org.UseEvents; + UseDirectory = org.UseDirectory; + UseTotp = org.UseTotp; + Use2fa = org.Use2fa; + UseApi = org.UseApi; + UseResetPassword = org.UseResetPassword; + MaxStorageGb = org.MaxStorageGb; + SelfHost = org.SelfHost; + UsersGetPremium = org.UsersGetPremium; + Issued = DateTime.UtcNow; - public OrganizationLicense(Organization org, SubscriptionInfo subscriptionInfo, Guid installationId, - ILicensingService licenseService, int? version = null) + if (subscriptionInfo?.Subscription == null) { - Version = version.GetValueOrDefault(CURRENT_LICENSE_FILE_VERSION); // TODO: Remember to change the constant - LicenseType = Enums.LicenseType.Organization; - LicenseKey = org.LicenseKey; - InstallationId = installationId; - Id = org.Id; - Name = org.Name; - BillingEmail = org.BillingEmail; - BusinessName = org.BusinessName; - Enabled = org.Enabled; - Plan = org.Plan; - PlanType = org.PlanType; - Seats = org.Seats; - MaxCollections = org.MaxCollections; - UsePolicies = org.UsePolicies; - UseSso = org.UseSso; - UseKeyConnector = org.UseKeyConnector; - UseScim = org.UseScim; - UseGroups = org.UseGroups; - UseEvents = org.UseEvents; - UseDirectory = org.UseDirectory; - UseTotp = org.UseTotp; - Use2fa = org.Use2fa; - UseApi = org.UseApi; - UseResetPassword = org.UseResetPassword; - MaxStorageGb = org.MaxStorageGb; - SelfHost = org.SelfHost; - UsersGetPremium = org.UsersGetPremium; - Issued = DateTime.UtcNow; - - if (subscriptionInfo?.Subscription == null) + if (org.PlanType == PlanType.Custom && org.ExpirationDate.HasValue) { - if (org.PlanType == PlanType.Custom && org.ExpirationDate.HasValue) - { - Expires = Refresh = org.ExpirationDate.Value; - Trial = false; - } - else - { - Expires = Refresh = Issued.AddDays(7); - Trial = true; - } - } - else if (subscriptionInfo.Subscription.TrialEndDate.HasValue && - subscriptionInfo.Subscription.TrialEndDate.Value > DateTime.UtcNow) - { - Expires = Refresh = subscriptionInfo.Subscription.TrialEndDate.Value; - Trial = true; - } - else - { - if (org.ExpirationDate.HasValue && org.ExpirationDate.Value < DateTime.UtcNow) - { - // expired - Expires = Refresh = org.ExpirationDate.Value; - } - else if (subscriptionInfo?.Subscription?.PeriodDuration != null && - subscriptionInfo.Subscription.PeriodDuration > TimeSpan.FromDays(180)) - { - Refresh = DateTime.UtcNow.AddDays(30); - Expires = subscriptionInfo?.Subscription.PeriodEndDate.Value.AddDays(60); - } - else - { - Expires = org.ExpirationDate.HasValue ? org.ExpirationDate.Value.AddMonths(11) : Issued.AddYears(1); - Refresh = DateTime.UtcNow - Expires > TimeSpan.FromDays(30) ? DateTime.UtcNow.AddDays(30) : Expires; - } - + Expires = Refresh = org.ExpirationDate.Value; Trial = false; } - - Hash = Convert.ToBase64String(ComputeHash()); - Signature = Convert.ToBase64String(licenseService.SignLicense(this)); - } - - public string LicenseKey { get; set; } - public Guid InstallationId { get; set; } - public Guid Id { get; set; } - public string Name { get; set; } - public string BillingEmail { get; set; } - public string BusinessName { get; set; } - public bool Enabled { get; set; } - public string Plan { get; set; } - public PlanType PlanType { get; set; } - public int? Seats { get; set; } - public short? MaxCollections { get; set; } - public bool UsePolicies { get; set; } - public bool UseSso { get; set; } - public bool UseKeyConnector { get; set; } - public bool UseScim { get; set; } - public bool UseGroups { get; set; } - public bool UseEvents { get; set; } - public bool UseDirectory { get; set; } - public bool UseTotp { get; set; } - public bool Use2fa { get; set; } - public bool UseApi { get; set; } - public bool UseResetPassword { get; set; } - public short? MaxStorageGb { get; set; } - public bool SelfHost { get; set; } - public bool UsersGetPremium { get; set; } - public int Version { get; set; } - public DateTime Issued { get; set; } - public DateTime? Refresh { get; set; } - public DateTime? Expires { get; set; } - public bool Trial { get; set; } - public LicenseType? LicenseType { get; set; } - public string Hash { get; set; } - public string Signature { get; set; } - [JsonIgnore] - public byte[] SignatureBytes => Convert.FromBase64String(Signature); - - /// - /// Represents the current version of the license format. Should be updated whenever new fields are added. - /// - private const int CURRENT_LICENSE_FILE_VERSION = 9; - private bool ValidLicenseVersion - { - get => Version is >= 1 and <= 10; - } - - public byte[] GetDataBytes(bool forHash = false) - { - string data = null; - if (ValidLicenseVersion) + else { - var props = typeof(OrganizationLicense) - .GetProperties(BindingFlags.Public | BindingFlags.Instance) - .Where(p => - !p.Name.Equals(nameof(Signature)) && - !p.Name.Equals(nameof(SignatureBytes)) && - !p.Name.Equals(nameof(LicenseType)) && - // UsersGetPremium was added in Version 2 - (Version >= 2 || !p.Name.Equals(nameof(UsersGetPremium))) && - // UseEvents was added in Version 3 - (Version >= 3 || !p.Name.Equals(nameof(UseEvents))) && - // Use2fa was added in Version 4 - (Version >= 4 || !p.Name.Equals(nameof(Use2fa))) && - // UseApi was added in Version 5 - (Version >= 5 || !p.Name.Equals(nameof(UseApi))) && - // UsePolicies was added in Version 6 - (Version >= 6 || !p.Name.Equals(nameof(UsePolicies))) && - // UseSso was added in Version 7 - (Version >= 7 || !p.Name.Equals(nameof(UseSso))) && - // UseResetPassword was added in Version 8 - (Version >= 8 || !p.Name.Equals(nameof(UseResetPassword))) && - // UseKeyConnector was added in Version 9 - (Version >= 9 || !p.Name.Equals(nameof(UseKeyConnector))) && - // UseScim was added in Version 10 - (Version >= 10 || !p.Name.Equals(nameof(UseScim))) && + Expires = Refresh = Issued.AddDays(7); + Trial = true; + } + } + else if (subscriptionInfo.Subscription.TrialEndDate.HasValue && + subscriptionInfo.Subscription.TrialEndDate.Value > DateTime.UtcNow) + { + Expires = Refresh = subscriptionInfo.Subscription.TrialEndDate.Value; + Trial = true; + } + else + { + if (org.ExpirationDate.HasValue && org.ExpirationDate.Value < DateTime.UtcNow) + { + // expired + Expires = Refresh = org.ExpirationDate.Value; + } + else if (subscriptionInfo?.Subscription?.PeriodDuration != null && + subscriptionInfo.Subscription.PeriodDuration > TimeSpan.FromDays(180)) + { + Refresh = DateTime.UtcNow.AddDays(30); + Expires = subscriptionInfo?.Subscription.PeriodEndDate.Value.AddDays(60); + } + else + { + Expires = org.ExpirationDate.HasValue ? org.ExpirationDate.Value.AddMonths(11) : Issued.AddYears(1); + Refresh = DateTime.UtcNow - Expires > TimeSpan.FromDays(30) ? DateTime.UtcNow.AddDays(30) : Expires; + } + + Trial = false; + } + + Hash = Convert.ToBase64String(ComputeHash()); + Signature = Convert.ToBase64String(licenseService.SignLicense(this)); + } + + public string LicenseKey { get; set; } + public Guid InstallationId { get; set; } + public Guid Id { get; set; } + public string Name { get; set; } + public string BillingEmail { get; set; } + public string BusinessName { get; set; } + public bool Enabled { get; set; } + public string Plan { get; set; } + public PlanType PlanType { get; set; } + public int? Seats { get; set; } + public short? MaxCollections { get; set; } + public bool UsePolicies { get; set; } + public bool UseSso { get; set; } + public bool UseKeyConnector { get; set; } + public bool UseScim { get; set; } + public bool UseGroups { get; set; } + public bool UseEvents { get; set; } + public bool UseDirectory { get; set; } + public bool UseTotp { get; set; } + public bool Use2fa { get; set; } + public bool UseApi { get; set; } + public bool UseResetPassword { get; set; } + public short? MaxStorageGb { get; set; } + public bool SelfHost { get; set; } + public bool UsersGetPremium { get; set; } + public int Version { get; set; } + public DateTime Issued { get; set; } + public DateTime? Refresh { get; set; } + public DateTime? Expires { get; set; } + public bool Trial { get; set; } + public LicenseType? LicenseType { get; set; } + public string Hash { get; set; } + public string Signature { get; set; } + [JsonIgnore] + public byte[] SignatureBytes => Convert.FromBase64String(Signature); + + /// + /// Represents the current version of the license format. Should be updated whenever new fields are added. + /// + private const int CURRENT_LICENSE_FILE_VERSION = 9; + private bool ValidLicenseVersion + { + get => Version is >= 1 and <= 10; + } + + public byte[] GetDataBytes(bool forHash = false) + { + string data = null; + if (ValidLicenseVersion) + { + var props = typeof(OrganizationLicense) + .GetProperties(BindingFlags.Public | BindingFlags.Instance) + .Where(p => + !p.Name.Equals(nameof(Signature)) && + !p.Name.Equals(nameof(SignatureBytes)) && + !p.Name.Equals(nameof(LicenseType)) && + // UsersGetPremium was added in Version 2 + (Version >= 2 || !p.Name.Equals(nameof(UsersGetPremium))) && + // UseEvents was added in Version 3 + (Version >= 3 || !p.Name.Equals(nameof(UseEvents))) && + // Use2fa was added in Version 4 + (Version >= 4 || !p.Name.Equals(nameof(Use2fa))) && + // UseApi was added in Version 5 + (Version >= 5 || !p.Name.Equals(nameof(UseApi))) && + // UsePolicies was added in Version 6 + (Version >= 6 || !p.Name.Equals(nameof(UsePolicies))) && + // UseSso was added in Version 7 + (Version >= 7 || !p.Name.Equals(nameof(UseSso))) && + // UseResetPassword was added in Version 8 + (Version >= 8 || !p.Name.Equals(nameof(UseResetPassword))) && + // UseKeyConnector was added in Version 9 + (Version >= 9 || !p.Name.Equals(nameof(UseKeyConnector))) && + // UseScim was added in Version 10 + (Version >= 10 || !p.Name.Equals(nameof(UseScim))) && + ( + !forHash || ( - !forHash || - ( - !p.Name.Equals(nameof(Hash)) && - !p.Name.Equals(nameof(Issued)) && - !p.Name.Equals(nameof(Refresh)) - ) - )) - .OrderBy(p => p.Name) - .Select(p => $"{p.Name}:{Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}") - .Aggregate((c, n) => $"{c}|{n}"); - data = $"license:organization|{props}"; - } - else - { - throw new NotSupportedException($"Version {Version} is not supported."); - } - - return Encoding.UTF8.GetBytes(data); + !p.Name.Equals(nameof(Hash)) && + !p.Name.Equals(nameof(Issued)) && + !p.Name.Equals(nameof(Refresh)) + ) + )) + .OrderBy(p => p.Name) + .Select(p => $"{p.Name}:{Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}") + .Aggregate((c, n) => $"{c}|{n}"); + data = $"license:organization|{props}"; + } + else + { + throw new NotSupportedException($"Version {Version} is not supported."); } - public byte[] ComputeHash() + return Encoding.UTF8.GetBytes(data); + } + + public byte[] ComputeHash() + { + using (var alg = SHA256.Create()) { - using (var alg = SHA256.Create()) - { - return alg.ComputeHash(GetDataBytes(true)); - } + return alg.ComputeHash(GetDataBytes(true)); + } + } + + public bool CanUse(IGlobalSettings globalSettings) + { + if (!Enabled || Issued > DateTime.UtcNow || Expires < DateTime.UtcNow) + { + return false; } - public bool CanUse(IGlobalSettings globalSettings) + if (ValidLicenseVersion) { - if (!Enabled || Issued > DateTime.UtcNow || Expires < DateTime.UtcNow) - { - return false; - } + return InstallationId == globalSettings.Installation.Id && SelfHost; + } + else + { + throw new NotSupportedException($"Version {Version} is not supported."); + } + } - if (ValidLicenseVersion) - { - return InstallationId == globalSettings.Installation.Id && SelfHost; - } - else - { - throw new NotSupportedException($"Version {Version} is not supported."); - } + public bool VerifyData(Organization organization, IGlobalSettings globalSettings) + { + if (Issued > DateTime.UtcNow || Expires < DateTime.UtcNow) + { + return false; } - public bool VerifyData(Organization organization, IGlobalSettings globalSettings) + if (ValidLicenseVersion) { - if (Issued > DateTime.UtcNow || Expires < DateTime.UtcNow) + var valid = + globalSettings.Installation.Id == InstallationId && + organization.LicenseKey != null && organization.LicenseKey.Equals(LicenseKey) && + organization.Enabled == Enabled && + organization.PlanType == PlanType && + organization.Seats == Seats && + organization.MaxCollections == MaxCollections && + organization.UseGroups == UseGroups && + organization.UseDirectory == UseDirectory && + organization.UseTotp == UseTotp && + organization.SelfHost == SelfHost && + organization.Name.Equals(Name); + + if (valid && Version >= 2) { - return false; + valid = organization.UsersGetPremium == UsersGetPremium; } - if (ValidLicenseVersion) + if (valid && Version >= 3) { - var valid = - globalSettings.Installation.Id == InstallationId && - organization.LicenseKey != null && organization.LicenseKey.Equals(LicenseKey) && - organization.Enabled == Enabled && - organization.PlanType == PlanType && - organization.Seats == Seats && - organization.MaxCollections == MaxCollections && - organization.UseGroups == UseGroups && - organization.UseDirectory == UseDirectory && - organization.UseTotp == UseTotp && - organization.SelfHost == SelfHost && - organization.Name.Equals(Name); - - if (valid && Version >= 2) - { - valid = organization.UsersGetPremium == UsersGetPremium; - } - - if (valid && Version >= 3) - { - valid = organization.UseEvents == UseEvents; - } - - if (valid && Version >= 4) - { - valid = organization.Use2fa == Use2fa; - } - - if (valid && Version >= 5) - { - valid = organization.UseApi == UseApi; - } - - if (valid && Version >= 6) - { - valid = organization.UsePolicies == UsePolicies; - } - - if (valid && Version >= 7) - { - valid = organization.UseSso == UseSso; - } - - if (valid && Version >= 8) - { - valid = organization.UseResetPassword == UseResetPassword; - } - - if (valid && Version >= 9) - { - valid = organization.UseKeyConnector == UseKeyConnector; - } - - if (valid && Version >= 10) - { - valid = organization.UseScim == UseScim; - } - - return valid; + valid = organization.UseEvents == UseEvents; } - else + + if (valid && Version >= 4) { - throw new NotSupportedException($"Version {Version} is not supported."); + valid = organization.Use2fa == Use2fa; } + + if (valid && Version >= 5) + { + valid = organization.UseApi == UseApi; + } + + if (valid && Version >= 6) + { + valid = organization.UsePolicies == UsePolicies; + } + + if (valid && Version >= 7) + { + valid = organization.UseSso == UseSso; + } + + if (valid && Version >= 8) + { + valid = organization.UseResetPassword == UseResetPassword; + } + + if (valid && Version >= 9) + { + valid = organization.UseKeyConnector == UseKeyConnector; + } + + if (valid && Version >= 10) + { + valid = organization.UseScim == UseScim; + } + + return valid; + } + else + { + throw new NotSupportedException($"Version {Version} is not supported."); + } + } + + public bool VerifySignature(X509Certificate2 certificate) + { + using (var rsa = certificate.GetRSAPublicKey()) + { + return rsa.VerifyData(GetDataBytes(), SignatureBytes, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + } + } + + public byte[] Sign(X509Certificate2 certificate) + { + if (!certificate.HasPrivateKey) + { + throw new InvalidOperationException("You don't have the private key!"); } - public bool VerifySignature(X509Certificate2 certificate) + using (var rsa = certificate.GetRSAPrivateKey()) { - using (var rsa = certificate.GetRSAPublicKey()) - { - return rsa.VerifyData(GetDataBytes(), SignatureBytes, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); - } - } - - public byte[] Sign(X509Certificate2 certificate) - { - if (!certificate.HasPrivateKey) - { - throw new InvalidOperationException("You don't have the private key!"); - } - - using (var rsa = certificate.GetRSAPrivateKey()) - { - return rsa.SignData(GetDataBytes(), HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); - } + return rsa.SignData(GetDataBytes(), HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); } } } diff --git a/src/Core/Models/Business/OrganizationSignup.cs b/src/Core/Models/Business/OrganizationSignup.cs index a257410fd..970ede9af 100644 --- a/src/Core/Models/Business/OrganizationSignup.cs +++ b/src/Core/Models/Business/OrganizationSignup.cs @@ -1,17 +1,16 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Models.Business +namespace Bit.Core.Models.Business; + +public class OrganizationSignup : OrganizationUpgrade { - public class OrganizationSignup : OrganizationUpgrade - { - public string Name { get; set; } - public string BillingEmail { get; set; } - public User Owner { get; set; } - public string OwnerKey { get; set; } - public string CollectionName { get; set; } - public PaymentMethodType? PaymentMethodType { get; set; } - public string PaymentToken { get; set; } - public int? MaxAutoscaleSeats { get; set; } = null; - } + public string Name { get; set; } + public string BillingEmail { get; set; } + public User Owner { get; set; } + public string OwnerKey { get; set; } + public string CollectionName { get; set; } + public PaymentMethodType? PaymentMethodType { get; set; } + public string PaymentToken { get; set; } + public int? MaxAutoscaleSeats { get; set; } = null; } diff --git a/src/Core/Models/Business/OrganizationUpgrade.cs b/src/Core/Models/Business/OrganizationUpgrade.cs index f6d8aa415..b77a9d012 100644 --- a/src/Core/Models/Business/OrganizationUpgrade.cs +++ b/src/Core/Models/Business/OrganizationUpgrade.cs @@ -1,16 +1,15 @@ using Bit.Core.Enums; -namespace Bit.Core.Models.Business +namespace Bit.Core.Models.Business; + +public class OrganizationUpgrade { - public class OrganizationUpgrade - { - public string BusinessName { get; set; } - public PlanType Plan { get; set; } - public int AdditionalSeats { get; set; } - public short AdditionalStorageGb { get; set; } - public bool PremiumAccessAddon { get; set; } - public TaxInfo TaxInfo { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } - } + public string BusinessName { get; set; } + public PlanType Plan { get; set; } + public int AdditionalSeats { get; set; } + public short AdditionalStorageGb { get; set; } + public bool PremiumAccessAddon { get; set; } + public TaxInfo TaxInfo { get; set; } + public string PublicKey { get; set; } + public string PrivateKey { get; set; } } diff --git a/src/Core/Models/Business/OrganizationUserInvite.cs b/src/Core/Models/Business/OrganizationUserInvite.cs index 8e7f6f865..4fa61d55c 100644 --- a/src/Core/Models/Business/OrganizationUserInvite.cs +++ b/src/Core/Models/Business/OrganizationUserInvite.cs @@ -1,25 +1,24 @@ using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Core.Models.Business +namespace Bit.Core.Models.Business; + +public class OrganizationUserInvite { - public class OrganizationUserInvite + public IEnumerable Emails { get; set; } + public Enums.OrganizationUserType? Type { get; set; } + public bool AccessAll { get; set; } + public Permissions Permissions { get; set; } + public IEnumerable Collections { get; set; } + + public OrganizationUserInvite() { } + + public OrganizationUserInvite(OrganizationUserInviteData requestModel) { - public IEnumerable Emails { get; set; } - public Enums.OrganizationUserType? Type { get; set; } - public bool AccessAll { get; set; } - public Permissions Permissions { get; set; } - public IEnumerable Collections { get; set; } - - public OrganizationUserInvite() { } - - public OrganizationUserInvite(OrganizationUserInviteData requestModel) - { - Emails = requestModel.Emails; - Type = requestModel.Type; - AccessAll = requestModel.AccessAll; - Collections = requestModel.Collections; - Permissions = requestModel.Permissions; - } + Emails = requestModel.Emails; + Type = requestModel.Type; + AccessAll = requestModel.AccessAll; + Collections = requestModel.Collections; + Permissions = requestModel.Permissions; } } diff --git a/src/Core/Models/Business/Provider/ProviderUserInvite.cs b/src/Core/Models/Business/Provider/ProviderUserInvite.cs index 39f609479..72e87728d 100644 --- a/src/Core/Models/Business/Provider/ProviderUserInvite.cs +++ b/src/Core/Models/Business/Provider/ProviderUserInvite.cs @@ -1,36 +1,35 @@ using Bit.Core.Enums.Provider; -namespace Bit.Core.Models.Business.Provider +namespace Bit.Core.Models.Business.Provider; + +public class ProviderUserInvite { - public class ProviderUserInvite + public IEnumerable UserIdentifiers { get; set; } + public ProviderUserType Type { get; set; } + public Guid InvitingUserId { get; set; } + public Guid ProviderId { get; set; } +} + +public static class ProviderUserInviteFactory +{ + public static ProviderUserInvite CreateIntialInvite(IEnumerable inviteeEmails, ProviderUserType type, Guid invitingUserId, Guid providerId) { - public IEnumerable UserIdentifiers { get; set; } - public ProviderUserType Type { get; set; } - public Guid InvitingUserId { get; set; } - public Guid ProviderId { get; set; } + return new ProviderUserInvite + { + UserIdentifiers = inviteeEmails, + Type = type, + InvitingUserId = invitingUserId, + ProviderId = providerId + }; } - public static class ProviderUserInviteFactory + public static ProviderUserInvite CreateReinvite(IEnumerable inviteeUserIds, Guid invitingUserId, Guid providerId) { - public static ProviderUserInvite CreateIntialInvite(IEnumerable inviteeEmails, ProviderUserType type, Guid invitingUserId, Guid providerId) + return new ProviderUserInvite { - return new ProviderUserInvite - { - UserIdentifiers = inviteeEmails, - Type = type, - InvitingUserId = invitingUserId, - ProviderId = providerId - }; - } - - public static ProviderUserInvite CreateReinvite(IEnumerable inviteeUserIds, Guid invitingUserId, Guid providerId) - { - return new ProviderUserInvite - { - UserIdentifiers = inviteeUserIds, - InvitingUserId = invitingUserId, - ProviderId = providerId - }; - } + UserIdentifiers = inviteeUserIds, + InvitingUserId = invitingUserId, + ProviderId = providerId + }; } } diff --git a/src/Core/Models/Business/ReferenceEvent.cs b/src/Core/Models/Business/ReferenceEvent.cs index 35cb5dc73..4f20b2455 100644 --- a/src/Core/Models/Business/ReferenceEvent.cs +++ b/src/Core/Models/Business/ReferenceEvent.cs @@ -2,61 +2,60 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Models.Business +namespace Bit.Core.Models.Business; + +public class ReferenceEvent { - public class ReferenceEvent + public ReferenceEvent() { } + + public ReferenceEvent(ReferenceEventType type, IReferenceable source) { - public ReferenceEvent() { } - - public ReferenceEvent(ReferenceEventType type, IReferenceable source) + Type = type; + if (source != null) { - Type = type; - if (source != null) - { - Source = source.IsUser() ? ReferenceEventSource.User : ReferenceEventSource.Organization; - Id = source.Id; - ReferenceData = source.ReferenceData; - } + Source = source.IsUser() ? ReferenceEventSource.User : ReferenceEventSource.Organization; + Id = source.Id; + ReferenceData = source.ReferenceData; } - - [JsonConverter(typeof(JsonStringEnumConverter))] - public ReferenceEventType Type { get; set; } - - [JsonConverter(typeof(JsonStringEnumConverter))] - public ReferenceEventSource Source { get; set; } - - public Guid Id { get; set; } - - public string ReferenceData { get; set; } - - public DateTime EventDate { get; set; } = DateTime.UtcNow; - - public int? Users { get; set; } - - public bool? EndOfPeriod { get; set; } - - public string PlanName { get; set; } - - public PlanType? PlanType { get; set; } - - public string OldPlanName { get; set; } - - public PlanType? OldPlanType { get; set; } - - public int? Seats { get; set; } - public int? PreviousSeats { get; set; } - - public short? Storage { get; set; } - - [JsonConverter(typeof(JsonStringEnumConverter))] - public SendType? SendType { get; set; } - - public int? MaxAccessCount { get; set; } - - public bool? HasPassword { get; set; } - - public string EventRaisedByUser { get; set; } - - public bool? SalesAssistedTrialStarted { get; set; } } + + [JsonConverter(typeof(JsonStringEnumConverter))] + public ReferenceEventType Type { get; set; } + + [JsonConverter(typeof(JsonStringEnumConverter))] + public ReferenceEventSource Source { get; set; } + + public Guid Id { get; set; } + + public string ReferenceData { get; set; } + + public DateTime EventDate { get; set; } = DateTime.UtcNow; + + public int? Users { get; set; } + + public bool? EndOfPeriod { get; set; } + + public string PlanName { get; set; } + + public PlanType? PlanType { get; set; } + + public string OldPlanName { get; set; } + + public PlanType? OldPlanType { get; set; } + + public int? Seats { get; set; } + public int? PreviousSeats { get; set; } + + public short? Storage { get; set; } + + [JsonConverter(typeof(JsonStringEnumConverter))] + public SendType? SendType { get; set; } + + public int? MaxAccessCount { get; set; } + + public bool? HasPassword { get; set; } + + public string EventRaisedByUser { get; set; } + + public bool? SalesAssistedTrialStarted { get; set; } } diff --git a/src/Core/Models/Business/SubscriptionCreateOptions.cs b/src/Core/Models/Business/SubscriptionCreateOptions.cs index e78aaeda0..4964a625c 100644 --- a/src/Core/Models/Business/SubscriptionCreateOptions.cs +++ b/src/Core/Models/Business/SubscriptionCreateOptions.cs @@ -1,84 +1,83 @@ using Bit.Core.Entities; using Stripe; -namespace Bit.Core.Models.Business +namespace Bit.Core.Models.Business; + +public class OrganizationSubscriptionOptionsBase : Stripe.SubscriptionCreateOptions { - public class OrganizationSubscriptionOptionsBase : Stripe.SubscriptionCreateOptions + public OrganizationSubscriptionOptionsBase(Organization org, StaticStore.Plan plan, TaxInfo taxInfo, int additionalSeats, int additionalStorageGb, bool premiumAccessAddon) { - public OrganizationSubscriptionOptionsBase(Organization org, StaticStore.Plan plan, TaxInfo taxInfo, int additionalSeats, int additionalStorageGb, bool premiumAccessAddon) + Items = new List(); + Metadata = new Dictionary { - Items = new List(); - Metadata = new Dictionary - { - [org.GatewayIdField()] = org.Id.ToString() - }; + [org.GatewayIdField()] = org.Id.ToString() + }; - if (plan.StripePlanId != null) + if (plan.StripePlanId != null) + { + Items.Add(new SubscriptionItemOptions { - Items.Add(new SubscriptionItemOptions - { - Plan = plan.StripePlanId, - Quantity = 1 - }); - } - - if (additionalSeats > 0 && plan.StripeSeatPlanId != null) - { - Items.Add(new SubscriptionItemOptions - { - Plan = plan.StripeSeatPlanId, - Quantity = additionalSeats - }); - } - - if (additionalStorageGb > 0) - { - Items.Add(new SubscriptionItemOptions - { - Plan = plan.StripeStoragePlanId, - Quantity = additionalStorageGb - }); - } - - if (premiumAccessAddon && plan.StripePremiumAccessPlanId != null) - { - Items.Add(new SubscriptionItemOptions - { - Plan = plan.StripePremiumAccessPlanId, - Quantity = 1 - }); - } - - if (!string.IsNullOrWhiteSpace(taxInfo?.StripeTaxRateId)) - { - DefaultTaxRates = new List { taxInfo.StripeTaxRateId }; - } + Plan = plan.StripePlanId, + Quantity = 1 + }); } - } - public class OrganizationPurchaseSubscriptionOptions : OrganizationSubscriptionOptionsBase - { - public OrganizationPurchaseSubscriptionOptions( - Organization org, StaticStore.Plan plan, - TaxInfo taxInfo, int additionalSeats = 0, - int additionalStorageGb = 0, bool premiumAccessAddon = false) : - base(org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon) + if (additionalSeats > 0 && plan.StripeSeatPlanId != null) { - OffSession = true; - TrialPeriodDays = plan.TrialPeriodDays; + Items.Add(new SubscriptionItemOptions + { + Plan = plan.StripeSeatPlanId, + Quantity = additionalSeats + }); } - } - public class OrganizationUpgradeSubscriptionOptions : OrganizationSubscriptionOptionsBase - { - public OrganizationUpgradeSubscriptionOptions( - string customerId, Organization org, - StaticStore.Plan plan, TaxInfo taxInfo, - int additionalSeats = 0, int additionalStorageGb = 0, - bool premiumAccessAddon = false) : - base(org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon) + if (additionalStorageGb > 0) { - Customer = customerId; + Items.Add(new SubscriptionItemOptions + { + Plan = plan.StripeStoragePlanId, + Quantity = additionalStorageGb + }); + } + + if (premiumAccessAddon && plan.StripePremiumAccessPlanId != null) + { + Items.Add(new SubscriptionItemOptions + { + Plan = plan.StripePremiumAccessPlanId, + Quantity = 1 + }); + } + + if (!string.IsNullOrWhiteSpace(taxInfo?.StripeTaxRateId)) + { + DefaultTaxRates = new List { taxInfo.StripeTaxRateId }; } } } + +public class OrganizationPurchaseSubscriptionOptions : OrganizationSubscriptionOptionsBase +{ + public OrganizationPurchaseSubscriptionOptions( + Organization org, StaticStore.Plan plan, + TaxInfo taxInfo, int additionalSeats = 0, + int additionalStorageGb = 0, bool premiumAccessAddon = false) : + base(org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon) + { + OffSession = true; + TrialPeriodDays = plan.TrialPeriodDays; + } +} + +public class OrganizationUpgradeSubscriptionOptions : OrganizationSubscriptionOptionsBase +{ + public OrganizationUpgradeSubscriptionOptions( + string customerId, Organization org, + StaticStore.Plan plan, TaxInfo taxInfo, + int additionalSeats = 0, int additionalStorageGb = 0, + bool premiumAccessAddon = false) : + base(org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon) + { + Customer = customerId; + } +} diff --git a/src/Core/Models/Business/SubscriptionInfo.cs b/src/Core/Models/Business/SubscriptionInfo.cs index e8e339db8..61aa060cd 100644 --- a/src/Core/Models/Business/SubscriptionInfo.cs +++ b/src/Core/Models/Business/SubscriptionInfo.cs @@ -1,87 +1,86 @@ using Stripe; -namespace Bit.Core.Models.Business +namespace Bit.Core.Models.Business; + +public class SubscriptionInfo { - public class SubscriptionInfo + public BillingSubscription Subscription { get; set; } + public BillingUpcomingInvoice UpcomingInvoice { get; set; } + public bool UsingInAppPurchase { get; set; } + + public class BillingSubscription { - public BillingSubscription Subscription { get; set; } - public BillingUpcomingInvoice UpcomingInvoice { get; set; } - public bool UsingInAppPurchase { get; set; } - - public class BillingSubscription + public BillingSubscription(Subscription sub) { - public BillingSubscription(Subscription sub) + Status = sub.Status; + TrialStartDate = sub.TrialStart; + TrialEndDate = sub.TrialEnd; + PeriodStartDate = sub.CurrentPeriodStart; + PeriodEndDate = sub.CurrentPeriodEnd; + CancelledDate = sub.CanceledAt; + CancelAtEndDate = sub.CancelAtPeriodEnd; + Cancelled = sub.Status == "canceled" || sub.Status == "unpaid" || sub.Status == "incomplete_expired"; + if (sub.Items?.Data != null) { - Status = sub.Status; - TrialStartDate = sub.TrialStart; - TrialEndDate = sub.TrialEnd; - PeriodStartDate = sub.CurrentPeriodStart; - PeriodEndDate = sub.CurrentPeriodEnd; - CancelledDate = sub.CanceledAt; - CancelAtEndDate = sub.CancelAtPeriodEnd; - Cancelled = sub.Status == "canceled" || sub.Status == "unpaid" || sub.Status == "incomplete_expired"; - if (sub.Items?.Data != null) - { - Items = sub.Items.Data.Select(i => new BillingSubscriptionItem(i)); - } - } - - public DateTime? TrialStartDate { get; set; } - public DateTime? TrialEndDate { get; set; } - public DateTime? PeriodStartDate { get; set; } - public DateTime? PeriodEndDate { get; set; } - public TimeSpan? PeriodDuration => PeriodEndDate - PeriodStartDate; - public DateTime? CancelledDate { get; set; } - public bool CancelAtEndDate { get; set; } - public string Status { get; set; } - public bool Cancelled { get; set; } - public IEnumerable Items { get; set; } = new List(); - - public class BillingSubscriptionItem - { - public BillingSubscriptionItem(SubscriptionItem item) - { - if (item.Plan != null) - { - Name = item.Plan.Nickname; - Amount = item.Plan.Amount.GetValueOrDefault() / 100M; - Interval = item.Plan.Interval; - } - - Quantity = (int)item.Quantity; - SponsoredSubscriptionItem = Utilities.StaticStore.SponsoredPlans.Any(p => p.StripePlanId == item.Plan.Id); - } - - public string Name { get; set; } - public decimal Amount { get; set; } - public int Quantity { get; set; } - public string Interval { get; set; } - public bool SponsoredSubscriptionItem { get; set; } + Items = sub.Items.Data.Select(i => new BillingSubscriptionItem(i)); } } - public class BillingUpcomingInvoice + public DateTime? TrialStartDate { get; set; } + public DateTime? TrialEndDate { get; set; } + public DateTime? PeriodStartDate { get; set; } + public DateTime? PeriodEndDate { get; set; } + public TimeSpan? PeriodDuration => PeriodEndDate - PeriodStartDate; + public DateTime? CancelledDate { get; set; } + public bool CancelAtEndDate { get; set; } + public string Status { get; set; } + public bool Cancelled { get; set; } + public IEnumerable Items { get; set; } = new List(); + + public class BillingSubscriptionItem { - public BillingUpcomingInvoice() { } - - public BillingUpcomingInvoice(Invoice inv) + public BillingSubscriptionItem(SubscriptionItem item) { - Amount = inv.AmountDue / 100M; - Date = inv.Created; - } - - public BillingUpcomingInvoice(Braintree.Subscription sub) - { - Amount = sub.NextBillAmount.GetValueOrDefault() + sub.Balance.GetValueOrDefault(); - if (Amount < 0) + if (item.Plan != null) { - Amount = 0; + Name = item.Plan.Nickname; + Amount = item.Plan.Amount.GetValueOrDefault() / 100M; + Interval = item.Plan.Interval; } - Date = sub.NextBillingDate; + + Quantity = (int)item.Quantity; + SponsoredSubscriptionItem = Utilities.StaticStore.SponsoredPlans.Any(p => p.StripePlanId == item.Plan.Id); } + public string Name { get; set; } public decimal Amount { get; set; } - public DateTime? Date { get; set; } + public int Quantity { get; set; } + public string Interval { get; set; } + public bool SponsoredSubscriptionItem { get; set; } } } + + public class BillingUpcomingInvoice + { + public BillingUpcomingInvoice() { } + + public BillingUpcomingInvoice(Invoice inv) + { + Amount = inv.AmountDue / 100M; + Date = inv.Created; + } + + public BillingUpcomingInvoice(Braintree.Subscription sub) + { + Amount = sub.NextBillAmount.GetValueOrDefault() + sub.Balance.GetValueOrDefault(); + if (Amount < 0) + { + Amount = 0; + } + Date = sub.NextBillingDate; + } + + public decimal Amount { get; set; } + public DateTime? Date { get; set; } + } } diff --git a/src/Core/Models/Business/SubscriptionUpdate.cs b/src/Core/Models/Business/SubscriptionUpdate.cs index 56902524d..64b43a8de 100644 --- a/src/Core/Models/Business/SubscriptionUpdate.cs +++ b/src/Core/Models/Business/SubscriptionUpdate.cs @@ -1,210 +1,209 @@ using Bit.Core.Entities; using Stripe; -namespace Bit.Core.Models.Business +namespace Bit.Core.Models.Business; + +public abstract class SubscriptionUpdate { - public abstract class SubscriptionUpdate + protected abstract List PlanIds { get; } + + public abstract List RevertItemsOptions(Subscription subscription); + public abstract List UpgradeItemsOptions(Subscription subscription); + + public bool UpdateNeeded(Subscription subscription) { - protected abstract List PlanIds { get; } - - public abstract List RevertItemsOptions(Subscription subscription); - public abstract List UpgradeItemsOptions(Subscription subscription); - - public bool UpdateNeeded(Subscription subscription) + var upgradeItemsOptions = UpgradeItemsOptions(subscription); + foreach (var upgradeItemOptions in upgradeItemsOptions) { - var upgradeItemsOptions = UpgradeItemsOptions(subscription); - foreach (var upgradeItemOptions in upgradeItemsOptions) + var upgradeQuantity = upgradeItemOptions.Quantity ?? 0; + var existingQuantity = SubscriptionItem(subscription, upgradeItemOptions.Plan)?.Quantity ?? 0; + if (upgradeQuantity != existingQuantity) { - var upgradeQuantity = upgradeItemOptions.Quantity ?? 0; - var existingQuantity = SubscriptionItem(subscription, upgradeItemOptions.Plan)?.Quantity ?? 0; - if (upgradeQuantity != existingQuantity) - { - return true; - } + return true; } - return false; } - - protected static SubscriptionItem SubscriptionItem(Subscription subscription, string planId) => - planId == null ? null : subscription.Items?.Data?.FirstOrDefault(i => i.Plan.Id == planId); + return false; } + protected static SubscriptionItem SubscriptionItem(Subscription subscription, string planId) => + planId == null ? null : subscription.Items?.Data?.FirstOrDefault(i => i.Plan.Id == planId); +} - public class SeatSubscriptionUpdate : SubscriptionUpdate + +public class SeatSubscriptionUpdate : SubscriptionUpdate +{ + private readonly int _previousSeats; + private readonly StaticStore.Plan _plan; + private readonly long? _additionalSeats; + protected override List PlanIds => new() { _plan.StripeSeatPlanId }; + + + public SeatSubscriptionUpdate(Organization organization, StaticStore.Plan plan, long? additionalSeats) { - private readonly int _previousSeats; - private readonly StaticStore.Plan _plan; - private readonly long? _additionalSeats; - protected override List PlanIds => new() { _plan.StripeSeatPlanId }; - - - public SeatSubscriptionUpdate(Organization organization, StaticStore.Plan plan, long? additionalSeats) - { - _plan = plan; - _additionalSeats = additionalSeats; - _previousSeats = organization.Seats ?? 0; - } - - public override List UpgradeItemsOptions(Subscription subscription) - { - var item = SubscriptionItem(subscription, PlanIds.Single()); - return new() - { - new SubscriptionItemOptions - { - Id = item?.Id, - Plan = PlanIds.Single(), - Quantity = _additionalSeats, - Deleted = (item?.Id != null && _additionalSeats == 0) ? true : (bool?)null, - } - }; - } - - public override List RevertItemsOptions(Subscription subscription) - { - - var item = SubscriptionItem(subscription, PlanIds.Single()); - return new() - { - new SubscriptionItemOptions - { - Id = item?.Id, - Plan = PlanIds.Single(), - Quantity = _previousSeats, - Deleted = _previousSeats == 0 ? true : (bool?)null, - } - }; - } + _plan = plan; + _additionalSeats = additionalSeats; + _previousSeats = organization.Seats ?? 0; } - public class StorageSubscriptionUpdate : SubscriptionUpdate + public override List UpgradeItemsOptions(Subscription subscription) { - private long? _prevStorage; - private readonly string _plan; - private readonly long? _additionalStorage; - protected override List PlanIds => new() { _plan }; - - public StorageSubscriptionUpdate(string plan, long? additionalStorage) + var item = SubscriptionItem(subscription, PlanIds.Single()); + return new() { - _plan = plan; - _additionalStorage = additionalStorage; - } - - public override List UpgradeItemsOptions(Subscription subscription) - { - var item = SubscriptionItem(subscription, PlanIds.Single()); - _prevStorage = item?.Quantity ?? 0; - return new() + new SubscriptionItemOptions { - new SubscriptionItemOptions - { - Id = item?.Id, - Plan = _plan, - Quantity = _additionalStorage, - Deleted = (item?.Id != null && _additionalStorage == 0) ? true : (bool?)null, - } - }; - } - - public override List RevertItemsOptions(Subscription subscription) - { - if (!_prevStorage.HasValue) - { - throw new Exception("Unknown previous value, must first call UpgradeItemsOptions"); + Id = item?.Id, + Plan = PlanIds.Single(), + Quantity = _additionalSeats, + Deleted = (item?.Id != null && _additionalSeats == 0) ? true : (bool?)null, } - - var item = SubscriptionItem(subscription, PlanIds.Single()); - return new() - { - new SubscriptionItemOptions - { - Id = item?.Id, - Plan = _plan, - Quantity = _prevStorage.Value, - Deleted = _prevStorage.Value == 0 ? true : (bool?)null, - } - }; - } + }; } - public class SponsorOrganizationSubscriptionUpdate : SubscriptionUpdate + public override List RevertItemsOptions(Subscription subscription) { - private readonly string _existingPlanStripeId; - private readonly string _sponsoredPlanStripeId; - private readonly bool _applySponsorship; - protected override List PlanIds => new() { _existingPlanStripeId, _sponsoredPlanStripeId }; - public SponsorOrganizationSubscriptionUpdate(StaticStore.Plan existingPlan, StaticStore.SponsoredPlan sponsoredPlan, bool applySponsorship) + var item = SubscriptionItem(subscription, PlanIds.Single()); + return new() { - _existingPlanStripeId = existingPlan.StripePlanId; - _sponsoredPlanStripeId = sponsoredPlan?.StripePlanId; - _applySponsorship = applySponsorship; - } - - public override List RevertItemsOptions(Subscription subscription) - { - var result = new List(); - if (!string.IsNullOrWhiteSpace(AddStripePlanId)) + new SubscriptionItemOptions { - result.Add(new SubscriptionItemOptions - { - Id = AddStripeItem(subscription)?.Id, - Plan = AddStripePlanId, - Quantity = 0, - Deleted = true, - }); + Id = item?.Id, + Plan = PlanIds.Single(), + Quantity = _previousSeats, + Deleted = _previousSeats == 0 ? true : (bool?)null, } - - if (!string.IsNullOrWhiteSpace(RemoveStripePlanId)) - { - result.Add(new SubscriptionItemOptions - { - Id = RemoveStripeItem(subscription)?.Id, - Plan = RemoveStripePlanId, - Quantity = 1, - Deleted = false, - }); - } - return result; - } - - public override List UpgradeItemsOptions(Subscription subscription) - { - var result = new List(); - if (RemoveStripeItem(subscription) != null) - { - result.Add(new SubscriptionItemOptions - { - Id = RemoveStripeItem(subscription)?.Id, - Plan = RemoveStripePlanId, - Quantity = 0, - Deleted = true, - }); - } - - if (!string.IsNullOrWhiteSpace(AddStripePlanId)) - { - result.Add(new SubscriptionItemOptions - { - Id = AddStripeItem(subscription)?.Id, - Plan = AddStripePlanId, - Quantity = 1, - Deleted = false, - }); - } - return result; - } - - private string RemoveStripePlanId => _applySponsorship ? _existingPlanStripeId : _sponsoredPlanStripeId; - private string AddStripePlanId => _applySponsorship ? _sponsoredPlanStripeId : _existingPlanStripeId; - private Stripe.SubscriptionItem RemoveStripeItem(Subscription subscription) => - _applySponsorship ? - SubscriptionItem(subscription, _existingPlanStripeId) : - SubscriptionItem(subscription, _sponsoredPlanStripeId); - private Stripe.SubscriptionItem AddStripeItem(Subscription subscription) => - _applySponsorship ? - SubscriptionItem(subscription, _sponsoredPlanStripeId) : - SubscriptionItem(subscription, _existingPlanStripeId); - + }; } } + +public class StorageSubscriptionUpdate : SubscriptionUpdate +{ + private long? _prevStorage; + private readonly string _plan; + private readonly long? _additionalStorage; + protected override List PlanIds => new() { _plan }; + + public StorageSubscriptionUpdate(string plan, long? additionalStorage) + { + _plan = plan; + _additionalStorage = additionalStorage; + } + + public override List UpgradeItemsOptions(Subscription subscription) + { + var item = SubscriptionItem(subscription, PlanIds.Single()); + _prevStorage = item?.Quantity ?? 0; + return new() + { + new SubscriptionItemOptions + { + Id = item?.Id, + Plan = _plan, + Quantity = _additionalStorage, + Deleted = (item?.Id != null && _additionalStorage == 0) ? true : (bool?)null, + } + }; + } + + public override List RevertItemsOptions(Subscription subscription) + { + if (!_prevStorage.HasValue) + { + throw new Exception("Unknown previous value, must first call UpgradeItemsOptions"); + } + + var item = SubscriptionItem(subscription, PlanIds.Single()); + return new() + { + new SubscriptionItemOptions + { + Id = item?.Id, + Plan = _plan, + Quantity = _prevStorage.Value, + Deleted = _prevStorage.Value == 0 ? true : (bool?)null, + } + }; + } +} + +public class SponsorOrganizationSubscriptionUpdate : SubscriptionUpdate +{ + private readonly string _existingPlanStripeId; + private readonly string _sponsoredPlanStripeId; + private readonly bool _applySponsorship; + protected override List PlanIds => new() { _existingPlanStripeId, _sponsoredPlanStripeId }; + + public SponsorOrganizationSubscriptionUpdate(StaticStore.Plan existingPlan, StaticStore.SponsoredPlan sponsoredPlan, bool applySponsorship) + { + _existingPlanStripeId = existingPlan.StripePlanId; + _sponsoredPlanStripeId = sponsoredPlan?.StripePlanId; + _applySponsorship = applySponsorship; + } + + public override List RevertItemsOptions(Subscription subscription) + { + var result = new List(); + if (!string.IsNullOrWhiteSpace(AddStripePlanId)) + { + result.Add(new SubscriptionItemOptions + { + Id = AddStripeItem(subscription)?.Id, + Plan = AddStripePlanId, + Quantity = 0, + Deleted = true, + }); + } + + if (!string.IsNullOrWhiteSpace(RemoveStripePlanId)) + { + result.Add(new SubscriptionItemOptions + { + Id = RemoveStripeItem(subscription)?.Id, + Plan = RemoveStripePlanId, + Quantity = 1, + Deleted = false, + }); + } + return result; + } + + public override List UpgradeItemsOptions(Subscription subscription) + { + var result = new List(); + if (RemoveStripeItem(subscription) != null) + { + result.Add(new SubscriptionItemOptions + { + Id = RemoveStripeItem(subscription)?.Id, + Plan = RemoveStripePlanId, + Quantity = 0, + Deleted = true, + }); + } + + if (!string.IsNullOrWhiteSpace(AddStripePlanId)) + { + result.Add(new SubscriptionItemOptions + { + Id = AddStripeItem(subscription)?.Id, + Plan = AddStripePlanId, + Quantity = 1, + Deleted = false, + }); + } + return result; + } + + private string RemoveStripePlanId => _applySponsorship ? _existingPlanStripeId : _sponsoredPlanStripeId; + private string AddStripePlanId => _applySponsorship ? _sponsoredPlanStripeId : _existingPlanStripeId; + private Stripe.SubscriptionItem RemoveStripeItem(Subscription subscription) => + _applySponsorship ? + SubscriptionItem(subscription, _existingPlanStripeId) : + SubscriptionItem(subscription, _sponsoredPlanStripeId); + private Stripe.SubscriptionItem AddStripeItem(Subscription subscription) => + _applySponsorship ? + SubscriptionItem(subscription, _sponsoredPlanStripeId) : + SubscriptionItem(subscription, _existingPlanStripeId); + +} diff --git a/src/Core/Models/Business/TaxInfo.cs b/src/Core/Models/Business/TaxInfo.cs index 62d30b8fe..e763b7223 100644 --- a/src/Core/Models/Business/TaxInfo.cs +++ b/src/Core/Models/Business/TaxInfo.cs @@ -1,154 +1,153 @@ -namespace Bit.Core.Models.Business +namespace Bit.Core.Models.Business; + +public class TaxInfo { - public class TaxInfo + private string _taxIdNumber = null; + private string _taxIdType = null; + + public string TaxIdNumber { - private string _taxIdNumber = null; - private string _taxIdType = null; - - public string TaxIdNumber + get => _taxIdNumber; + set { - get => _taxIdNumber; - set - { - _taxIdNumber = value; - _taxIdType = null; - } - } - public string StripeTaxRateId { get; set; } - public string BillingAddressLine1 { get; set; } - public string BillingAddressLine2 { get; set; } - public string BillingAddressCity { get; set; } - public string BillingAddressState { get; set; } - public string BillingAddressPostalCode { get; set; } - public string BillingAddressCountry { get; set; } = "US"; - public string TaxIdType - { - get - { - if (string.IsNullOrWhiteSpace(BillingAddressCountry) || - string.IsNullOrWhiteSpace(TaxIdNumber)) - { - return null; - } - if (!string.IsNullOrWhiteSpace(_taxIdType)) - { - return _taxIdType; - } - - switch (BillingAddressCountry) - { - case "AE": - _taxIdType = "ae_trn"; - break; - case "AU": - _taxIdType = "au_abn"; - break; - case "BR": - _taxIdType = "br_cnpj"; - break; - case "CA": - // May break for those in Québec given the assumption of QST - if (BillingAddressState?.Contains("bec") ?? false) - { - _taxIdType = "ca_qst"; - break; - } - _taxIdType = "ca_bn"; - break; - case "CL": - _taxIdType = "cl_tin"; - break; - case "AT": - case "BE": - case "BG": - case "CY": - case "CZ": - case "DE": - case "DK": - case "EE": - case "ES": - case "FI": - case "FR": - case "GB": - case "GR": - case "HR": - case "HU": - case "IE": - case "IT": - case "LT": - case "LU": - case "LV": - case "MT": - case "NL": - case "PL": - case "PT": - case "RO": - case "SE": - case "SI": - case "SK": - _taxIdType = "eu_vat"; - break; - case "HK": - _taxIdType = "hk_br"; - break; - case "IN": - _taxIdType = "in_gst"; - break; - case "JP": - _taxIdType = "jp_cn"; - break; - case "KR": - _taxIdType = "kr_brn"; - break; - case "LI": - _taxIdType = "li_uid"; - break; - case "MX": - _taxIdType = "mx_rfc"; - break; - case "MY": - _taxIdType = "my_sst"; - break; - case "NO": - _taxIdType = "no_vat"; - break; - case "NZ": - _taxIdType = "nz_gst"; - break; - case "RU": - _taxIdType = "ru_inn"; - break; - case "SA": - _taxIdType = "sa_vat"; - break; - case "SG": - _taxIdType = "sg_gst"; - break; - case "TH": - _taxIdType = "th_vat"; - break; - case "TW": - _taxIdType = "tw_vat"; - break; - case "US": - _taxIdType = "us_ein"; - break; - case "ZA": - _taxIdType = "za_vat"; - break; - default: - _taxIdType = null; - break; - } - - return _taxIdType; - } - } - - public bool HasTaxId - { - get => !string.IsNullOrWhiteSpace(TaxIdNumber) && - !string.IsNullOrWhiteSpace(TaxIdType); + _taxIdNumber = value; + _taxIdType = null; } } + public string StripeTaxRateId { get; set; } + public string BillingAddressLine1 { get; set; } + public string BillingAddressLine2 { get; set; } + public string BillingAddressCity { get; set; } + public string BillingAddressState { get; set; } + public string BillingAddressPostalCode { get; set; } + public string BillingAddressCountry { get; set; } = "US"; + public string TaxIdType + { + get + { + if (string.IsNullOrWhiteSpace(BillingAddressCountry) || + string.IsNullOrWhiteSpace(TaxIdNumber)) + { + return null; + } + if (!string.IsNullOrWhiteSpace(_taxIdType)) + { + return _taxIdType; + } + + switch (BillingAddressCountry) + { + case "AE": + _taxIdType = "ae_trn"; + break; + case "AU": + _taxIdType = "au_abn"; + break; + case "BR": + _taxIdType = "br_cnpj"; + break; + case "CA": + // May break for those in Québec given the assumption of QST + if (BillingAddressState?.Contains("bec") ?? false) + { + _taxIdType = "ca_qst"; + break; + } + _taxIdType = "ca_bn"; + break; + case "CL": + _taxIdType = "cl_tin"; + break; + case "AT": + case "BE": + case "BG": + case "CY": + case "CZ": + case "DE": + case "DK": + case "EE": + case "ES": + case "FI": + case "FR": + case "GB": + case "GR": + case "HR": + case "HU": + case "IE": + case "IT": + case "LT": + case "LU": + case "LV": + case "MT": + case "NL": + case "PL": + case "PT": + case "RO": + case "SE": + case "SI": + case "SK": + _taxIdType = "eu_vat"; + break; + case "HK": + _taxIdType = "hk_br"; + break; + case "IN": + _taxIdType = "in_gst"; + break; + case "JP": + _taxIdType = "jp_cn"; + break; + case "KR": + _taxIdType = "kr_brn"; + break; + case "LI": + _taxIdType = "li_uid"; + break; + case "MX": + _taxIdType = "mx_rfc"; + break; + case "MY": + _taxIdType = "my_sst"; + break; + case "NO": + _taxIdType = "no_vat"; + break; + case "NZ": + _taxIdType = "nz_gst"; + break; + case "RU": + _taxIdType = "ru_inn"; + break; + case "SA": + _taxIdType = "sa_vat"; + break; + case "SG": + _taxIdType = "sg_gst"; + break; + case "TH": + _taxIdType = "th_vat"; + break; + case "TW": + _taxIdType = "tw_vat"; + break; + case "US": + _taxIdType = "us_ein"; + break; + case "ZA": + _taxIdType = "za_vat"; + break; + default: + _taxIdType = null; + break; + } + + return _taxIdType; + } + } + + public bool HasTaxId + { + get => !string.IsNullOrWhiteSpace(TaxIdNumber) && + !string.IsNullOrWhiteSpace(TaxIdType); + } } diff --git a/src/Core/Models/Business/Tokenables/EmergencyAccessInviteTokenable.cs b/src/Core/Models/Business/Tokenables/EmergencyAccessInviteTokenable.cs index f8d7b02b7..9d0e6cafa 100644 --- a/src/Core/Models/Business/Tokenables/EmergencyAccessInviteTokenable.cs +++ b/src/Core/Models/Business/Tokenables/EmergencyAccessInviteTokenable.cs @@ -1,36 +1,35 @@ using System.Text.Json.Serialization; using Bit.Core.Entities; -namespace Bit.Core.Models.Business.Tokenables +namespace Bit.Core.Models.Business.Tokenables; + +public class EmergencyAccessInviteTokenable : Tokens.ExpiringTokenable { - public class EmergencyAccessInviteTokenable : Tokens.ExpiringTokenable + public const string ClearTextPrefix = ""; + public const string DataProtectorPurpose = "EmergencyAccessServiceDataProtector"; + public const string TokenIdentifier = "EmergencyAccessInvite"; + public string Identifier { get; set; } = TokenIdentifier; + public Guid Id { get; set; } + public string Email { get; set; } + + [JsonConstructor] + public EmergencyAccessInviteTokenable(DateTime expirationDate) { - public const string ClearTextPrefix = ""; - public const string DataProtectorPurpose = "EmergencyAccessServiceDataProtector"; - public const string TokenIdentifier = "EmergencyAccessInvite"; - public string Identifier { get; set; } = TokenIdentifier; - public Guid Id { get; set; } - public string Email { get; set; } - - [JsonConstructor] - public EmergencyAccessInviteTokenable(DateTime expirationDate) - { - ExpirationDate = expirationDate; - } - - public EmergencyAccessInviteTokenable(EmergencyAccess user, int hoursTillExpiration) - { - Id = user.Id; - Email = user.Email; - ExpirationDate = DateTime.UtcNow.AddHours(hoursTillExpiration); - } - - public bool IsValid(Guid id, string email) - { - return Id == id && - Email.Equals(email, StringComparison.InvariantCultureIgnoreCase); - } - - protected override bool TokenIsValid() => Identifier == TokenIdentifier && Id != default && !string.IsNullOrWhiteSpace(Email); + ExpirationDate = expirationDate; } + + public EmergencyAccessInviteTokenable(EmergencyAccess user, int hoursTillExpiration) + { + Id = user.Id; + Email = user.Email; + ExpirationDate = DateTime.UtcNow.AddHours(hoursTillExpiration); + } + + public bool IsValid(Guid id, string email) + { + return Id == id && + Email.Equals(email, StringComparison.InvariantCultureIgnoreCase); + } + + protected override bool TokenIsValid() => Identifier == TokenIdentifier && Id != default && !string.IsNullOrWhiteSpace(Email); } diff --git a/src/Core/Models/Business/Tokenables/HCaptchaTokenable.cs b/src/Core/Models/Business/Tokenables/HCaptchaTokenable.cs index 774df7d79..c62c7189a 100644 --- a/src/Core/Models/Business/Tokenables/HCaptchaTokenable.cs +++ b/src/Core/Models/Business/Tokenables/HCaptchaTokenable.cs @@ -2,43 +2,42 @@ using Bit.Core.Entities; using Bit.Core.Tokens; -namespace Bit.Core.Models.Business.Tokenables +namespace Bit.Core.Models.Business.Tokenables; + +public class HCaptchaTokenable : ExpiringTokenable { - public class HCaptchaTokenable : ExpiringTokenable + private const double _tokenLifetimeInHours = (double)5 / 60; // 5 minutes + public const string ClearTextPrefix = "BWCaptchaBypass_"; + public const string DataProtectorPurpose = "CaptchaServiceDataProtector"; + public const string TokenIdentifier = "CaptchaBypassToken"; + + public string Identifier { get; set; } = TokenIdentifier; + public Guid Id { get; set; } + public string Email { get; set; } + + [JsonConstructor] + public HCaptchaTokenable() { - private const double _tokenLifetimeInHours = (double)5 / 60; // 5 minutes - public const string ClearTextPrefix = "BWCaptchaBypass_"; - public const string DataProtectorPurpose = "CaptchaServiceDataProtector"; - public const string TokenIdentifier = "CaptchaBypassToken"; - - public string Identifier { get; set; } = TokenIdentifier; - public Guid Id { get; set; } - public string Email { get; set; } - - [JsonConstructor] - public HCaptchaTokenable() - { - ExpirationDate = DateTime.UtcNow.AddHours(_tokenLifetimeInHours); - } - - public HCaptchaTokenable(User user) : this() - { - Id = user?.Id ?? default; - Email = user?.Email; - } - - public bool TokenIsValid(User user) - { - if (Id == default || Email == default || user == null) - { - return false; - } - - return Id == user.Id && - Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase); - } - - // Validates deserialized - protected override bool TokenIsValid() => Identifier == TokenIdentifier && Id != default && !string.IsNullOrWhiteSpace(Email); + ExpirationDate = DateTime.UtcNow.AddHours(_tokenLifetimeInHours); } + + public HCaptchaTokenable(User user) : this() + { + Id = user?.Id ?? default; + Email = user?.Email; + } + + public bool TokenIsValid(User user) + { + if (Id == default || Email == default || user == null) + { + return false; + } + + return Id == user.Id && + Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase); + } + + // Validates deserialized + protected override bool TokenIsValid() => Identifier == TokenIdentifier && Id != default && !string.IsNullOrWhiteSpace(Email); } diff --git a/src/Core/Models/Business/Tokenables/OrganizationSponsorshipOfferTokenable.cs b/src/Core/Models/Business/Tokenables/OrganizationSponsorshipOfferTokenable.cs index 0360f3542..4bca8e1ca 100644 --- a/src/Core/Models/Business/Tokenables/OrganizationSponsorshipOfferTokenable.cs +++ b/src/Core/Models/Business/Tokenables/OrganizationSponsorshipOfferTokenable.cs @@ -3,55 +3,54 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Tokens; -namespace Bit.Core.Models.Business.Tokenables +namespace Bit.Core.Models.Business.Tokenables; + +public class OrganizationSponsorshipOfferTokenable : Tokenable { - public class OrganizationSponsorshipOfferTokenable : Tokenable + public const string ClearTextPrefix = "BWOrganizationSponsorship_"; + public const string DataProtectorPurpose = "OrganizationSponsorshipDataProtector"; + public const string TokenIdentifier = "OrganizationSponsorshipOfferToken"; + public string Identifier { get; set; } = TokenIdentifier; + public Guid Id { get; set; } + public PlanSponsorshipType SponsorshipType { get; set; } + public string Email { get; set; } + + public override bool Valid => !string.IsNullOrWhiteSpace(Email) && + Identifier == TokenIdentifier && + Id != default; + + + [JsonConstructor] + public OrganizationSponsorshipOfferTokenable() { } + + public OrganizationSponsorshipOfferTokenable(OrganizationSponsorship sponsorship) { - public const string ClearTextPrefix = "BWOrganizationSponsorship_"; - public const string DataProtectorPurpose = "OrganizationSponsorshipDataProtector"; - public const string TokenIdentifier = "OrganizationSponsorshipOfferToken"; - public string Identifier { get; set; } = TokenIdentifier; - public Guid Id { get; set; } - public PlanSponsorshipType SponsorshipType { get; set; } - public string Email { get; set; } - - public override bool Valid => !string.IsNullOrWhiteSpace(Email) && - Identifier == TokenIdentifier && - Id != default; - - - [JsonConstructor] - public OrganizationSponsorshipOfferTokenable() { } - - public OrganizationSponsorshipOfferTokenable(OrganizationSponsorship sponsorship) + if (string.IsNullOrWhiteSpace(sponsorship.OfferedToEmail)) { - if (string.IsNullOrWhiteSpace(sponsorship.OfferedToEmail)) - { - throw new ArgumentException("Invalid OrganizationSponsorship to create a token, OfferedToEmail is required", nameof(sponsorship)); - } - Email = sponsorship.OfferedToEmail; - - if (!sponsorship.PlanSponsorshipType.HasValue) - { - throw new ArgumentException("Invalid OrganizationSponsorship to create a token, PlanSponsorshipType is required", nameof(sponsorship)); - } - SponsorshipType = sponsorship.PlanSponsorshipType.Value; - - if (sponsorship.Id == default) - { - throw new ArgumentException("Invalid OrganizationSponsorship to create a token, Id is required", nameof(sponsorship)); - } - Id = sponsorship.Id; + throw new ArgumentException("Invalid OrganizationSponsorship to create a token, OfferedToEmail is required", nameof(sponsorship)); } + Email = sponsorship.OfferedToEmail; - public bool IsValid(OrganizationSponsorship sponsorship, string currentUserEmail) => - sponsorship != null && - sponsorship.PlanSponsorshipType.HasValue && - SponsorshipType == sponsorship.PlanSponsorshipType.Value && - Id == sponsorship.Id && - !string.IsNullOrWhiteSpace(sponsorship.OfferedToEmail) && - Email.Equals(currentUserEmail, StringComparison.InvariantCultureIgnoreCase) && - Email.Equals(sponsorship.OfferedToEmail, StringComparison.InvariantCultureIgnoreCase); + if (!sponsorship.PlanSponsorshipType.HasValue) + { + throw new ArgumentException("Invalid OrganizationSponsorship to create a token, PlanSponsorshipType is required", nameof(sponsorship)); + } + SponsorshipType = sponsorship.PlanSponsorshipType.Value; + if (sponsorship.Id == default) + { + throw new ArgumentException("Invalid OrganizationSponsorship to create a token, Id is required", nameof(sponsorship)); + } + Id = sponsorship.Id; } + + public bool IsValid(OrganizationSponsorship sponsorship, string currentUserEmail) => + sponsorship != null && + sponsorship.PlanSponsorshipType.HasValue && + SponsorshipType == sponsorship.PlanSponsorshipType.Value && + Id == sponsorship.Id && + !string.IsNullOrWhiteSpace(sponsorship.OfferedToEmail) && + Email.Equals(currentUserEmail, StringComparison.InvariantCultureIgnoreCase) && + Email.Equals(sponsorship.OfferedToEmail, StringComparison.InvariantCultureIgnoreCase); + } diff --git a/src/Core/Models/Business/Tokenables/SsoTokenable.cs b/src/Core/Models/Business/Tokenables/SsoTokenable.cs index 765e6ce59..f6524d2c7 100644 --- a/src/Core/Models/Business/Tokenables/SsoTokenable.cs +++ b/src/Core/Models/Business/Tokenables/SsoTokenable.cs @@ -2,43 +2,42 @@ using Bit.Core.Entities; using Bit.Core.Tokens; -namespace Bit.Core.Models.Business.Tokenables +namespace Bit.Core.Models.Business.Tokenables; + +public class SsoTokenable : ExpiringTokenable { - public class SsoTokenable : ExpiringTokenable + public const string ClearTextPrefix = "BWUserPrefix_"; + public const string DataProtectorPurpose = "SsoTokenDataProtector"; + public const string TokenIdentifier = "ssoToken"; + + public Guid OrganizationId { get; set; } + public string DomainHint { get; set; } + public string Identifier { get; set; } = TokenIdentifier; + + [JsonConstructor] + public SsoTokenable() { } + + public SsoTokenable(Organization organization, double tokenLifetimeInSeconds) : this() { - public const string ClearTextPrefix = "BWUserPrefix_"; - public const string DataProtectorPurpose = "SsoTokenDataProtector"; - public const string TokenIdentifier = "ssoToken"; - - public Guid OrganizationId { get; set; } - public string DomainHint { get; set; } - public string Identifier { get; set; } = TokenIdentifier; - - [JsonConstructor] - public SsoTokenable() { } - - public SsoTokenable(Organization organization, double tokenLifetimeInSeconds) : this() - { - OrganizationId = organization?.Id ?? default; - DomainHint = organization?.Identifier; - ExpirationDate = DateTime.UtcNow.AddSeconds(tokenLifetimeInSeconds); - } - - public bool TokenIsValid(Organization organization) - { - if (OrganizationId == default || DomainHint == default || organization == null || !Valid) - { - return false; - } - - return organization.Identifier.Equals(DomainHint, StringComparison.InvariantCultureIgnoreCase) - && organization.Id.Equals(OrganizationId); - } - - // Validates deserialized - protected override bool TokenIsValid() => - Identifier == TokenIdentifier - && OrganizationId != default - && !string.IsNullOrWhiteSpace(DomainHint); + OrganizationId = organization?.Id ?? default; + DomainHint = organization?.Identifier; + ExpirationDate = DateTime.UtcNow.AddSeconds(tokenLifetimeInSeconds); } + + public bool TokenIsValid(Organization organization) + { + if (OrganizationId == default || DomainHint == default || organization == null || !Valid) + { + return false; + } + + return organization.Identifier.Equals(DomainHint, StringComparison.InvariantCultureIgnoreCase) + && organization.Id.Equals(OrganizationId); + } + + // Validates deserialized + protected override bool TokenIsValid() => + Identifier == TokenIdentifier + && OrganizationId != default + && !string.IsNullOrWhiteSpace(DomainHint); } diff --git a/src/Core/Models/Business/UserLicense.cs b/src/Core/Models/Business/UserLicense.cs index 183bf9576..f079a7183 100644 --- a/src/Core/Models/Business/UserLicense.cs +++ b/src/Core/Models/Business/UserLicense.cs @@ -7,168 +7,167 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Services; -namespace Bit.Core.Models.Business +namespace Bit.Core.Models.Business; + +public class UserLicense : ILicense { - public class UserLicense : ILicense + public UserLicense() + { } + + public UserLicense(User user, SubscriptionInfo subscriptionInfo, ILicensingService licenseService, + int? version = null) { - public UserLicense() - { } + LicenseType = Enums.LicenseType.User; + LicenseKey = user.LicenseKey; + Id = user.Id; + Name = user.Name; + Email = user.Email; + Version = version.GetValueOrDefault(1); + Premium = user.Premium; + MaxStorageGb = user.MaxStorageGb; + Issued = DateTime.UtcNow; + Expires = subscriptionInfo?.UpcomingInvoice?.Date != null ? + subscriptionInfo.UpcomingInvoice.Date.Value.AddDays(7) : + user.PremiumExpirationDate?.AddDays(7); + Refresh = subscriptionInfo?.UpcomingInvoice?.Date; + Trial = (subscriptionInfo?.Subscription?.TrialEndDate.HasValue ?? false) && + subscriptionInfo.Subscription.TrialEndDate.Value > DateTime.UtcNow; - public UserLicense(User user, SubscriptionInfo subscriptionInfo, ILicensingService licenseService, - int? version = null) + Hash = Convert.ToBase64String(ComputeHash()); + Signature = Convert.ToBase64String(licenseService.SignLicense(this)); + } + + public UserLicense(User user, ILicensingService licenseService, int? version = null) + { + LicenseType = Enums.LicenseType.User; + LicenseKey = user.LicenseKey; + Id = user.Id; + Name = user.Name; + Email = user.Email; + Version = version.GetValueOrDefault(1); + Premium = user.Premium; + MaxStorageGb = user.MaxStorageGb; + Issued = DateTime.UtcNow; + Expires = user.PremiumExpirationDate?.AddDays(7); + Refresh = user.PremiumExpirationDate?.Date; + Trial = false; + + Hash = Convert.ToBase64String(ComputeHash()); + Signature = Convert.ToBase64String(licenseService.SignLicense(this)); + } + + public string LicenseKey { get; set; } + public Guid Id { get; set; } + public string Name { get; set; } + public string Email { get; set; } + public bool Premium { get; set; } + public short? MaxStorageGb { get; set; } + public int Version { get; set; } + public DateTime Issued { get; set; } + public DateTime? Refresh { get; set; } + public DateTime? Expires { get; set; } + public bool Trial { get; set; } + public LicenseType? LicenseType { get; set; } + public string Hash { get; set; } + public string Signature { get; set; } + [JsonIgnore] + public byte[] SignatureBytes => Convert.FromBase64String(Signature); + + public byte[] GetDataBytes(bool forHash = false) + { + string data = null; + if (Version == 1) { - LicenseType = Enums.LicenseType.User; - LicenseKey = user.LicenseKey; - Id = user.Id; - Name = user.Name; - Email = user.Email; - Version = version.GetValueOrDefault(1); - Premium = user.Premium; - MaxStorageGb = user.MaxStorageGb; - Issued = DateTime.UtcNow; - Expires = subscriptionInfo?.UpcomingInvoice?.Date != null ? - subscriptionInfo.UpcomingInvoice.Date.Value.AddDays(7) : - user.PremiumExpirationDate?.AddDays(7); - Refresh = subscriptionInfo?.UpcomingInvoice?.Date; - Trial = (subscriptionInfo?.Subscription?.TrialEndDate.HasValue ?? false) && - subscriptionInfo.Subscription.TrialEndDate.Value > DateTime.UtcNow; - - Hash = Convert.ToBase64String(ComputeHash()); - Signature = Convert.ToBase64String(licenseService.SignLicense(this)); - } - - public UserLicense(User user, ILicensingService licenseService, int? version = null) - { - LicenseType = Enums.LicenseType.User; - LicenseKey = user.LicenseKey; - Id = user.Id; - Name = user.Name; - Email = user.Email; - Version = version.GetValueOrDefault(1); - Premium = user.Premium; - MaxStorageGb = user.MaxStorageGb; - Issued = DateTime.UtcNow; - Expires = user.PremiumExpirationDate?.AddDays(7); - Refresh = user.PremiumExpirationDate?.Date; - Trial = false; - - Hash = Convert.ToBase64String(ComputeHash()); - Signature = Convert.ToBase64String(licenseService.SignLicense(this)); - } - - public string LicenseKey { get; set; } - public Guid Id { get; set; } - public string Name { get; set; } - public string Email { get; set; } - public bool Premium { get; set; } - public short? MaxStorageGb { get; set; } - public int Version { get; set; } - public DateTime Issued { get; set; } - public DateTime? Refresh { get; set; } - public DateTime? Expires { get; set; } - public bool Trial { get; set; } - public LicenseType? LicenseType { get; set; } - public string Hash { get; set; } - public string Signature { get; set; } - [JsonIgnore] - public byte[] SignatureBytes => Convert.FromBase64String(Signature); - - public byte[] GetDataBytes(bool forHash = false) - { - string data = null; - if (Version == 1) - { - var props = typeof(UserLicense) - .GetProperties(BindingFlags.Public | BindingFlags.Instance) - .Where(p => - !p.Name.Equals(nameof(Signature)) && - !p.Name.Equals(nameof(SignatureBytes)) && - !p.Name.Equals(nameof(LicenseType)) && + var props = typeof(UserLicense) + .GetProperties(BindingFlags.Public | BindingFlags.Instance) + .Where(p => + !p.Name.Equals(nameof(Signature)) && + !p.Name.Equals(nameof(SignatureBytes)) && + !p.Name.Equals(nameof(LicenseType)) && + ( + !forHash || ( - !forHash || - ( - !p.Name.Equals(nameof(Hash)) && - !p.Name.Equals(nameof(Issued)) && - !p.Name.Equals(nameof(Refresh)) - ) - )) - .OrderBy(p => p.Name) - .Select(p => $"{p.Name}:{Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}") - .Aggregate((c, n) => $"{c}|{n}"); - data = $"license:user|{props}"; - } - else - { - throw new NotSupportedException($"Version {Version} is not supported."); - } - - return Encoding.UTF8.GetBytes(data); + !p.Name.Equals(nameof(Hash)) && + !p.Name.Equals(nameof(Issued)) && + !p.Name.Equals(nameof(Refresh)) + ) + )) + .OrderBy(p => p.Name) + .Select(p => $"{p.Name}:{Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}") + .Aggregate((c, n) => $"{c}|{n}"); + data = $"license:user|{props}"; + } + else + { + throw new NotSupportedException($"Version {Version} is not supported."); } - public byte[] ComputeHash() + return Encoding.UTF8.GetBytes(data); + } + + public byte[] ComputeHash() + { + using (var alg = SHA256.Create()) { - using (var alg = SHA256.Create()) - { - return alg.ComputeHash(GetDataBytes(true)); - } + return alg.ComputeHash(GetDataBytes(true)); + } + } + + public bool CanUse(User user) + { + if (Issued > DateTime.UtcNow || Expires < DateTime.UtcNow) + { + return false; } - public bool CanUse(User user) + if (Version == 1) { - if (Issued > DateTime.UtcNow || Expires < DateTime.UtcNow) - { - return false; - } + return user.EmailVerified && user.Email.Equals(Email, StringComparison.InvariantCultureIgnoreCase); + } + else + { + throw new NotSupportedException($"Version {Version} is not supported."); + } + } - if (Version == 1) - { - return user.EmailVerified && user.Email.Equals(Email, StringComparison.InvariantCultureIgnoreCase); - } - else - { - throw new NotSupportedException($"Version {Version} is not supported."); - } + public bool VerifyData(User user) + { + if (Issued > DateTime.UtcNow || Expires < DateTime.UtcNow) + { + return false; } - public bool VerifyData(User user) + if (Version == 1) { - if (Issued > DateTime.UtcNow || Expires < DateTime.UtcNow) - { - return false; - } + return + user.LicenseKey != null && user.LicenseKey.Equals(LicenseKey) && + user.Premium == Premium && + user.Email.Equals(Email, StringComparison.InvariantCultureIgnoreCase); + } + else + { + throw new NotSupportedException($"Version {Version} is not supported."); + } + } - if (Version == 1) - { - return - user.LicenseKey != null && user.LicenseKey.Equals(LicenseKey) && - user.Premium == Premium && - user.Email.Equals(Email, StringComparison.InvariantCultureIgnoreCase); - } - else - { - throw new NotSupportedException($"Version {Version} is not supported."); - } + public bool VerifySignature(X509Certificate2 certificate) + { + using (var rsa = certificate.GetRSAPublicKey()) + { + return rsa.VerifyData(GetDataBytes(), SignatureBytes, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + } + } + + public byte[] Sign(X509Certificate2 certificate) + { + if (!certificate.HasPrivateKey) + { + throw new InvalidOperationException("You don't have the private key!"); } - public bool VerifySignature(X509Certificate2 certificate) + using (var rsa = certificate.GetRSAPrivateKey()) { - using (var rsa = certificate.GetRSAPublicKey()) - { - return rsa.VerifyData(GetDataBytes(), SignatureBytes, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); - } - } - - public byte[] Sign(X509Certificate2 certificate) - { - if (!certificate.HasPrivateKey) - { - throw new InvalidOperationException("You don't have the private key!"); - } - - using (var rsa = certificate.GetRSAPrivateKey()) - { - return rsa.SignData(GetDataBytes(), HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); - } + return rsa.SignData(GetDataBytes(), HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); } } } diff --git a/src/Core/Models/Data/AttachmentResponseData.cs b/src/Core/Models/Data/AttachmentResponseData.cs index 1a5c0de43..f45125c3d 100644 --- a/src/Core/Models/Data/AttachmentResponseData.cs +++ b/src/Core/Models/Data/AttachmentResponseData.cs @@ -1,12 +1,11 @@ using Bit.Core.Entities; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class AttachmentResponseData { - public class AttachmentResponseData - { - public string Id { get; set; } - public CipherAttachment.MetaData Data { get; set; } - public Cipher Cipher { get; set; } - public string Url { get; set; } - } + public string Id { get; set; } + public CipherAttachment.MetaData Data { get; set; } + public Cipher Cipher { get; set; } + public string Url { get; set; } } diff --git a/src/Core/Models/Data/CipherAttachment.cs b/src/Core/Models/Data/CipherAttachment.cs index a306c76ad..62b46335a 100644 --- a/src/Core/Models/Data/CipherAttachment.cs +++ b/src/Core/Models/Data/CipherAttachment.cs @@ -1,36 +1,35 @@ using System.Text.Json.Serialization; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class CipherAttachment { - public class CipherAttachment + public Guid Id { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public string AttachmentId { get; set; } + public string AttachmentData { get; set; } + + public class MetaData { - public Guid Id { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public string AttachmentId { get; set; } - public string AttachmentData { get; set; } + private long _size; - public class MetaData + // We serialize Size as a string since JSON (or Javascript) doesn't support full precision for long numbers + [JsonNumberHandling(JsonNumberHandling.AllowReadingFromString | JsonNumberHandling.WriteAsString)] + public long Size { - private long _size; - - // We serialize Size as a string since JSON (or Javascript) doesn't support full precision for long numbers - [JsonNumberHandling(JsonNumberHandling.AllowReadingFromString | JsonNumberHandling.WriteAsString)] - public long Size - { - get { return _size; } - set { _size = value; } - } - - public string FileName { get; set; } - public string Key { get; set; } - - public string ContainerName { get; set; } = "attachments"; - public bool Validated { get; set; } = true; - - // This is stored alongside metadata as an identifier. It does not need repeating in serialization - [JsonIgnore] - public string AttachmentId { get; set; } + get { return _size; } + set { _size = value; } } + + public string FileName { get; set; } + public string Key { get; set; } + + public string ContainerName { get; set; } = "attachments"; + public bool Validated { get; set; } = true; + + // This is stored alongside metadata as an identifier. It does not need repeating in serialization + [JsonIgnore] + public string AttachmentId { get; set; } } } diff --git a/src/Core/Models/Data/CipherCardData.cs b/src/Core/Models/Data/CipherCardData.cs index 0d8745eb9..fdfc604da 100644 --- a/src/Core/Models/Data/CipherCardData.cs +++ b/src/Core/Models/Data/CipherCardData.cs @@ -1,14 +1,13 @@ -namespace Bit.Core.Models.Data -{ - public class CipherCardData : CipherData - { - public CipherCardData() { } +namespace Bit.Core.Models.Data; - public string CardholderName { get; set; } - public string Brand { get; set; } - public string Number { get; set; } - public string ExpMonth { get; set; } - public string ExpYear { get; set; } - public string Code { get; set; } - } +public class CipherCardData : CipherData +{ + public CipherCardData() { } + + public string CardholderName { get; set; } + public string Brand { get; set; } + public string Number { get; set; } + public string ExpMonth { get; set; } + public string ExpYear { get; set; } + public string Code { get; set; } } diff --git a/src/Core/Models/Data/CipherData.cs b/src/Core/Models/Data/CipherData.cs index 3c7598f26..9881ed6ba 100644 --- a/src/Core/Models/Data/CipherData.cs +++ b/src/Core/Models/Data/CipherData.cs @@ -1,12 +1,11 @@ -namespace Bit.Core.Models.Data -{ - public abstract class CipherData - { - public CipherData() { } +namespace Bit.Core.Models.Data; - public string Name { get; set; } - public string Notes { get; set; } - public IEnumerable Fields { get; set; } - public IEnumerable PasswordHistory { get; set; } - } +public abstract class CipherData +{ + public CipherData() { } + + public string Name { get; set; } + public string Notes { get; set; } + public IEnumerable Fields { get; set; } + public IEnumerable PasswordHistory { get; set; } } diff --git a/src/Core/Models/Data/CipherDetails.cs b/src/Core/Models/Data/CipherDetails.cs index e7276ac3c..21a636bf7 100644 --- a/src/Core/Models/Data/CipherDetails.cs +++ b/src/Core/Models/Data/CipherDetails.cs @@ -1,10 +1,9 @@ -namespace Core.Models.Data +namespace Core.Models.Data; + +public class CipherDetails : CipherOrganizationDetails { - public class CipherDetails : CipherOrganizationDetails - { - public Guid? FolderId { get; set; } - public bool Favorite { get; set; } - public bool Edit { get; set; } - public bool ViewPassword { get; set; } - } + public Guid? FolderId { get; set; } + public bool Favorite { get; set; } + public bool Edit { get; set; } + public bool ViewPassword { get; set; } } diff --git a/src/Core/Models/Data/CipherFieldData.cs b/src/Core/Models/Data/CipherFieldData.cs index b46d16099..748a478cf 100644 --- a/src/Core/Models/Data/CipherFieldData.cs +++ b/src/Core/Models/Data/CipherFieldData.cs @@ -1,14 +1,13 @@ using Bit.Core.Enums; -namespace Bit.Core.Models.Data -{ - public class CipherFieldData - { - public CipherFieldData() { } +namespace Bit.Core.Models.Data; - public FieldType Type { get; set; } - public string Name { get; set; } - public string Value { get; set; } - public int? LinkedId { get; set; } - } +public class CipherFieldData +{ + public CipherFieldData() { } + + public FieldType Type { get; set; } + public string Name { get; set; } + public string Value { get; set; } + public int? LinkedId { get; set; } } diff --git a/src/Core/Models/Data/CipherIdentityData.cs b/src/Core/Models/Data/CipherIdentityData.cs index 3a5aa70e8..19773424a 100644 --- a/src/Core/Models/Data/CipherIdentityData.cs +++ b/src/Core/Models/Data/CipherIdentityData.cs @@ -1,26 +1,25 @@ -namespace Bit.Core.Models.Data -{ - public class CipherIdentityData : CipherData - { - public CipherIdentityData() { } +namespace Bit.Core.Models.Data; - public string Title { get; set; } - public string FirstName { get; set; } - public string MiddleName { get; set; } - public string LastName { get; set; } - public string Address1 { get; set; } - public string Address2 { get; set; } - public string Address3 { get; set; } - public string City { get; set; } - public string State { get; set; } - public string PostalCode { get; set; } - public string Country { get; set; } - public string Company { get; set; } - public string Email { get; set; } - public string Phone { get; set; } - public string SSN { get; set; } - public string Username { get; set; } - public string PassportNumber { get; set; } - public string LicenseNumber { get; set; } - } +public class CipherIdentityData : CipherData +{ + public CipherIdentityData() { } + + public string Title { get; set; } + public string FirstName { get; set; } + public string MiddleName { get; set; } + public string LastName { get; set; } + public string Address1 { get; set; } + public string Address2 { get; set; } + public string Address3 { get; set; } + public string City { get; set; } + public string State { get; set; } + public string PostalCode { get; set; } + public string Country { get; set; } + public string Company { get; set; } + public string Email { get; set; } + public string Phone { get; set; } + public string SSN { get; set; } + public string Username { get; set; } + public string PassportNumber { get; set; } + public string LicenseNumber { get; set; } } diff --git a/src/Core/Models/Data/CipherLoginData.cs b/src/Core/Models/Data/CipherLoginData.cs index 2a98ff155..d266d7786 100644 --- a/src/Core/Models/Data/CipherLoginData.cs +++ b/src/Core/Models/Data/CipherLoginData.cs @@ -1,31 +1,30 @@ using Bit.Core.Enums; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class CipherLoginData : CipherData { - public class CipherLoginData : CipherData + private string _uri; + + public CipherLoginData() { } + + public string Uri { - private string _uri; + get => Uris?.FirstOrDefault()?.Uri ?? _uri; + set { _uri = value; } + } + public IEnumerable Uris { get; set; } + public string Username { get; set; } + public string Password { get; set; } + public DateTime? PasswordRevisionDate { get; set; } + public string Totp { get; set; } + public bool? AutofillOnPageLoad { get; set; } - public CipherLoginData() { } + public class CipherLoginUriData + { + public CipherLoginUriData() { } - public string Uri - { - get => Uris?.FirstOrDefault()?.Uri ?? _uri; - set { _uri = value; } - } - public IEnumerable Uris { get; set; } - public string Username { get; set; } - public string Password { get; set; } - public DateTime? PasswordRevisionDate { get; set; } - public string Totp { get; set; } - public bool? AutofillOnPageLoad { get; set; } - - public class CipherLoginUriData - { - public CipherLoginUriData() { } - - public string Uri { get; set; } - public UriMatchType? Match { get; set; } = null; - } + public string Uri { get; set; } + public UriMatchType? Match { get; set; } = null; } } diff --git a/src/Core/Models/Data/CipherOrganizationDetails.cs b/src/Core/Models/Data/CipherOrganizationDetails.cs index 522ebdd2f..d2717b30f 100644 --- a/src/Core/Models/Data/CipherOrganizationDetails.cs +++ b/src/Core/Models/Data/CipherOrganizationDetails.cs @@ -1,9 +1,8 @@ using Bit.Core.Entities; -namespace Core.Models.Data +namespace Core.Models.Data; + +public class CipherOrganizationDetails : Cipher { - public class CipherOrganizationDetails : Cipher - { - public bool OrganizationUseTotp { get; set; } - } + public bool OrganizationUseTotp { get; set; } } diff --git a/src/Core/Models/Data/CipherPasswordHistoryData.cs b/src/Core/Models/Data/CipherPasswordHistoryData.cs index 2362572a1..3ea5edab4 100644 --- a/src/Core/Models/Data/CipherPasswordHistoryData.cs +++ b/src/Core/Models/Data/CipherPasswordHistoryData.cs @@ -1,10 +1,9 @@ -namespace Bit.Core.Models.Data -{ - public class CipherPasswordHistoryData - { - public CipherPasswordHistoryData() { } +namespace Bit.Core.Models.Data; - public string Password { get; set; } - public DateTime LastUsedDate { get; set; } - } +public class CipherPasswordHistoryData +{ + public CipherPasswordHistoryData() { } + + public string Password { get; set; } + public DateTime LastUsedDate { get; set; } } diff --git a/src/Core/Models/Data/CipherSecureNoteData.cs b/src/Core/Models/Data/CipherSecureNoteData.cs index 1287e71df..88b7384cd 100644 --- a/src/Core/Models/Data/CipherSecureNoteData.cs +++ b/src/Core/Models/Data/CipherSecureNoteData.cs @@ -1,11 +1,10 @@ using Bit.Core.Enums; -namespace Bit.Core.Models.Data -{ - public class CipherSecureNoteData : CipherData - { - public CipherSecureNoteData() { } +namespace Bit.Core.Models.Data; - public SecureNoteType Type { get; set; } - } +public class CipherSecureNoteData : CipherData +{ + public CipherSecureNoteData() { } + + public SecureNoteType Type { get; set; } } diff --git a/src/Core/Models/Data/CollectionDetails.cs b/src/Core/Models/Data/CollectionDetails.cs index 110acc3e5..4b618749e 100644 --- a/src/Core/Models/Data/CollectionDetails.cs +++ b/src/Core/Models/Data/CollectionDetails.cs @@ -1,10 +1,9 @@ using Bit.Core.Entities; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class CollectionDetails : Collection { - public class CollectionDetails : Collection - { - public bool ReadOnly { get; set; } - public bool HidePasswords { get; set; } - } + public bool ReadOnly { get; set; } + public bool HidePasswords { get; set; } } diff --git a/src/Core/Models/Data/DictionaryEntity.cs b/src/Core/Models/Data/DictionaryEntity.cs index 00b85d6a2..72e6c871c 100644 --- a/src/Core/Models/Data/DictionaryEntity.cs +++ b/src/Core/Models/Data/DictionaryEntity.cs @@ -1,135 +1,134 @@ using System.Collections; using Microsoft.Azure.Cosmos.Table; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class DictionaryEntity : TableEntity, IDictionary { - public class DictionaryEntity : TableEntity, IDictionary + private IDictionary _properties = new Dictionary(); + + public ICollection Values => _properties.Values; + + public EntityProperty this[string key] { - private IDictionary _properties = new Dictionary(); + get => _properties[key]; + set => _properties[key] = value; + } - public ICollection Values => _properties.Values; + public int Count => _properties.Count; - public EntityProperty this[string key] - { - get => _properties[key]; - set => _properties[key] = value; - } + public bool IsReadOnly => _properties.IsReadOnly; - public int Count => _properties.Count; + public ICollection Keys => _properties.Keys; - public bool IsReadOnly => _properties.IsReadOnly; + public override void ReadEntity(IDictionary properties, + OperationContext operationContext) + { + _properties = properties; + } - public ICollection Keys => _properties.Keys; + public override IDictionary WriteEntity(OperationContext operationContext) + { + return _properties; + } - public override void ReadEntity(IDictionary properties, - OperationContext operationContext) - { - _properties = properties; - } + public void Add(string key, EntityProperty value) + { + _properties.Add(key, value); + } - public override IDictionary WriteEntity(OperationContext operationContext) - { - return _properties; - } + public void Add(string key, bool value) + { + _properties.Add(key, new EntityProperty(value)); + } - public void Add(string key, EntityProperty value) - { - _properties.Add(key, value); - } + public void Add(string key, byte[] value) + { + _properties.Add(key, new EntityProperty(value)); + } - public void Add(string key, bool value) - { - _properties.Add(key, new EntityProperty(value)); - } + public void Add(string key, DateTime? value) + { + _properties.Add(key, new EntityProperty(value)); + } - public void Add(string key, byte[] value) - { - _properties.Add(key, new EntityProperty(value)); - } + public void Add(string key, DateTimeOffset? value) + { + _properties.Add(key, new EntityProperty(value)); + } - public void Add(string key, DateTime? value) - { - _properties.Add(key, new EntityProperty(value)); - } + public void Add(string key, double value) + { + _properties.Add(key, new EntityProperty(value)); + } - public void Add(string key, DateTimeOffset? value) - { - _properties.Add(key, new EntityProperty(value)); - } + public void Add(string key, Guid value) + { + _properties.Add(key, new EntityProperty(value)); + } - public void Add(string key, double value) - { - _properties.Add(key, new EntityProperty(value)); - } + public void Add(string key, int value) + { + _properties.Add(key, new EntityProperty(value)); + } - public void Add(string key, Guid value) - { - _properties.Add(key, new EntityProperty(value)); - } + public void Add(string key, long value) + { + _properties.Add(key, new EntityProperty(value)); + } - public void Add(string key, int value) - { - _properties.Add(key, new EntityProperty(value)); - } + public void Add(string key, string value) + { + _properties.Add(key, new EntityProperty(value)); + } - public void Add(string key, long value) - { - _properties.Add(key, new EntityProperty(value)); - } + public void Add(KeyValuePair item) + { + _properties.Add(item); + } - public void Add(string key, string value) - { - _properties.Add(key, new EntityProperty(value)); - } + public bool ContainsKey(string key) + { + return _properties.ContainsKey(key); + } - public void Add(KeyValuePair item) - { - _properties.Add(item); - } + public bool Remove(string key) + { + return _properties.Remove(key); + } - public bool ContainsKey(string key) - { - return _properties.ContainsKey(key); - } + public bool TryGetValue(string key, out EntityProperty value) + { + return _properties.TryGetValue(key, out value); + } - public bool Remove(string key) - { - return _properties.Remove(key); - } + public void Clear() + { + _properties.Clear(); + } - public bool TryGetValue(string key, out EntityProperty value) - { - return _properties.TryGetValue(key, out value); - } + public bool Contains(KeyValuePair item) + { + return _properties.Contains(item); + } - public void Clear() - { - _properties.Clear(); - } + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + _properties.CopyTo(array, arrayIndex); + } - public bool Contains(KeyValuePair item) - { - return _properties.Contains(item); - } + public bool Remove(KeyValuePair item) + { + return _properties.Remove(item); + } - public void CopyTo(KeyValuePair[] array, int arrayIndex) - { - _properties.CopyTo(array, arrayIndex); - } + public IEnumerator> GetEnumerator() + { + return _properties.GetEnumerator(); + } - public bool Remove(KeyValuePair item) - { - return _properties.Remove(item); - } - - public IEnumerator> GetEnumerator() - { - return _properties.GetEnumerator(); - } - - IEnumerator IEnumerable.GetEnumerator() - { - return _properties.GetEnumerator(); - } + IEnumerator IEnumerable.GetEnumerator() + { + return _properties.GetEnumerator(); } } diff --git a/src/Core/Models/Data/EmergencyAccessDetails.cs b/src/Core/Models/Data/EmergencyAccessDetails.cs index 54e5069a0..89b04e3fc 100644 --- a/src/Core/Models/Data/EmergencyAccessDetails.cs +++ b/src/Core/Models/Data/EmergencyAccessDetails.cs @@ -1,12 +1,11 @@ using Bit.Core.Entities; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class EmergencyAccessDetails : EmergencyAccess { - public class EmergencyAccessDetails : EmergencyAccess - { - public string GranteeName { get; set; } - public string GranteeEmail { get; set; } - public string GrantorName { get; set; } - public string GrantorEmail { get; set; } - } + public string GranteeName { get; set; } + public string GranteeEmail { get; set; } + public string GrantorName { get; set; } + public string GrantorEmail { get; set; } } diff --git a/src/Core/Models/Data/EmergencyAccessNotify.cs b/src/Core/Models/Data/EmergencyAccessNotify.cs index 4661a1b49..6eaccd272 100644 --- a/src/Core/Models/Data/EmergencyAccessNotify.cs +++ b/src/Core/Models/Data/EmergencyAccessNotify.cs @@ -1,11 +1,10 @@ using Bit.Core.Entities; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class EmergencyAccessNotify : EmergencyAccess { - public class EmergencyAccessNotify : EmergencyAccess - { - public string GrantorEmail { get; set; } - public string GranteeName { get; set; } - public string GranteeEmail { get; set; } - } + public string GrantorEmail { get; set; } + public string GranteeName { get; set; } + public string GranteeEmail { get; set; } } diff --git a/src/Core/Models/Data/EmergencyAccessViewData.cs b/src/Core/Models/Data/EmergencyAccessViewData.cs index 86260e823..ef9ffb0a2 100644 --- a/src/Core/Models/Data/EmergencyAccessViewData.cs +++ b/src/Core/Models/Data/EmergencyAccessViewData.cs @@ -1,11 +1,10 @@ using Bit.Core.Entities; using Core.Models.Data; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class EmergencyAccessViewData { - public class EmergencyAccessViewData - { - public EmergencyAccess EmergencyAccess { get; set; } - public IEnumerable Ciphers { get; set; } - } + public EmergencyAccess EmergencyAccess { get; set; } + public IEnumerable Ciphers { get; set; } } diff --git a/src/Core/Models/Data/EventMessage.cs b/src/Core/Models/Data/EventMessage.cs index f99330d01..c77eceab0 100644 --- a/src/Core/Models/Data/EventMessage.cs +++ b/src/Core/Models/Data/EventMessage.cs @@ -1,35 +1,34 @@ using Bit.Core.Context; using Bit.Core.Enums; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class EventMessage : IEvent { - public class EventMessage : IEvent + public EventMessage() { } + + public EventMessage(ICurrentContext currentContext) + : base() { - public EventMessage() { } - - public EventMessage(ICurrentContext currentContext) - : base() - { - IpAddress = currentContext.IpAddress; - DeviceType = currentContext.DeviceType; - } - - public DateTime Date { get; set; } - public EventType Type { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public Guid? InstallationId { get; set; } - public Guid? ProviderId { get; set; } - public Guid? CipherId { get; set; } - public Guid? CollectionId { get; set; } - public Guid? GroupId { get; set; } - public Guid? PolicyId { get; set; } - public Guid? OrganizationUserId { get; set; } - public Guid? ProviderUserId { get; set; } - public Guid? ProviderOrganizationId { get; set; } - public Guid? ActingUserId { get; set; } - public DeviceType? DeviceType { get; set; } - public string IpAddress { get; set; } - public Guid? IdempotencyId { get; private set; } = Guid.NewGuid(); + IpAddress = currentContext.IpAddress; + DeviceType = currentContext.DeviceType; } + + public DateTime Date { get; set; } + public EventType Type { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public Guid? InstallationId { get; set; } + public Guid? ProviderId { get; set; } + public Guid? CipherId { get; set; } + public Guid? CollectionId { get; set; } + public Guid? GroupId { get; set; } + public Guid? PolicyId { get; set; } + public Guid? OrganizationUserId { get; set; } + public Guid? ProviderUserId { get; set; } + public Guid? ProviderOrganizationId { get; set; } + public Guid? ActingUserId { get; set; } + public DeviceType? DeviceType { get; set; } + public string IpAddress { get; set; } + public Guid? IdempotencyId { get; private set; } = Guid.NewGuid(); } diff --git a/src/Core/Models/Data/EventTableEntity.cs b/src/Core/Models/Data/EventTableEntity.cs index 83e25b296..182a3171d 100644 --- a/src/Core/Models/Data/EventTableEntity.cs +++ b/src/Core/Models/Data/EventTableEntity.cs @@ -2,154 +2,153 @@ using Bit.Core.Utilities; using Microsoft.Azure.Cosmos.Table; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class EventTableEntity : TableEntity, IEvent { - public class EventTableEntity : TableEntity, IEvent + public EventTableEntity() { } + + private EventTableEntity(IEvent e) { - public EventTableEntity() { } + Date = e.Date; + Type = e.Type; + UserId = e.UserId; + OrganizationId = e.OrganizationId; + InstallationId = e.InstallationId; + ProviderId = e.ProviderId; + CipherId = e.CipherId; + CollectionId = e.CollectionId; + PolicyId = e.PolicyId; + GroupId = e.GroupId; + OrganizationUserId = e.OrganizationUserId; + ProviderUserId = e.ProviderUserId; + ProviderOrganizationId = e.ProviderOrganizationId; + DeviceType = e.DeviceType; + IpAddress = e.IpAddress; + ActingUserId = e.ActingUserId; + } - private EventTableEntity(IEvent e) + public DateTime Date { get; set; } + public EventType Type { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public Guid? InstallationId { get; set; } + public Guid? ProviderId { get; set; } + public Guid? CipherId { get; set; } + public Guid? CollectionId { get; set; } + public Guid? PolicyId { get; set; } + public Guid? GroupId { get; set; } + public Guid? OrganizationUserId { get; set; } + public Guid? ProviderUserId { get; set; } + public Guid? ProviderOrganizationId { get; set; } + public DeviceType? DeviceType { get; set; } + public string IpAddress { get; set; } + public Guid? ActingUserId { get; set; } + + public override IDictionary WriteEntity(OperationContext operationContext) + { + var result = base.WriteEntity(operationContext); + + var typeName = nameof(Type); + if (result.ContainsKey(typeName)) { - Date = e.Date; - Type = e.Type; - UserId = e.UserId; - OrganizationId = e.OrganizationId; - InstallationId = e.InstallationId; - ProviderId = e.ProviderId; - CipherId = e.CipherId; - CollectionId = e.CollectionId; - PolicyId = e.PolicyId; - GroupId = e.GroupId; - OrganizationUserId = e.OrganizationUserId; - ProviderUserId = e.ProviderUserId; - ProviderOrganizationId = e.ProviderOrganizationId; - DeviceType = e.DeviceType; - IpAddress = e.IpAddress; - ActingUserId = e.ActingUserId; + result[typeName] = new EntityProperty((int)Type); + } + else + { + result.Add(typeName, new EntityProperty((int)Type)); } - public DateTime Date { get; set; } - public EventType Type { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public Guid? InstallationId { get; set; } - public Guid? ProviderId { get; set; } - public Guid? CipherId { get; set; } - public Guid? CollectionId { get; set; } - public Guid? PolicyId { get; set; } - public Guid? GroupId { get; set; } - public Guid? OrganizationUserId { get; set; } - public Guid? ProviderUserId { get; set; } - public Guid? ProviderOrganizationId { get; set; } - public DeviceType? DeviceType { get; set; } - public string IpAddress { get; set; } - public Guid? ActingUserId { get; set; } - - public override IDictionary WriteEntity(OperationContext operationContext) + var deviceTypeName = nameof(DeviceType); + if (result.ContainsKey(deviceTypeName)) { - var result = base.WriteEntity(operationContext); - - var typeName = nameof(Type); - if (result.ContainsKey(typeName)) - { - result[typeName] = new EntityProperty((int)Type); - } - else - { - result.Add(typeName, new EntityProperty((int)Type)); - } - - var deviceTypeName = nameof(DeviceType); - if (result.ContainsKey(deviceTypeName)) - { - result[deviceTypeName] = new EntityProperty((int?)DeviceType); - } - else - { - result.Add(deviceTypeName, new EntityProperty((int?)DeviceType)); - } - - return result; + result[deviceTypeName] = new EntityProperty((int?)DeviceType); + } + else + { + result.Add(deviceTypeName, new EntityProperty((int?)DeviceType)); } - public override void ReadEntity(IDictionary properties, - OperationContext operationContext) + return result; + } + + public override void ReadEntity(IDictionary properties, + OperationContext operationContext) + { + base.ReadEntity(properties, operationContext); + + var typeName = nameof(Type); + if (properties.ContainsKey(typeName) && properties[typeName].Int32Value.HasValue) { - base.ReadEntity(properties, operationContext); - - var typeName = nameof(Type); - if (properties.ContainsKey(typeName) && properties[typeName].Int32Value.HasValue) - { - Type = (EventType)properties[typeName].Int32Value.Value; - } - - var deviceTypeName = nameof(DeviceType); - if (properties.ContainsKey(deviceTypeName) && properties[deviceTypeName].Int32Value.HasValue) - { - DeviceType = (DeviceType)properties[deviceTypeName].Int32Value.Value; - } + Type = (EventType)properties[typeName].Int32Value.Value; } - public static List IndexEvent(EventMessage e) + var deviceTypeName = nameof(DeviceType); + if (properties.ContainsKey(deviceTypeName) && properties[deviceTypeName].Int32Value.HasValue) { - var uniquifier = e.IdempotencyId.GetValueOrDefault(Guid.NewGuid()); - - var pKey = GetPartitionKey(e); - - var dateKey = CoreHelpers.DateTimeToTableStorageKey(e.Date); - - var entities = new List - { - new EventTableEntity(e) - { - PartitionKey = pKey, - RowKey = $"Date={dateKey}__Uniquifier={uniquifier}" - } - }; - - if (e.OrganizationId.HasValue && e.ActingUserId.HasValue) - { - entities.Add(new EventTableEntity(e) - { - PartitionKey = pKey, - RowKey = $"ActingUserId={e.ActingUserId}__Date={dateKey}__Uniquifier={uniquifier}" - }); - } - - if (!e.OrganizationId.HasValue && e.ProviderId.HasValue && e.ActingUserId.HasValue) - { - entities.Add(new EventTableEntity(e) - { - PartitionKey = pKey, - RowKey = $"ActingUserId={e.ActingUserId}__Date={dateKey}__Uniquifier={uniquifier}" - }); - } - - if (e.CipherId.HasValue) - { - entities.Add(new EventTableEntity(e) - { - PartitionKey = pKey, - RowKey = $"CipherId={e.CipherId}__Date={dateKey}__Uniquifier={uniquifier}" - }); - } - - return entities; - } - - private static string GetPartitionKey(EventMessage e) - { - if (e.OrganizationId.HasValue) - { - return $"OrganizationId={e.OrganizationId}"; - } - - if (e.ProviderId.HasValue) - { - return $"ProviderId={e.ProviderId}"; - } - - return $"UserId={e.UserId}"; + DeviceType = (DeviceType)properties[deviceTypeName].Int32Value.Value; } } + + public static List IndexEvent(EventMessage e) + { + var uniquifier = e.IdempotencyId.GetValueOrDefault(Guid.NewGuid()); + + var pKey = GetPartitionKey(e); + + var dateKey = CoreHelpers.DateTimeToTableStorageKey(e.Date); + + var entities = new List + { + new EventTableEntity(e) + { + PartitionKey = pKey, + RowKey = $"Date={dateKey}__Uniquifier={uniquifier}" + } + }; + + if (e.OrganizationId.HasValue && e.ActingUserId.HasValue) + { + entities.Add(new EventTableEntity(e) + { + PartitionKey = pKey, + RowKey = $"ActingUserId={e.ActingUserId}__Date={dateKey}__Uniquifier={uniquifier}" + }); + } + + if (!e.OrganizationId.HasValue && e.ProviderId.HasValue && e.ActingUserId.HasValue) + { + entities.Add(new EventTableEntity(e) + { + PartitionKey = pKey, + RowKey = $"ActingUserId={e.ActingUserId}__Date={dateKey}__Uniquifier={uniquifier}" + }); + } + + if (e.CipherId.HasValue) + { + entities.Add(new EventTableEntity(e) + { + PartitionKey = pKey, + RowKey = $"CipherId={e.CipherId}__Date={dateKey}__Uniquifier={uniquifier}" + }); + } + + return entities; + } + + private static string GetPartitionKey(EventMessage e) + { + if (e.OrganizationId.HasValue) + { + return $"OrganizationId={e.OrganizationId}"; + } + + if (e.ProviderId.HasValue) + { + return $"ProviderId={e.ProviderId}"; + } + + return $"UserId={e.UserId}"; + } } diff --git a/src/Core/Models/Data/GroupWithCollections.cs b/src/Core/Models/Data/GroupWithCollections.cs index 958f70feb..3fa08bc45 100644 --- a/src/Core/Models/Data/GroupWithCollections.cs +++ b/src/Core/Models/Data/GroupWithCollections.cs @@ -1,10 +1,9 @@ using System.Data; using Bit.Core.Entities; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class GroupWithCollections : Group { - public class GroupWithCollections : Group - { - public DataTable Collections { get; set; } - } + public DataTable Collections { get; set; } } diff --git a/src/Core/Models/Data/IEvent.cs b/src/Core/Models/Data/IEvent.cs index 860c6d446..82d8f74ba 100644 --- a/src/Core/Models/Data/IEvent.cs +++ b/src/Core/Models/Data/IEvent.cs @@ -1,24 +1,23 @@ using Bit.Core.Enums; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public interface IEvent { - public interface IEvent - { - EventType Type { get; set; } - Guid? UserId { get; set; } - Guid? OrganizationId { get; set; } - Guid? InstallationId { get; set; } - Guid? ProviderId { get; set; } - Guid? CipherId { get; set; } - Guid? CollectionId { get; set; } - Guid? GroupId { get; set; } - Guid? PolicyId { get; set; } - Guid? OrganizationUserId { get; set; } - Guid? ProviderUserId { get; set; } - Guid? ProviderOrganizationId { get; set; } - Guid? ActingUserId { get; set; } - DeviceType? DeviceType { get; set; } - string IpAddress { get; set; } - DateTime Date { get; set; } - } + EventType Type { get; set; } + Guid? UserId { get; set; } + Guid? OrganizationId { get; set; } + Guid? InstallationId { get; set; } + Guid? ProviderId { get; set; } + Guid? CipherId { get; set; } + Guid? CollectionId { get; set; } + Guid? GroupId { get; set; } + Guid? PolicyId { get; set; } + Guid? OrganizationUserId { get; set; } + Guid? ProviderUserId { get; set; } + Guid? ProviderOrganizationId { get; set; } + Guid? ActingUserId { get; set; } + DeviceType? DeviceType { get; set; } + string IpAddress { get; set; } + DateTime Date { get; set; } } diff --git a/src/Core/Models/Data/InstallationDeviceEntity.cs b/src/Core/Models/Data/InstallationDeviceEntity.cs index 0fb81e340..cb7bf0087 100644 --- a/src/Core/Models/Data/InstallationDeviceEntity.cs +++ b/src/Core/Models/Data/InstallationDeviceEntity.cs @@ -1,35 +1,34 @@ using Microsoft.Azure.Cosmos.Table; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class InstallationDeviceEntity : TableEntity { - public class InstallationDeviceEntity : TableEntity + public InstallationDeviceEntity() { } + + public InstallationDeviceEntity(Guid installationId, Guid deviceId) { - public InstallationDeviceEntity() { } + PartitionKey = installationId.ToString(); + RowKey = deviceId.ToString(); + } - public InstallationDeviceEntity(Guid installationId, Guid deviceId) + public InstallationDeviceEntity(string prefixedDeviceId) + { + var parts = prefixedDeviceId.Split("_"); + if (parts.Length < 2) { - PartitionKey = installationId.ToString(); - RowKey = deviceId.ToString(); + throw new ArgumentException("Not enough parts."); } + if (!Guid.TryParse(parts[0], out var installationId) || !Guid.TryParse(parts[1], out var deviceId)) + { + throw new ArgumentException("Could not parse parts."); + } + PartitionKey = parts[0]; + RowKey = parts[1]; + } - public InstallationDeviceEntity(string prefixedDeviceId) - { - var parts = prefixedDeviceId.Split("_"); - if (parts.Length < 2) - { - throw new ArgumentException("Not enough parts."); - } - if (!Guid.TryParse(parts[0], out var installationId) || !Guid.TryParse(parts[1], out var deviceId)) - { - throw new ArgumentException("Could not parse parts."); - } - PartitionKey = parts[0]; - RowKey = parts[1]; - } - - public static bool IsInstallationDeviceId(string deviceId) - { - return deviceId != null && deviceId.Length == 73 && deviceId[36] == '_'; - } + public static bool IsInstallationDeviceId(string deviceId) + { + return deviceId != null && deviceId.Length == 73 && deviceId[36] == '_'; } } diff --git a/src/Core/Models/Data/Organizations/OrganizationAbility.cs b/src/Core/Models/Data/Organizations/OrganizationAbility.cs index 6ec693185..9b9ee8509 100644 --- a/src/Core/Models/Data/Organizations/OrganizationAbility.cs +++ b/src/Core/Models/Data/Organizations/OrganizationAbility.cs @@ -1,35 +1,34 @@ using Bit.Core.Entities; -namespace Bit.Core.Models.Data.Organizations +namespace Bit.Core.Models.Data.Organizations; + +public class OrganizationAbility { - public class OrganizationAbility + public OrganizationAbility() { } + + public OrganizationAbility(Organization organization) { - public OrganizationAbility() { } - - public OrganizationAbility(Organization organization) - { - Id = organization.Id; - UseEvents = organization.UseEvents; - Use2fa = organization.Use2fa; - Using2fa = organization.Use2fa && organization.TwoFactorProviders != null && - organization.TwoFactorProviders != "{}"; - UsersGetPremium = organization.UsersGetPremium; - Enabled = organization.Enabled; - UseSso = organization.UseSso; - UseKeyConnector = organization.UseKeyConnector; - UseScim = organization.UseScim; - UseResetPassword = organization.UseResetPassword; - } - - public Guid Id { get; set; } - public bool UseEvents { get; set; } - public bool Use2fa { get; set; } - public bool Using2fa { get; set; } - public bool UsersGetPremium { get; set; } - public bool Enabled { get; set; } - public bool UseSso { get; set; } - public bool UseKeyConnector { get; set; } - public bool UseScim { get; set; } - public bool UseResetPassword { get; set; } + Id = organization.Id; + UseEvents = organization.UseEvents; + Use2fa = organization.Use2fa; + Using2fa = organization.Use2fa && organization.TwoFactorProviders != null && + organization.TwoFactorProviders != "{}"; + UsersGetPremium = organization.UsersGetPremium; + Enabled = organization.Enabled; + UseSso = organization.UseSso; + UseKeyConnector = organization.UseKeyConnector; + UseScim = organization.UseScim; + UseResetPassword = organization.UseResetPassword; } + + public Guid Id { get; set; } + public bool UseEvents { get; set; } + public bool Use2fa { get; set; } + public bool Using2fa { get; set; } + public bool UsersGetPremium { get; set; } + public bool Enabled { get; set; } + public bool UseSso { get; set; } + public bool UseKeyConnector { get; set; } + public bool UseScim { get; set; } + public bool UseResetPassword { get; set; } } diff --git a/src/Core/Models/Data/Organizations/OrganizationConnections/OrganizationConnectionData.cs b/src/Core/Models/Data/Organizations/OrganizationConnections/OrganizationConnectionData.cs index 272f411c5..3a3edaed4 100644 --- a/src/Core/Models/Data/Organizations/OrganizationConnections/OrganizationConnectionData.cs +++ b/src/Core/Models/Data/Organizations/OrganizationConnections/OrganizationConnectionData.cs @@ -1,32 +1,31 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Models.Data.Organizations.OrganizationConnections +namespace Bit.Core.Models.Data.Organizations.OrganizationConnections; + +public class OrganizationConnectionData where T : new() { - public class OrganizationConnectionData where T : new() + public Guid? Id { get; set; } + public OrganizationConnectionType Type { get; set; } + public Guid OrganizationId { get; set; } + public bool Enabled { get; set; } + public T Config { get; set; } + + public OrganizationConnection ToEntity() { - public Guid? Id { get; set; } - public OrganizationConnectionType Type { get; set; } - public Guid OrganizationId { get; set; } - public bool Enabled { get; set; } - public T Config { get; set; } - - public OrganizationConnection ToEntity() + var result = new OrganizationConnection() { - var result = new OrganizationConnection() - { - Type = Type, - OrganizationId = OrganizationId, - Enabled = Enabled, - }; - result.SetConfig(Config); + Type = Type, + OrganizationId = OrganizationId, + Enabled = Enabled, + }; + result.SetConfig(Config); - if (Id.HasValue) - { - result.Id = Id.Value; - } - - return result; + if (Id.HasValue) + { + result.Id = Id.Value; } + + return result; } } diff --git a/src/Core/Models/Data/Organizations/OrganizationSponsorships/OrganizationSponsorshipData.cs b/src/Core/Models/Data/Organizations/OrganizationSponsorships/OrganizationSponsorshipData.cs index 2a964ec99..927262957 100644 --- a/src/Core/Models/Data/Organizations/OrganizationSponsorships/OrganizationSponsorshipData.cs +++ b/src/Core/Models/Data/Organizations/OrganizationSponsorships/OrganizationSponsorshipData.cs @@ -1,31 +1,30 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Models.Data.Organizations.OrganizationSponsorships -{ - public class OrganizationSponsorshipData - { - public OrganizationSponsorshipData() { } - public OrganizationSponsorshipData(OrganizationSponsorship sponsorship) - { - SponsoringOrganizationUserId = sponsorship.SponsoringOrganizationUserId; - SponsoredOrganizationId = sponsorship.SponsoredOrganizationId; - FriendlyName = sponsorship.FriendlyName; - OfferedToEmail = sponsorship.OfferedToEmail; - PlanSponsorshipType = sponsorship.PlanSponsorshipType.GetValueOrDefault(); - LastSyncDate = sponsorship.LastSyncDate; - ValidUntil = sponsorship.ValidUntil; - ToDelete = sponsorship.ToDelete; - } - public Guid SponsoringOrganizationUserId { get; set; } - public Guid? SponsoredOrganizationId { get; set; } - public string FriendlyName { get; set; } - public string OfferedToEmail { get; set; } - public PlanSponsorshipType PlanSponsorshipType { get; set; } - public DateTime? LastSyncDate { get; set; } - public DateTime? ValidUntil { get; set; } - public bool ToDelete { get; set; } +namespace Bit.Core.Models.Data.Organizations.OrganizationSponsorships; - public bool CloudSponsorshipRemoved { get; set; } +public class OrganizationSponsorshipData +{ + public OrganizationSponsorshipData() { } + public OrganizationSponsorshipData(OrganizationSponsorship sponsorship) + { + SponsoringOrganizationUserId = sponsorship.SponsoringOrganizationUserId; + SponsoredOrganizationId = sponsorship.SponsoredOrganizationId; + FriendlyName = sponsorship.FriendlyName; + OfferedToEmail = sponsorship.OfferedToEmail; + PlanSponsorshipType = sponsorship.PlanSponsorshipType.GetValueOrDefault(); + LastSyncDate = sponsorship.LastSyncDate; + ValidUntil = sponsorship.ValidUntil; + ToDelete = sponsorship.ToDelete; } + public Guid SponsoringOrganizationUserId { get; set; } + public Guid? SponsoredOrganizationId { get; set; } + public string FriendlyName { get; set; } + public string OfferedToEmail { get; set; } + public PlanSponsorshipType PlanSponsorshipType { get; set; } + public DateTime? LastSyncDate { get; set; } + public DateTime? ValidUntil { get; set; } + public bool ToDelete { get; set; } + + public bool CloudSponsorshipRemoved { get; set; } } diff --git a/src/Core/Models/Data/Organizations/OrganizationSponsorships/OrganizationSponsorshipSyncData.cs b/src/Core/Models/Data/Organizations/OrganizationSponsorships/OrganizationSponsorshipSyncData.cs index 29cd20030..8c1018711 100644 --- a/src/Core/Models/Data/Organizations/OrganizationSponsorships/OrganizationSponsorshipSyncData.cs +++ b/src/Core/Models/Data/Organizations/OrganizationSponsorships/OrganizationSponsorshipSyncData.cs @@ -1,9 +1,8 @@ -namespace Bit.Core.Models.Data.Organizations.OrganizationSponsorships +namespace Bit.Core.Models.Data.Organizations.OrganizationSponsorships; + +public class OrganizationSponsorshipSyncData { - public class OrganizationSponsorshipSyncData - { - public string BillingSyncKey { get; set; } - public Guid SponsoringOrganizationCloudId { get; set; } - public IEnumerable SponsorshipsBatch { get; set; } - } + public string BillingSyncKey { get; set; } + public Guid SponsoringOrganizationCloudId { get; set; } + public IEnumerable SponsorshipsBatch { get; set; } } diff --git a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserInviteData.cs b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserInviteData.cs index e23be4468..ff360c10f 100644 --- a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserInviteData.cs +++ b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserInviteData.cs @@ -1,13 +1,12 @@ using Bit.Core.Enums; -namespace Bit.Core.Models.Data.Organizations.OrganizationUsers +namespace Bit.Core.Models.Data.Organizations.OrganizationUsers; + +public class OrganizationUserInviteData { - public class OrganizationUserInviteData - { - public IEnumerable Emails { get; set; } - public OrganizationUserType? Type { get; set; } - public bool AccessAll { get; set; } - public IEnumerable Collections { get; set; } - public Permissions Permissions { get; set; } - } + public IEnumerable Emails { get; set; } + public OrganizationUserType? Type { get; set; } + public bool AccessAll { get; set; } + public IEnumerable Collections { get; set; } + public Permissions Permissions { get; set; } } diff --git a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs index 554e99e6c..c132aee64 100644 --- a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs +++ b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs @@ -1,43 +1,42 @@ -namespace Bit.Core.Models.Data.Organizations.OrganizationUsers +namespace Bit.Core.Models.Data.Organizations.OrganizationUsers; + +public class OrganizationUserOrganizationDetails { - public class OrganizationUserOrganizationDetails - { - public Guid OrganizationId { get; set; } - public Guid? UserId { get; set; } - public string Name { get; set; } - public bool UsePolicies { get; set; } - public bool UseSso { get; set; } - public bool UseKeyConnector { get; set; } - public bool UseScim { get; set; } - public bool UseGroups { get; set; } - public bool UseDirectory { get; set; } - public bool UseEvents { get; set; } - public bool UseTotp { get; set; } - public bool Use2fa { get; set; } - public bool UseApi { get; set; } - public bool UseResetPassword { get; set; } - public bool SelfHost { get; set; } - public bool UsersGetPremium { get; set; } - public int? Seats { get; set; } - public short? MaxCollections { get; set; } - public short? MaxStorageGb { get; set; } - public string Key { get; set; } - public Enums.OrganizationUserStatusType Status { get; set; } - public Enums.OrganizationUserType Type { get; set; } - public bool Enabled { get; set; } - public Enums.PlanType PlanType { get; set; } - public string SsoExternalId { get; set; } - public string Identifier { get; set; } - public string Permissions { get; set; } - public string ResetPasswordKey { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } - public Guid? ProviderId { get; set; } - public string ProviderName { get; set; } - public string FamilySponsorshipFriendlyName { get; set; } - public string SsoConfig { get; set; } - public DateTime? FamilySponsorshipLastSyncDate { get; set; } - public DateTime? FamilySponsorshipValidUntil { get; set; } - public bool? FamilySponsorshipToDelete { get; set; } - } + public Guid OrganizationId { get; set; } + public Guid? UserId { get; set; } + public string Name { get; set; } + public bool UsePolicies { get; set; } + public bool UseSso { get; set; } + public bool UseKeyConnector { get; set; } + public bool UseScim { get; set; } + public bool UseGroups { get; set; } + public bool UseDirectory { get; set; } + public bool UseEvents { get; set; } + public bool UseTotp { get; set; } + public bool Use2fa { get; set; } + public bool UseApi { get; set; } + public bool UseResetPassword { get; set; } + public bool SelfHost { get; set; } + public bool UsersGetPremium { get; set; } + public int? Seats { get; set; } + public short? MaxCollections { get; set; } + public short? MaxStorageGb { get; set; } + public string Key { get; set; } + public Enums.OrganizationUserStatusType Status { get; set; } + public Enums.OrganizationUserType Type { get; set; } + public bool Enabled { get; set; } + public Enums.PlanType PlanType { get; set; } + public string SsoExternalId { get; set; } + public string Identifier { get; set; } + public string Permissions { get; set; } + public string ResetPasswordKey { get; set; } + public string PublicKey { get; set; } + public string PrivateKey { get; set; } + public Guid? ProviderId { get; set; } + public string ProviderName { get; set; } + public string FamilySponsorshipFriendlyName { get; set; } + public string SsoConfig { get; set; } + public DateTime? FamilySponsorshipLastSyncDate { get; set; } + public DateTime? FamilySponsorshipValidUntil { get; set; } + public bool? FamilySponsorshipToDelete { get; set; } } diff --git a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserPublicKey.cs b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserPublicKey.cs index c465f49a0..7c0496787 100644 --- a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserPublicKey.cs +++ b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserPublicKey.cs @@ -1,9 +1,8 @@ -namespace Bit.Core.Models.Data.Organizations.OrganizationUsers +namespace Bit.Core.Models.Data.Organizations.OrganizationUsers; + +public class OrganizationUserPublicKey { - public class OrganizationUserPublicKey - { - public Guid Id { get; set; } - public Guid UserId { get; set; } - public string PublicKey { get; set; } - } + public Guid Id { get; set; } + public Guid UserId { get; set; } + public string PublicKey { get; set; } } diff --git a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserResetPasswordDetails.cs b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserResetPasswordDetails.cs index ccac4a587..66fa27dfd 100644 --- a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserResetPasswordDetails.cs +++ b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserResetPasswordDetails.cs @@ -1,35 +1,34 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Models.Data.Organizations.OrganizationUsers +namespace Bit.Core.Models.Data.Organizations.OrganizationUsers; + +public class OrganizationUserResetPasswordDetails { - public class OrganizationUserResetPasswordDetails + public OrganizationUserResetPasswordDetails(OrganizationUser orgUser, User user, Organization org) { - public OrganizationUserResetPasswordDetails(OrganizationUser orgUser, User user, Organization org) + if (orgUser == null) { - if (orgUser == null) - { - throw new ArgumentNullException(nameof(orgUser)); - } - - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - if (org == null) - { - throw new ArgumentNullException(nameof(org)); - } - - Kdf = user.Kdf; - KdfIterations = user.KdfIterations; - ResetPasswordKey = orgUser.ResetPasswordKey; - EncryptedPrivateKey = org.PrivateKey; + throw new ArgumentNullException(nameof(orgUser)); } - public KdfType Kdf { get; set; } - public int KdfIterations { get; set; } - public string ResetPasswordKey { get; set; } - public string EncryptedPrivateKey { get; set; } + + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + if (org == null) + { + throw new ArgumentNullException(nameof(org)); + } + + Kdf = user.Kdf; + KdfIterations = user.KdfIterations; + ResetPasswordKey = orgUser.ResetPasswordKey; + EncryptedPrivateKey = org.PrivateKey; } + public KdfType Kdf { get; set; } + public int KdfIterations { get; set; } + public string ResetPasswordKey { get; set; } + public string EncryptedPrivateKey { get; set; } } diff --git a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserUserDetails.cs b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserUserDetails.cs index 334ee7417..ff28d1f3c 100644 --- a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserUserDetails.cs +++ b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserUserDetails.cs @@ -1,60 +1,59 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.Models.Data.Organizations.OrganizationUsers +namespace Bit.Core.Models.Data.Organizations.OrganizationUsers; + +public class OrganizationUserUserDetails : IExternal, ITwoFactorProvidersUser { - public class OrganizationUserUserDetails : IExternal, ITwoFactorProvidersUser + private Dictionary _twoFactorProviders; + + public Guid Id { get; set; } + public Guid OrganizationId { get; set; } + public Guid? UserId { get; set; } + public string Name { get; set; } + public string Email { get; set; } + public string TwoFactorProviders { get; set; } + public bool? Premium { get; set; } + public OrganizationUserStatusType Status { get; set; } + public OrganizationUserType Type { get; set; } + public bool AccessAll { get; set; } + public string ExternalId { get; set; } + public string SsoExternalId { get; set; } + public string Permissions { get; set; } + public string ResetPasswordKey { get; set; } + public bool UsesKeyConnector { get; set; } + + public Dictionary GetTwoFactorProviders() { - private Dictionary _twoFactorProviders; - - public Guid Id { get; set; } - public Guid OrganizationId { get; set; } - public Guid? UserId { get; set; } - public string Name { get; set; } - public string Email { get; set; } - public string TwoFactorProviders { get; set; } - public bool? Premium { get; set; } - public OrganizationUserStatusType Status { get; set; } - public OrganizationUserType Type { get; set; } - public bool AccessAll { get; set; } - public string ExternalId { get; set; } - public string SsoExternalId { get; set; } - public string Permissions { get; set; } - public string ResetPasswordKey { get; set; } - public bool UsesKeyConnector { get; set; } - - public Dictionary GetTwoFactorProviders() + if (string.IsNullOrWhiteSpace(TwoFactorProviders)) { - if (string.IsNullOrWhiteSpace(TwoFactorProviders)) - { - return null; - } - - try - { - if (_twoFactorProviders == null) - { - _twoFactorProviders = - JsonHelpers.LegacyDeserialize>( - TwoFactorProviders); - } - - return _twoFactorProviders; - } - catch (Newtonsoft.Json.JsonException) - { - return null; - } + return null; } - public Guid? GetUserId() + try { - return UserId; - } + if (_twoFactorProviders == null) + { + _twoFactorProviders = + JsonHelpers.LegacyDeserialize>( + TwoFactorProviders); + } - public bool GetPremium() + return _twoFactorProviders; + } + catch (Newtonsoft.Json.JsonException) { - return Premium.GetValueOrDefault(false); + return null; } } + + public Guid? GetUserId() + { + return UserId; + } + + public bool GetPremium() + { + return Premium.GetValueOrDefault(false); + } } diff --git a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserWithCollections.cs b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserWithCollections.cs index c96a49f56..d86c6c158 100644 --- a/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserWithCollections.cs +++ b/src/Core/Models/Data/Organizations/OrganizationUsers/OrganizationUserWithCollections.cs @@ -1,10 +1,9 @@ using System.Data; using Bit.Core.Entities; -namespace Bit.Core.Models.Data.Organizations.OrganizationUsers +namespace Bit.Core.Models.Data.Organizations.OrganizationUsers; + +public class OrganizationUserWithCollections : OrganizationUser { - public class OrganizationUserWithCollections : OrganizationUser - { - public DataTable Collections { get; set; } - } + public DataTable Collections { get; set; } } diff --git a/src/Core/Models/Data/Organizations/Policies/IPolicyDataModel.cs b/src/Core/Models/Data/Organizations/Policies/IPolicyDataModel.cs index 1d263cedb..ef8789d48 100644 --- a/src/Core/Models/Data/Organizations/Policies/IPolicyDataModel.cs +++ b/src/Core/Models/Data/Organizations/Policies/IPolicyDataModel.cs @@ -1,6 +1,5 @@ -namespace Bit.Core.Models.Data.Organizations.Policies +namespace Bit.Core.Models.Data.Organizations.Policies; + +public interface IPolicyDataModel { - public interface IPolicyDataModel - { - } } diff --git a/src/Core/Models/Data/Organizations/Policies/ResetPasswordDataModel.cs b/src/Core/Models/Data/Organizations/Policies/ResetPasswordDataModel.cs index c77d8ef01..1931cc5b7 100644 --- a/src/Core/Models/Data/Organizations/Policies/ResetPasswordDataModel.cs +++ b/src/Core/Models/Data/Organizations/Policies/ResetPasswordDataModel.cs @@ -1,10 +1,9 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Models.Data.Organizations.Policies +namespace Bit.Core.Models.Data.Organizations.Policies; + +public class ResetPasswordDataModel : IPolicyDataModel { - public class ResetPasswordDataModel : IPolicyDataModel - { - [Display(Name = "ResetPasswordAutoEnrollCheckbox")] - public bool AutoEnrollEnabled { get; set; } - } + [Display(Name = "ResetPasswordAutoEnrollCheckbox")] + public bool AutoEnrollEnabled { get; set; } } diff --git a/src/Core/Models/Data/Organizations/Policies/SendOptionsPolicyData.cs b/src/Core/Models/Data/Organizations/Policies/SendOptionsPolicyData.cs index d9bb5ef9d..aa9f65166 100644 --- a/src/Core/Models/Data/Organizations/Policies/SendOptionsPolicyData.cs +++ b/src/Core/Models/Data/Organizations/Policies/SendOptionsPolicyData.cs @@ -1,10 +1,9 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Models.Data.Organizations.Policies +namespace Bit.Core.Models.Data.Organizations.Policies; + +public class SendOptionsPolicyData : IPolicyDataModel { - public class SendOptionsPolicyData : IPolicyDataModel - { - [Display(Name = "DisableHideEmail")] - public bool DisableHideEmail { get; set; } - } + [Display(Name = "DisableHideEmail")] + public bool DisableHideEmail { get; set; } } diff --git a/src/Core/Models/Data/PageOptions.cs b/src/Core/Models/Data/PageOptions.cs index 1b354932e..e9f12ece9 100644 --- a/src/Core/Models/Data/PageOptions.cs +++ b/src/Core/Models/Data/PageOptions.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class PageOptions { - public class PageOptions - { - public string ContinuationToken { get; set; } - public int PageSize { get; set; } = 50; - } + public string ContinuationToken { get; set; } + public int PageSize { get; set; } = 50; } diff --git a/src/Core/Models/Data/PagedResult.cs b/src/Core/Models/Data/PagedResult.cs index 1bb7e3cd2..b02044dd8 100644 --- a/src/Core/Models/Data/PagedResult.cs +++ b/src/Core/Models/Data/PagedResult.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class PagedResult { - public class PagedResult - { - public List Data { get; set; } = new List(); - public string ContinuationToken { get; set; } - } + public List Data { get; set; } = new List(); + public string ContinuationToken { get; set; } } diff --git a/src/Core/Models/Data/Permissions.cs b/src/Core/Models/Data/Permissions.cs index 5cb0149a3..49a7e37f0 100644 --- a/src/Core/Models/Data/Permissions.cs +++ b/src/Core/Models/Data/Permissions.cs @@ -1,45 +1,44 @@ using System.Text.Json.Serialization; -namespace Bit.Core.Models.Data -{ - public class Permissions - { - public bool AccessEventLogs { get; set; } - public bool AccessImportExport { get; set; } - public bool AccessReports { get; set; } - [Obsolete("This permission exists for client backwards-compatibility. It should not be used to determine permissions in this repository", true)] - public bool ManageAllCollections => CreateNewCollections && EditAnyCollection && DeleteAnyCollection; - public bool CreateNewCollections { get; set; } - public bool EditAnyCollection { get; set; } - public bool DeleteAnyCollection { get; set; } - [Obsolete("This permission exists for client backwards-compatibility. It should not be used to determine permissions in this repository", true)] - public bool ManageAssignedCollections => EditAssignedCollections && DeleteAssignedCollections; - public bool EditAssignedCollections { get; set; } - public bool DeleteAssignedCollections { get; set; } - public bool ManageGroups { get; set; } - public bool ManagePolicies { get; set; } - public bool ManageSso { get; set; } - public bool ManageUsers { get; set; } - public bool ManageResetPassword { get; set; } - public bool ManageScim { get; set; } +namespace Bit.Core.Models.Data; - [JsonIgnore] - public List<(bool Permission, string ClaimName)> ClaimsMap => new() - { - (AccessEventLogs, "accesseventlogs"), - (AccessImportExport, "accessimportexport"), - (AccessReports, "accessreports"), - (CreateNewCollections, "createnewcollections"), - (EditAnyCollection, "editanycollection"), - (DeleteAnyCollection, "deleteanycollection"), - (EditAssignedCollections, "editassignedcollections"), - (DeleteAssignedCollections, "deleteassignedcollections"), - (ManageGroups, "managegroups"), - (ManagePolicies, "managepolicies"), - (ManageSso, "managesso"), - (ManageUsers, "manageusers"), - (ManageResetPassword, "manageresetpassword"), - (ManageScim, "managescim"), - }; - } +public class Permissions +{ + public bool AccessEventLogs { get; set; } + public bool AccessImportExport { get; set; } + public bool AccessReports { get; set; } + [Obsolete("This permission exists for client backwards-compatibility. It should not be used to determine permissions in this repository", true)] + public bool ManageAllCollections => CreateNewCollections && EditAnyCollection && DeleteAnyCollection; + public bool CreateNewCollections { get; set; } + public bool EditAnyCollection { get; set; } + public bool DeleteAnyCollection { get; set; } + [Obsolete("This permission exists for client backwards-compatibility. It should not be used to determine permissions in this repository", true)] + public bool ManageAssignedCollections => EditAssignedCollections && DeleteAssignedCollections; + public bool EditAssignedCollections { get; set; } + public bool DeleteAssignedCollections { get; set; } + public bool ManageGroups { get; set; } + public bool ManagePolicies { get; set; } + public bool ManageSso { get; set; } + public bool ManageUsers { get; set; } + public bool ManageResetPassword { get; set; } + public bool ManageScim { get; set; } + + [JsonIgnore] + public List<(bool Permission, string ClaimName)> ClaimsMap => new() + { + (AccessEventLogs, "accesseventlogs"), + (AccessImportExport, "accessimportexport"), + (AccessReports, "accessreports"), + (CreateNewCollections, "createnewcollections"), + (EditAnyCollection, "editanycollection"), + (DeleteAnyCollection, "deleteanycollection"), + (EditAssignedCollections, "editassignedcollections"), + (DeleteAssignedCollections, "deleteassignedcollections"), + (ManageGroups, "managegroups"), + (ManagePolicies, "managepolicies"), + (ManageSso, "managesso"), + (ManageUsers, "manageusers"), + (ManageResetPassword, "manageresetpassword"), + (ManageScim, "managescim"), + }; } diff --git a/src/Core/Models/Data/Provider/ProviderAbility.cs b/src/Core/Models/Data/Provider/ProviderAbility.cs index a77203014..b7e45eaed 100644 --- a/src/Core/Models/Data/Provider/ProviderAbility.cs +++ b/src/Core/Models/Data/Provider/ProviderAbility.cs @@ -1,20 +1,19 @@ using Bit.Core.Entities.Provider; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class ProviderAbility { - public class ProviderAbility + public ProviderAbility() { } + + public ProviderAbility(Provider provider) { - public ProviderAbility() { } - - public ProviderAbility(Provider provider) - { - Id = provider.Id; - UseEvents = provider.UseEvents; - Enabled = provider.Enabled; - } - - public Guid Id { get; set; } - public bool UseEvents { get; set; } - public bool Enabled { get; set; } + Id = provider.Id; + UseEvents = provider.UseEvents; + Enabled = provider.Enabled; } + + public Guid Id { get; set; } + public bool UseEvents { get; set; } + public bool Enabled { get; set; } } diff --git a/src/Core/Models/Data/Provider/ProviderOrganizationOrganizationDetails.cs b/src/Core/Models/Data/Provider/ProviderOrganizationOrganizationDetails.cs index 279994df4..923bba4af 100644 --- a/src/Core/Models/Data/Provider/ProviderOrganizationOrganizationDetails.cs +++ b/src/Core/Models/Data/Provider/ProviderOrganizationOrganizationDetails.cs @@ -1,17 +1,16 @@ -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class ProviderOrganizationOrganizationDetails { - public class ProviderOrganizationOrganizationDetails - { - public Guid Id { get; set; } - public Guid ProviderId { get; set; } - public Guid OrganizationId { get; set; } - public string OrganizationName { get; set; } - public string Key { get; set; } - public string Settings { get; set; } - public DateTime CreationDate { get; set; } - public DateTime RevisionDate { get; set; } - public int UserCount { get; set; } - public int? Seats { get; set; } - public string Plan { get; set; } - } + public Guid Id { get; set; } + public Guid ProviderId { get; set; } + public Guid OrganizationId { get; set; } + public string OrganizationName { get; set; } + public string Key { get; set; } + public string Settings { get; set; } + public DateTime CreationDate { get; set; } + public DateTime RevisionDate { get; set; } + public int UserCount { get; set; } + public int? Seats { get; set; } + public string Plan { get; set; } } diff --git a/src/Core/Models/Data/Provider/ProviderUserOrganizationDetails.cs b/src/Core/Models/Data/Provider/ProviderUserOrganizationDetails.cs index ab19931b6..9d0740b73 100644 --- a/src/Core/Models/Data/Provider/ProviderUserOrganizationDetails.cs +++ b/src/Core/Models/Data/Provider/ProviderUserOrganizationDetails.cs @@ -1,37 +1,36 @@ using Bit.Core.Enums.Provider; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class ProviderUserOrganizationDetails { - public class ProviderUserOrganizationDetails - { - public Guid OrganizationId { get; set; } - public Guid? UserId { get; set; } - public string Name { get; set; } - public bool UsePolicies { get; set; } - public bool UseSso { get; set; } - public bool UseKeyConnector { get; set; } - public bool UseScim { get; set; } - public bool UseGroups { get; set; } - public bool UseDirectory { get; set; } - public bool UseEvents { get; set; } - public bool UseTotp { get; set; } - public bool Use2fa { get; set; } - public bool UseApi { get; set; } - public bool UseResetPassword { get; set; } - public bool SelfHost { get; set; } - public bool UsersGetPremium { get; set; } - public int? Seats { get; set; } - public short? MaxCollections { get; set; } - public short? MaxStorageGb { get; set; } - public string Key { get; set; } - public ProviderUserStatusType Status { get; set; } - public ProviderUserType Type { get; set; } - public bool Enabled { get; set; } - public string Identifier { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } - public Guid? ProviderId { get; set; } - public Guid? ProviderUserId { get; set; } - public string ProviderName { get; set; } - } + public Guid OrganizationId { get; set; } + public Guid? UserId { get; set; } + public string Name { get; set; } + public bool UsePolicies { get; set; } + public bool UseSso { get; set; } + public bool UseKeyConnector { get; set; } + public bool UseScim { get; set; } + public bool UseGroups { get; set; } + public bool UseDirectory { get; set; } + public bool UseEvents { get; set; } + public bool UseTotp { get; set; } + public bool Use2fa { get; set; } + public bool UseApi { get; set; } + public bool UseResetPassword { get; set; } + public bool SelfHost { get; set; } + public bool UsersGetPremium { get; set; } + public int? Seats { get; set; } + public short? MaxCollections { get; set; } + public short? MaxStorageGb { get; set; } + public string Key { get; set; } + public ProviderUserStatusType Status { get; set; } + public ProviderUserType Type { get; set; } + public bool Enabled { get; set; } + public string Identifier { get; set; } + public string PublicKey { get; set; } + public string PrivateKey { get; set; } + public Guid? ProviderId { get; set; } + public Guid? ProviderUserId { get; set; } + public string ProviderName { get; set; } } diff --git a/src/Core/Models/Data/Provider/ProviderUserProviderDetails.cs b/src/Core/Models/Data/Provider/ProviderUserProviderDetails.cs index a14a455d9..16f2e1dda 100644 --- a/src/Core/Models/Data/Provider/ProviderUserProviderDetails.cs +++ b/src/Core/Models/Data/Provider/ProviderUserProviderDetails.cs @@ -1,18 +1,17 @@ using Bit.Core.Enums.Provider; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class ProviderUserProviderDetails { - public class ProviderUserProviderDetails - { - public Guid ProviderId { get; set; } - public Guid? UserId { get; set; } - public string Name { get; set; } - public string Key { get; set; } - public ProviderUserStatusType Status { get; set; } - public ProviderUserType Type { get; set; } - public bool Enabled { get; set; } - public string Permissions { get; set; } - public bool UseEvents { get; set; } - public ProviderStatusType ProviderStatus { get; set; } - } + public Guid ProviderId { get; set; } + public Guid? UserId { get; set; } + public string Name { get; set; } + public string Key { get; set; } + public ProviderUserStatusType Status { get; set; } + public ProviderUserType Type { get; set; } + public bool Enabled { get; set; } + public string Permissions { get; set; } + public bool UseEvents { get; set; } + public ProviderStatusType ProviderStatus { get; set; } } diff --git a/src/Core/Models/Data/Provider/ProviderUserPublicKey.cs b/src/Core/Models/Data/Provider/ProviderUserPublicKey.cs index 0be26770c..0b161fd86 100644 --- a/src/Core/Models/Data/Provider/ProviderUserPublicKey.cs +++ b/src/Core/Models/Data/Provider/ProviderUserPublicKey.cs @@ -1,9 +1,8 @@ -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class ProviderUserPublicKey { - public class ProviderUserPublicKey - { - public Guid Id { get; set; } - public Guid UserId { get; set; } - public string PublicKey { get; set; } - } + public Guid Id { get; set; } + public Guid UserId { get; set; } + public string PublicKey { get; set; } } diff --git a/src/Core/Models/Data/Provider/ProviderUserUserDetails.cs b/src/Core/Models/Data/Provider/ProviderUserUserDetails.cs index 6d0c4daa6..51df1d44e 100644 --- a/src/Core/Models/Data/Provider/ProviderUserUserDetails.cs +++ b/src/Core/Models/Data/Provider/ProviderUserUserDetails.cs @@ -1,16 +1,15 @@ using Bit.Core.Enums.Provider; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class ProviderUserUserDetails { - public class ProviderUserUserDetails - { - public Guid Id { get; set; } - public Guid ProviderId { get; set; } - public Guid? UserId { get; set; } - public string Name { get; set; } - public string Email { get; set; } - public ProviderUserStatusType Status { get; set; } - public ProviderUserType Type { get; set; } - public string Permissions { get; set; } - } + public Guid Id { get; set; } + public Guid ProviderId { get; set; } + public Guid? UserId { get; set; } + public string Name { get; set; } + public string Email { get; set; } + public ProviderUserStatusType Status { get; set; } + public ProviderUserType Type { get; set; } + public string Permissions { get; set; } } diff --git a/src/Core/Models/Data/SelectionReadOnly.cs b/src/Core/Models/Data/SelectionReadOnly.cs index b1dd09d71..426abb57f 100644 --- a/src/Core/Models/Data/SelectionReadOnly.cs +++ b/src/Core/Models/Data/SelectionReadOnly.cs @@ -1,9 +1,8 @@ -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class SelectionReadOnly { - public class SelectionReadOnly - { - public Guid Id { get; set; } - public bool ReadOnly { get; set; } - public bool HidePasswords { get; set; } - } + public Guid Id { get; set; } + public bool ReadOnly { get; set; } + public bool HidePasswords { get; set; } } diff --git a/src/Core/Models/Data/SendData.cs b/src/Core/Models/Data/SendData.cs index 956f934ba..7210caae6 100644 --- a/src/Core/Models/Data/SendData.cs +++ b/src/Core/Models/Data/SendData.cs @@ -1,16 +1,15 @@ -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public abstract class SendData { - public abstract class SendData + public SendData() { } + + public SendData(string name, string notes) { - public SendData() { } - - public SendData(string name, string notes) - { - Name = name; - Notes = notes; - } - - public string Name { get; set; } - public string Notes { get; set; } + Name = name; + Notes = notes; } + + public string Name { get; set; } + public string Notes { get; set; } } diff --git a/src/Core/Models/Data/SendFileData.cs b/src/Core/Models/Data/SendFileData.cs index 8ec61ec79..253ee01ce 100644 --- a/src/Core/Models/Data/SendFileData.cs +++ b/src/Core/Models/Data/SendFileData.cs @@ -1,23 +1,22 @@ using System.Text.Json.Serialization; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class SendFileData : SendData { - public class SendFileData : SendData + public SendFileData() { } + + public SendFileData(string name, string notes, string fileName) + : base(name, notes) { - public SendFileData() { } - - public SendFileData(string name, string notes, string fileName) - : base(name, notes) - { - FileName = fileName; - } - - // We serialize Size as a string since JSON (or Javascript) doesn't support full precision for long numbers - [JsonNumberHandling(JsonNumberHandling.WriteAsString | JsonNumberHandling.AllowReadingFromString)] - public long Size { get; set; } - - public string Id { get; set; } - public string FileName { get; set; } - public bool Validated { get; set; } = true; + FileName = fileName; } + + // We serialize Size as a string since JSON (or Javascript) doesn't support full precision for long numbers + [JsonNumberHandling(JsonNumberHandling.WriteAsString | JsonNumberHandling.AllowReadingFromString)] + public long Size { get; set; } + + public string Id { get; set; } + public string FileName { get; set; } + public bool Validated { get; set; } = true; } diff --git a/src/Core/Models/Data/SendTextData.cs b/src/Core/Models/Data/SendTextData.cs index 0e6d30115..2aa6d0481 100644 --- a/src/Core/Models/Data/SendTextData.cs +++ b/src/Core/Models/Data/SendTextData.cs @@ -1,17 +1,16 @@ -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class SendTextData : SendData { - public class SendTextData : SendData + public SendTextData() { } + + public SendTextData(string name, string notes, string text, bool hidden) + : base(name, notes) { - public SendTextData() { } - - public SendTextData(string name, string notes, string text, bool hidden) - : base(name, notes) - { - Text = text; - Hidden = hidden; - } - - public string Text { get; set; } - public bool Hidden { get; set; } + Text = text; + Hidden = hidden; } + + public string Text { get; set; } + public bool Hidden { get; set; } } diff --git a/src/Core/Models/Data/SsoConfigurationData.cs b/src/Core/Models/Data/SsoConfigurationData.cs index 093ec9a8a..844c52146 100644 --- a/src/Core/Models/Data/SsoConfigurationData.cs +++ b/src/Core/Models/Data/SsoConfigurationData.cs @@ -2,125 +2,124 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Authentication.OpenIdConnect; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class SsoConfigurationData { - public class SsoConfigurationData + private static string _oidcSigninPath = "/oidc-signin"; + private static string _oidcSignedOutPath = "/oidc-signedout"; + private static string _saml2ModulePath = "/saml2"; + + public static SsoConfigurationData Deserialize(string data) { - private static string _oidcSigninPath = "/oidc-signin"; - private static string _oidcSignedOutPath = "/oidc-signedout"; - private static string _saml2ModulePath = "/saml2"; + return CoreHelpers.LoadClassFromJsonData(data); + } - public static SsoConfigurationData Deserialize(string data) + public string Serialize() + { + return CoreHelpers.ClassToJsonData(this); + } + + public SsoType ConfigType { get; set; } + + public bool KeyConnectorEnabled { get; set; } + public string KeyConnectorUrl { get; set; } + + // OIDC + public string Authority { get; set; } + public string ClientId { get; set; } + public string ClientSecret { get; set; } + public string MetadataAddress { get; set; } + public OpenIdConnectRedirectBehavior RedirectBehavior { get; set; } + public bool GetClaimsFromUserInfoEndpoint { get; set; } + public string AdditionalScopes { get; set; } + public string AdditionalUserIdClaimTypes { get; set; } + public string AdditionalEmailClaimTypes { get; set; } + public string AdditionalNameClaimTypes { get; set; } + public string AcrValues { get; set; } + public string ExpectedReturnAcrValue { get; set; } + + // SAML2 IDP + public string IdpEntityId { get; set; } + public string IdpSingleSignOnServiceUrl { get; set; } + public string IdpSingleLogoutServiceUrl { get; set; } + public string IdpX509PublicCert { get; set; } + public Saml2BindingType IdpBindingType { get; set; } + public bool IdpAllowUnsolicitedAuthnResponse { get; set; } + public string IdpArtifactResolutionServiceUrl { get => null; set { /*IGNORE*/ } } + public bool IdpDisableOutboundLogoutRequests { get; set; } + public string IdpOutboundSigningAlgorithm { get; set; } + public bool IdpWantAuthnRequestsSigned { get; set; } + + // SAML2 SP + public Saml2NameIdFormat SpNameIdFormat { get; set; } + public string SpOutboundSigningAlgorithm { get; set; } + public Saml2SigningBehavior SpSigningBehavior { get; set; } + public bool SpWantAssertionsSigned { get; set; } + public bool SpValidateCertificates { get; set; } + public string SpMinIncomingSigningAlgorithm { get; set; } + + public static string BuildCallbackPath(string ssoUri = null) + { + return BuildSsoUrl(_oidcSigninPath, ssoUri); + } + + public static string BuildSignedOutCallbackPath(string ssoUri = null) + { + return BuildSsoUrl(_oidcSignedOutPath, ssoUri); + } + + public static string BuildSaml2ModulePath(string ssoUri = null, string scheme = null) + { + return string.Concat(BuildSsoUrl(_saml2ModulePath, ssoUri), + string.IsNullOrWhiteSpace(scheme) ? string.Empty : $"/{scheme}"); + } + + public static string BuildSaml2AcsUrl(string ssoUri = null, string scheme = null) + { + return string.Concat(BuildSaml2ModulePath(ssoUri, scheme), "/Acs"); + } + + public static string BuildSaml2MetadataUrl(string ssoUri = null, string scheme = null) + { + return BuildSaml2ModulePath(ssoUri, scheme); + } + + public IEnumerable GetAdditionalScopes() => AdditionalScopes? + .Split(',')? + .Where(c => !string.IsNullOrWhiteSpace(c))? + .Select(c => c.Trim()) ?? + Array.Empty(); + + public IEnumerable GetAdditionalUserIdClaimTypes() => AdditionalUserIdClaimTypes? + .Split(',')? + .Where(c => !string.IsNullOrWhiteSpace(c))? + .Select(c => c.Trim()) ?? + Array.Empty(); + + public IEnumerable GetAdditionalEmailClaimTypes() => AdditionalEmailClaimTypes? + .Split(',')? + .Where(c => !string.IsNullOrWhiteSpace(c))? + .Select(c => c.Trim()) ?? + Array.Empty(); + + public IEnumerable GetAdditionalNameClaimTypes() => AdditionalNameClaimTypes? + .Split(',')? + .Where(c => !string.IsNullOrWhiteSpace(c))? + .Select(c => c.Trim()) ?? + Array.Empty(); + + private static string BuildSsoUrl(string relativePath, string ssoUri) + { + if (string.IsNullOrWhiteSpace(ssoUri) || + !Uri.IsWellFormedUriString(ssoUri, UriKind.Absolute)) { - return CoreHelpers.LoadClassFromJsonData(data); - } - - public string Serialize() - { - return CoreHelpers.ClassToJsonData(this); - } - - public SsoType ConfigType { get; set; } - - public bool KeyConnectorEnabled { get; set; } - public string KeyConnectorUrl { get; set; } - - // OIDC - public string Authority { get; set; } - public string ClientId { get; set; } - public string ClientSecret { get; set; } - public string MetadataAddress { get; set; } - public OpenIdConnectRedirectBehavior RedirectBehavior { get; set; } - public bool GetClaimsFromUserInfoEndpoint { get; set; } - public string AdditionalScopes { get; set; } - public string AdditionalUserIdClaimTypes { get; set; } - public string AdditionalEmailClaimTypes { get; set; } - public string AdditionalNameClaimTypes { get; set; } - public string AcrValues { get; set; } - public string ExpectedReturnAcrValue { get; set; } - - // SAML2 IDP - public string IdpEntityId { get; set; } - public string IdpSingleSignOnServiceUrl { get; set; } - public string IdpSingleLogoutServiceUrl { get; set; } - public string IdpX509PublicCert { get; set; } - public Saml2BindingType IdpBindingType { get; set; } - public bool IdpAllowUnsolicitedAuthnResponse { get; set; } - public string IdpArtifactResolutionServiceUrl { get => null; set { /*IGNORE*/ } } - public bool IdpDisableOutboundLogoutRequests { get; set; } - public string IdpOutboundSigningAlgorithm { get; set; } - public bool IdpWantAuthnRequestsSigned { get; set; } - - // SAML2 SP - public Saml2NameIdFormat SpNameIdFormat { get; set; } - public string SpOutboundSigningAlgorithm { get; set; } - public Saml2SigningBehavior SpSigningBehavior { get; set; } - public bool SpWantAssertionsSigned { get; set; } - public bool SpValidateCertificates { get; set; } - public string SpMinIncomingSigningAlgorithm { get; set; } - - public static string BuildCallbackPath(string ssoUri = null) - { - return BuildSsoUrl(_oidcSigninPath, ssoUri); - } - - public static string BuildSignedOutCallbackPath(string ssoUri = null) - { - return BuildSsoUrl(_oidcSignedOutPath, ssoUri); - } - - public static string BuildSaml2ModulePath(string ssoUri = null, string scheme = null) - { - return string.Concat(BuildSsoUrl(_saml2ModulePath, ssoUri), - string.IsNullOrWhiteSpace(scheme) ? string.Empty : $"/{scheme}"); - } - - public static string BuildSaml2AcsUrl(string ssoUri = null, string scheme = null) - { - return string.Concat(BuildSaml2ModulePath(ssoUri, scheme), "/Acs"); - } - - public static string BuildSaml2MetadataUrl(string ssoUri = null, string scheme = null) - { - return BuildSaml2ModulePath(ssoUri, scheme); - } - - public IEnumerable GetAdditionalScopes() => AdditionalScopes? - .Split(',')? - .Where(c => !string.IsNullOrWhiteSpace(c))? - .Select(c => c.Trim()) ?? - Array.Empty(); - - public IEnumerable GetAdditionalUserIdClaimTypes() => AdditionalUserIdClaimTypes? - .Split(',')? - .Where(c => !string.IsNullOrWhiteSpace(c))? - .Select(c => c.Trim()) ?? - Array.Empty(); - - public IEnumerable GetAdditionalEmailClaimTypes() => AdditionalEmailClaimTypes? - .Split(',')? - .Where(c => !string.IsNullOrWhiteSpace(c))? - .Select(c => c.Trim()) ?? - Array.Empty(); - - public IEnumerable GetAdditionalNameClaimTypes() => AdditionalNameClaimTypes? - .Split(',')? - .Where(c => !string.IsNullOrWhiteSpace(c))? - .Select(c => c.Trim()) ?? - Array.Empty(); - - private static string BuildSsoUrl(string relativePath, string ssoUri) - { - if (string.IsNullOrWhiteSpace(ssoUri) || - !Uri.IsWellFormedUriString(ssoUri, UriKind.Absolute)) - { - return relativePath; - } - if (Uri.TryCreate(string.Concat(ssoUri.TrimEnd('/'), relativePath), UriKind.Absolute, out var newUri)) - { - return newUri.ToString(); - } return relativePath; } + if (Uri.TryCreate(string.Concat(ssoUri.TrimEnd('/'), relativePath), UriKind.Absolute, out var newUri)) + { + return newUri.ToString(); + } + return relativePath; } } diff --git a/src/Core/Models/Data/UserKdfInformation.cs b/src/Core/Models/Data/UserKdfInformation.cs index 3825006d1..0fa3d6f83 100644 --- a/src/Core/Models/Data/UserKdfInformation.cs +++ b/src/Core/Models/Data/UserKdfInformation.cs @@ -1,10 +1,9 @@ using Bit.Core.Enums; -namespace Bit.Core.Models.Data +namespace Bit.Core.Models.Data; + +public class UserKdfInformation { - public class UserKdfInformation - { - public KdfType Kdf { get; set; } - public int KdfIterations { get; set; } - } + public KdfType Kdf { get; set; } + public int KdfIterations { get; set; } } diff --git a/src/Core/Models/IExternal.cs b/src/Core/Models/IExternal.cs index f6d51add2..e81de1d47 100644 --- a/src/Core/Models/IExternal.cs +++ b/src/Core/Models/IExternal.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Models +namespace Bit.Core.Models; + +public interface IExternal { - public interface IExternal - { - string ExternalId { get; } - } + string ExternalId { get; } } diff --git a/src/Core/Models/ITwoFactorProvidersUser.cs b/src/Core/Models/ITwoFactorProvidersUser.cs index c617960ad..b056ba31c 100644 --- a/src/Core/Models/ITwoFactorProvidersUser.cs +++ b/src/Core/Models/ITwoFactorProvidersUser.cs @@ -1,12 +1,11 @@ using Bit.Core.Enums; -namespace Bit.Core.Models +namespace Bit.Core.Models; + +public interface ITwoFactorProvidersUser { - public interface ITwoFactorProvidersUser - { - string TwoFactorProviders { get; } - Dictionary GetTwoFactorProviders(); - Guid? GetUserId(); - bool GetPremium(); - } + string TwoFactorProviders { get; } + Dictionary GetTwoFactorProviders(); + Guid? GetUserId(); + bool GetPremium(); } diff --git a/src/Core/Models/Mail/AddedCreditViewModel.cs b/src/Core/Models/Mail/AddedCreditViewModel.cs index fe3d99501..6ccfb5fcc 100644 --- a/src/Core/Models/Mail/AddedCreditViewModel.cs +++ b/src/Core/Models/Mail/AddedCreditViewModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class AddedCreditViewModel : BaseMailModel { - public class AddedCreditViewModel : BaseMailModel - { - public decimal Amount { get; set; } - } + public decimal Amount { get; set; } } diff --git a/src/Core/Models/Mail/AdminResetPasswordViewModel.cs b/src/Core/Models/Mail/AdminResetPasswordViewModel.cs index 5f5e859ac..18e257fea 100644 --- a/src/Core/Models/Mail/AdminResetPasswordViewModel.cs +++ b/src/Core/Models/Mail/AdminResetPasswordViewModel.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class AdminResetPasswordViewModel : BaseMailModel { - public class AdminResetPasswordViewModel : BaseMailModel - { - public string UserName { get; set; } - public string OrgName { get; set; } - } + public string UserName { get; set; } + public string OrgName { get; set; } } diff --git a/src/Core/Models/Mail/BaseMailModel.cs b/src/Core/Models/Mail/BaseMailModel.cs index 416e50d25..e3aa4d2c4 100644 --- a/src/Core/Models/Mail/BaseMailModel.cs +++ b/src/Core/Models/Mail/BaseMailModel.cs @@ -1,27 +1,26 @@ -namespace Bit.Core.Models.Mail -{ - public class BaseMailModel - { - public string SiteName { get; set; } - public string WebVaultUrl { get; set; } - public string WebVaultUrlHostname - { - get - { - if (Uri.TryCreate(WebVaultUrl, UriKind.Absolute, out Uri uri)) - { - return uri.Host; - } +namespace Bit.Core.Models.Mail; - return WebVaultUrl; - } - } - public string CurrentYear +public class BaseMailModel +{ + public string SiteName { get; set; } + public string WebVaultUrl { get; set; } + public string WebVaultUrlHostname + { + get { - get + if (Uri.TryCreate(WebVaultUrl, UriKind.Absolute, out Uri uri)) { - return DateTime.UtcNow.Year.ToString(); + return uri.Host; } + + return WebVaultUrl; + } + } + public string CurrentYear + { + get + { + return DateTime.UtcNow.Year.ToString(); } } } diff --git a/src/Core/Models/Mail/ChangeEmailExistsViewModel.cs b/src/Core/Models/Mail/ChangeEmailExistsViewModel.cs index 8eda66882..22367e8f2 100644 --- a/src/Core/Models/Mail/ChangeEmailExistsViewModel.cs +++ b/src/Core/Models/Mail/ChangeEmailExistsViewModel.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class ChangeEmailExistsViewModel : BaseMailModel { - public class ChangeEmailExistsViewModel : BaseMailModel - { - public string FromEmail { get; set; } - public string ToEmail { get; set; } - } + public string FromEmail { get; set; } + public string ToEmail { get; set; } } diff --git a/src/Core/Models/Mail/EmailTokenViewModel.cs b/src/Core/Models/Mail/EmailTokenViewModel.cs index 596fc7c21..561df580e 100644 --- a/src/Core/Models/Mail/EmailTokenViewModel.cs +++ b/src/Core/Models/Mail/EmailTokenViewModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class EmailTokenViewModel : BaseMailModel { - public class EmailTokenViewModel : BaseMailModel - { - public string Token { get; set; } - } + public string Token { get; set; } } diff --git a/src/Core/Models/Mail/EmergencyAccessAcceptedViewModel.cs b/src/Core/Models/Mail/EmergencyAccessAcceptedViewModel.cs index 1073ea859..afe29b984 100644 --- a/src/Core/Models/Mail/EmergencyAccessAcceptedViewModel.cs +++ b/src/Core/Models/Mail/EmergencyAccessAcceptedViewModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class EmergencyAccessAcceptedViewModel : BaseMailModel { - public class EmergencyAccessAcceptedViewModel : BaseMailModel - { - public string GranteeEmail { get; set; } - } + public string GranteeEmail { get; set; } } diff --git a/src/Core/Models/Mail/EmergencyAccessApprovedViewModel.cs b/src/Core/Models/Mail/EmergencyAccessApprovedViewModel.cs index b8cb13b7f..9ad446aab 100644 --- a/src/Core/Models/Mail/EmergencyAccessApprovedViewModel.cs +++ b/src/Core/Models/Mail/EmergencyAccessApprovedViewModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class EmergencyAccessApprovedViewModel : BaseMailModel { - public class EmergencyAccessApprovedViewModel : BaseMailModel - { - public string Name { get; set; } - } + public string Name { get; set; } } diff --git a/src/Core/Models/Mail/EmergencyAccessConfirmedViewModel.cs b/src/Core/Models/Mail/EmergencyAccessConfirmedViewModel.cs index c7f457e33..2ab55a05e 100644 --- a/src/Core/Models/Mail/EmergencyAccessConfirmedViewModel.cs +++ b/src/Core/Models/Mail/EmergencyAccessConfirmedViewModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class EmergencyAccessConfirmedViewModel : BaseMailModel { - public class EmergencyAccessConfirmedViewModel : BaseMailModel - { - public string Name { get; set; } - } + public string Name { get; set; } } diff --git a/src/Core/Models/Mail/EmergencyAccessInvitedViewModel.cs b/src/Core/Models/Mail/EmergencyAccessInvitedViewModel.cs index a211208c4..fa432c5b7 100644 --- a/src/Core/Models/Mail/EmergencyAccessInvitedViewModel.cs +++ b/src/Core/Models/Mail/EmergencyAccessInvitedViewModel.cs @@ -1,11 +1,10 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class EmergencyAccessInvitedViewModel : BaseMailModel { - public class EmergencyAccessInvitedViewModel : BaseMailModel - { - public string Name { get; set; } - public string Id { get; set; } - public string Email { get; set; } - public string Token { get; set; } - public string Url => $"{WebVaultUrl}/accept-emergency?id={Id}&name={Name}&email={Email}&token={Token}"; - } + public string Name { get; set; } + public string Id { get; set; } + public string Email { get; set; } + public string Token { get; set; } + public string Url => $"{WebVaultUrl}/accept-emergency?id={Id}&name={Name}&email={Email}&token={Token}"; } diff --git a/src/Core/Models/Mail/EmergencyAccessRecoveryTimedOutViewModel.cs b/src/Core/Models/Mail/EmergencyAccessRecoveryTimedOutViewModel.cs index 2c0a287ca..dd3ae3dd8 100644 --- a/src/Core/Models/Mail/EmergencyAccessRecoveryTimedOutViewModel.cs +++ b/src/Core/Models/Mail/EmergencyAccessRecoveryTimedOutViewModel.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class EmergencyAccessRecoveryTimedOutViewModel : BaseMailModel { - public class EmergencyAccessRecoveryTimedOutViewModel : BaseMailModel - { - public string Name { get; set; } - public string Action { get; set; } - } + public string Name { get; set; } + public string Action { get; set; } } diff --git a/src/Core/Models/Mail/EmergencyAccessRecoveryViewModel.cs b/src/Core/Models/Mail/EmergencyAccessRecoveryViewModel.cs index bea6059fc..3811b49ff 100644 --- a/src/Core/Models/Mail/EmergencyAccessRecoveryViewModel.cs +++ b/src/Core/Models/Mail/EmergencyAccessRecoveryViewModel.cs @@ -1,9 +1,8 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class EmergencyAccessRecoveryViewModel : BaseMailModel { - public class EmergencyAccessRecoveryViewModel : BaseMailModel - { - public string Name { get; set; } - public string Action { get; set; } - public int DaysLeft { get; set; } - } + public string Name { get; set; } + public string Action { get; set; } + public int DaysLeft { get; set; } } diff --git a/src/Core/Models/Mail/EmergencyAccessRejectedViewModel.cs b/src/Core/Models/Mail/EmergencyAccessRejectedViewModel.cs index 4cf188726..101cb9c16 100644 --- a/src/Core/Models/Mail/EmergencyAccessRejectedViewModel.cs +++ b/src/Core/Models/Mail/EmergencyAccessRejectedViewModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class EmergencyAccessRejectedViewModel : BaseMailModel { - public class EmergencyAccessRejectedViewModel : BaseMailModel - { - public string Name { get; set; } - } + public string Name { get; set; } } diff --git a/src/Core/Models/Mail/FailedAuthAttemptsModel.cs b/src/Core/Models/Mail/FailedAuthAttemptsModel.cs index 030616d35..8ef66061d 100644 --- a/src/Core/Models/Mail/FailedAuthAttemptsModel.cs +++ b/src/Core/Models/Mail/FailedAuthAttemptsModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class FailedAuthAttemptsModel : NewDeviceLoggedInModel { - public class FailedAuthAttemptsModel : NewDeviceLoggedInModel - { - public string AffectedEmail { get; set; } - } + public string AffectedEmail { get; set; } } diff --git a/src/Core/Models/Mail/FamiliesForEnterprise/FamiliesForEnterpriseOfferViewModel.cs b/src/Core/Models/Mail/FamiliesForEnterprise/FamiliesForEnterpriseOfferViewModel.cs index 97f028253..7e9d8ee19 100644 --- a/src/Core/Models/Mail/FamiliesForEnterprise/FamiliesForEnterpriseOfferViewModel.cs +++ b/src/Core/Models/Mail/FamiliesForEnterprise/FamiliesForEnterpriseOfferViewModel.cs @@ -1,17 +1,16 @@ -namespace Bit.Core.Models.Mail.FamiliesForEnterprise +namespace Bit.Core.Models.Mail.FamiliesForEnterprise; + +public class FamiliesForEnterpriseOfferViewModel : BaseMailModel { - public class FamiliesForEnterpriseOfferViewModel : BaseMailModel - { - public string SponsorOrgName { get; set; } - public string SponsoredEmail { get; set; } - public string SponsorshipToken { get; set; } - public bool ExistingAccount { get; set; } - public string Url => string.Concat( - WebVaultUrl, - "/accept-families-for-enterprise", - $"?token={SponsorshipToken}", - $"&email={SponsoredEmail}", - ExistingAccount ? "" : "®ister=true" - ); - } + public string SponsorOrgName { get; set; } + public string SponsoredEmail { get; set; } + public string SponsorshipToken { get; set; } + public bool ExistingAccount { get; set; } + public string Url => string.Concat( + WebVaultUrl, + "/accept-families-for-enterprise", + $"?token={SponsorshipToken}", + $"&email={SponsoredEmail}", + ExistingAccount ? "" : "®ister=true" + ); } diff --git a/src/Core/Models/Mail/FamiliesForEnterprise/FamiliesForEnterpriseSponsorshipRevertingViewModel.cs b/src/Core/Models/Mail/FamiliesForEnterprise/FamiliesForEnterpriseSponsorshipRevertingViewModel.cs index b15717c87..08c445b6f 100644 --- a/src/Core/Models/Mail/FamiliesForEnterprise/FamiliesForEnterpriseSponsorshipRevertingViewModel.cs +++ b/src/Core/Models/Mail/FamiliesForEnterprise/FamiliesForEnterpriseSponsorshipRevertingViewModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Models.Mail.FamiliesForEnterprise +namespace Bit.Core.Models.Mail.FamiliesForEnterprise; + +public class FamiliesForEnterpriseSponsorshipRevertingViewModel : BaseMailModel { - public class FamiliesForEnterpriseSponsorshipRevertingViewModel : BaseMailModel - { - public DateTime ExpirationDate { get; set; } - } + public DateTime ExpirationDate { get; set; } } diff --git a/src/Core/Models/Mail/IMailQueueMessage.cs b/src/Core/Models/Mail/IMailQueueMessage.cs index 37c09c90e..085e811c5 100644 --- a/src/Core/Models/Mail/IMailQueueMessage.cs +++ b/src/Core/Models/Mail/IMailQueueMessage.cs @@ -1,12 +1,11 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public interface IMailQueueMessage { - public interface IMailQueueMessage - { - string Subject { get; set; } - IEnumerable ToEmails { get; set; } - IEnumerable BccEmails { get; set; } - string Category { get; set; } - string TemplateName { get; set; } - object Model { get; set; } - } + string Subject { get; set; } + IEnumerable ToEmails { get; set; } + IEnumerable BccEmails { get; set; } + string Category { get; set; } + string TemplateName { get; set; } + object Model { get; set; } } diff --git a/src/Core/Models/Mail/InvoiceUpcomingViewModel.cs b/src/Core/Models/Mail/InvoiceUpcomingViewModel.cs index 7a3bdacea..29c40bf92 100644 --- a/src/Core/Models/Mail/InvoiceUpcomingViewModel.cs +++ b/src/Core/Models/Mail/InvoiceUpcomingViewModel.cs @@ -1,10 +1,9 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class InvoiceUpcomingViewModel : BaseMailModel { - public class InvoiceUpcomingViewModel : BaseMailModel - { - public decimal AmountDue { get; set; } - public DateTime DueDate { get; set; } - public List Items { get; set; } - public bool MentionInvoices { get; set; } - } + public decimal AmountDue { get; set; } + public DateTime DueDate { get; set; } + public List Items { get; set; } + public bool MentionInvoices { get; set; } } diff --git a/src/Core/Models/Mail/LicenseExpiredViewModel.cs b/src/Core/Models/Mail/LicenseExpiredViewModel.cs index 70f5f32cd..922b35cfb 100644 --- a/src/Core/Models/Mail/LicenseExpiredViewModel.cs +++ b/src/Core/Models/Mail/LicenseExpiredViewModel.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class LicenseExpiredViewModel : BaseMailModel { - public class LicenseExpiredViewModel : BaseMailModel - { - public string OrganizationName { get; set; } - public bool IsOrganization => !string.IsNullOrWhiteSpace(OrganizationName); - } + public string OrganizationName { get; set; } + public bool IsOrganization => !string.IsNullOrWhiteSpace(OrganizationName); } diff --git a/src/Core/Models/Mail/MailMessage.cs b/src/Core/Models/Mail/MailMessage.cs index 1ccb87acf..df444c77f 100644 --- a/src/Core/Models/Mail/MailMessage.cs +++ b/src/Core/Models/Mail/MailMessage.cs @@ -1,13 +1,12 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class MailMessage { - public class MailMessage - { - public string Subject { get; set; } - public IEnumerable ToEmails { get; set; } - public IEnumerable BccEmails { get; set; } - public string HtmlContent { get; set; } - public string TextContent { get; set; } - public string Category { get; set; } - public IDictionary MetaData { get; set; } - } + public string Subject { get; set; } + public IEnumerable ToEmails { get; set; } + public IEnumerable BccEmails { get; set; } + public string HtmlContent { get; set; } + public string TextContent { get; set; } + public string Category { get; set; } + public IDictionary MetaData { get; set; } } diff --git a/src/Core/Models/Mail/MailQueueMessage.cs b/src/Core/Models/Mail/MailQueueMessage.cs index 2aa2b3c65..d413c5f1a 100644 --- a/src/Core/Models/Mail/MailQueueMessage.cs +++ b/src/Core/Models/Mail/MailQueueMessage.cs @@ -1,29 +1,28 @@ using System.Text.Json.Serialization; using Bit.Core.Utilities; -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class MailQueueMessage : IMailQueueMessage { - public class MailQueueMessage : IMailQueueMessage + public string Subject { get; set; } + public IEnumerable ToEmails { get; set; } + public IEnumerable BccEmails { get; set; } + public string Category { get; set; } + public string TemplateName { get; set; } + + [JsonConverter(typeof(HandlebarsObjectJsonConverter))] + public object Model { get; set; } + + public MailQueueMessage() { } + + public MailQueueMessage(MailMessage message, string templateName, object model) { - public string Subject { get; set; } - public IEnumerable ToEmails { get; set; } - public IEnumerable BccEmails { get; set; } - public string Category { get; set; } - public string TemplateName { get; set; } - - [JsonConverter(typeof(HandlebarsObjectJsonConverter))] - public object Model { get; set; } - - public MailQueueMessage() { } - - public MailQueueMessage(MailMessage message, string templateName, object model) - { - Subject = message.Subject; - ToEmails = message.ToEmails; - BccEmails = message.BccEmails; - Category = string.IsNullOrEmpty(message.Category) ? templateName : message.Category; - TemplateName = templateName; - Model = model; - } + Subject = message.Subject; + ToEmails = message.ToEmails; + BccEmails = message.BccEmails; + Category = string.IsNullOrEmpty(message.Category) ? templateName : message.Category; + TemplateName = templateName; + Model = model; } } diff --git a/src/Core/Models/Mail/MasterPasswordHintViewModel.cs b/src/Core/Models/Mail/MasterPasswordHintViewModel.cs index d2cfff49e..01eb883a2 100644 --- a/src/Core/Models/Mail/MasterPasswordHintViewModel.cs +++ b/src/Core/Models/Mail/MasterPasswordHintViewModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class MasterPasswordHintViewModel : BaseMailModel { - public class MasterPasswordHintViewModel : BaseMailModel - { - public string Hint { get; set; } - } + public string Hint { get; set; } } diff --git a/src/Core/Models/Mail/NewDeviceLoggedInModel.cs b/src/Core/Models/Mail/NewDeviceLoggedInModel.cs index ee550fc4e..6d55a19b6 100644 --- a/src/Core/Models/Mail/NewDeviceLoggedInModel.cs +++ b/src/Core/Models/Mail/NewDeviceLoggedInModel.cs @@ -1,11 +1,10 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class NewDeviceLoggedInModel : BaseMailModel { - public class NewDeviceLoggedInModel : BaseMailModel - { - public string TheDate { get; set; } - public string TheTime { get; set; } - public string TimeZone { get; set; } - public string IpAddress { get; set; } - public string DeviceType { get; set; } - } + public string TheDate { get; set; } + public string TheTime { get; set; } + public string TimeZone { get; set; } + public string IpAddress { get; set; } + public string DeviceType { get; set; } } diff --git a/src/Core/Models/Mail/OrganizationSeatsAutoscaledViewModel.cs b/src/Core/Models/Mail/OrganizationSeatsAutoscaledViewModel.cs index 44299c390..87f87b1c6 100644 --- a/src/Core/Models/Mail/OrganizationSeatsAutoscaledViewModel.cs +++ b/src/Core/Models/Mail/OrganizationSeatsAutoscaledViewModel.cs @@ -1,9 +1,8 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class OrganizationSeatsAutoscaledViewModel : BaseMailModel { - public class OrganizationSeatsAutoscaledViewModel : BaseMailModel - { - public Guid OrganizationId { get; set; } - public int InitialSeatCount { get; set; } - public int CurrentSeatCount { get; set; } - } + public Guid OrganizationId { get; set; } + public int InitialSeatCount { get; set; } + public int CurrentSeatCount { get; set; } } diff --git a/src/Core/Models/Mail/OrganizationSeatsMaxReachedViewModel.cs b/src/Core/Models/Mail/OrganizationSeatsMaxReachedViewModel.cs index 5fcdee704..cdfb57b2d 100644 --- a/src/Core/Models/Mail/OrganizationSeatsMaxReachedViewModel.cs +++ b/src/Core/Models/Mail/OrganizationSeatsMaxReachedViewModel.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class OrganizationSeatsMaxReachedViewModel : BaseMailModel { - public class OrganizationSeatsMaxReachedViewModel : BaseMailModel - { - public Guid OrganizationId { get; set; } - public int MaxSeatCount { get; set; } - } + public Guid OrganizationId { get; set; } + public int MaxSeatCount { get; set; } } diff --git a/src/Core/Models/Mail/OrganizationUserAcceptedViewModel.cs b/src/Core/Models/Mail/OrganizationUserAcceptedViewModel.cs index 5bfd502a5..919463c2c 100644 --- a/src/Core/Models/Mail/OrganizationUserAcceptedViewModel.cs +++ b/src/Core/Models/Mail/OrganizationUserAcceptedViewModel.cs @@ -1,9 +1,8 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class OrganizationUserAcceptedViewModel : BaseMailModel { - public class OrganizationUserAcceptedViewModel : BaseMailModel - { - public Guid OrganizationId { get; set; } - public string OrganizationName { get; set; } - public string UserIdentifier { get; set; } - } + public Guid OrganizationId { get; set; } + public string OrganizationName { get; set; } + public string UserIdentifier { get; set; } } diff --git a/src/Core/Models/Mail/OrganizationUserConfirmedViewModel.cs b/src/Core/Models/Mail/OrganizationUserConfirmedViewModel.cs index e15cf54ee..61e710774 100644 --- a/src/Core/Models/Mail/OrganizationUserConfirmedViewModel.cs +++ b/src/Core/Models/Mail/OrganizationUserConfirmedViewModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class OrganizationUserConfirmedViewModel : BaseMailModel { - public class OrganizationUserConfirmedViewModel : BaseMailModel - { - public string OrganizationName { get; set; } - } + public string OrganizationName { get; set; } } diff --git a/src/Core/Models/Mail/OrganizationUserInvitedViewModel.cs b/src/Core/Models/Mail/OrganizationUserInvitedViewModel.cs index 0e13fa663..4bf9fbb86 100644 --- a/src/Core/Models/Mail/OrganizationUserInvitedViewModel.cs +++ b/src/Core/Models/Mail/OrganizationUserInvitedViewModel.cs @@ -1,21 +1,20 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class OrganizationUserInvitedViewModel : BaseMailModel { - public class OrganizationUserInvitedViewModel : BaseMailModel - { - public string OrganizationName { get; set; } - public string OrganizationId { get; set; } - public string OrganizationUserId { get; set; } - public string Email { get; set; } - public string OrganizationNameUrlEncoded { get; set; } - public string Token { get; set; } - public string ExpirationDate { get; set; } - public string Url => string.Format("{0}/accept-organization?organizationId={1}&" + - "organizationUserId={2}&email={3}&organizationName={4}&token={5}", - WebVaultUrl, - OrganizationId, - OrganizationUserId, - Email, - OrganizationNameUrlEncoded, - Token); - } + public string OrganizationName { get; set; } + public string OrganizationId { get; set; } + public string OrganizationUserId { get; set; } + public string Email { get; set; } + public string OrganizationNameUrlEncoded { get; set; } + public string Token { get; set; } + public string ExpirationDate { get; set; } + public string Url => string.Format("{0}/accept-organization?organizationId={1}&" + + "organizationUserId={2}&email={3}&organizationName={4}&token={5}", + WebVaultUrl, + OrganizationId, + OrganizationUserId, + Email, + OrganizationNameUrlEncoded, + Token); } diff --git a/src/Core/Models/Mail/OrganizationUserRemovedForPolicySingleOrgViewModel.cs b/src/Core/Models/Mail/OrganizationUserRemovedForPolicySingleOrgViewModel.cs index 9a92f0e0b..46020ae46 100644 --- a/src/Core/Models/Mail/OrganizationUserRemovedForPolicySingleOrgViewModel.cs +++ b/src/Core/Models/Mail/OrganizationUserRemovedForPolicySingleOrgViewModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class OrganizationUserRemovedForPolicySingleOrgViewModel : BaseMailModel { - public class OrganizationUserRemovedForPolicySingleOrgViewModel : BaseMailModel - { - public string OrganizationName { get; set; } - } + public string OrganizationName { get; set; } } diff --git a/src/Core/Models/Mail/OrganizationUserRemovedForPolicyTwoStepViewModel.cs b/src/Core/Models/Mail/OrganizationUserRemovedForPolicyTwoStepViewModel.cs index 10beaa5d7..cd4528ad5 100644 --- a/src/Core/Models/Mail/OrganizationUserRemovedForPolicyTwoStepViewModel.cs +++ b/src/Core/Models/Mail/OrganizationUserRemovedForPolicyTwoStepViewModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class OrganizationUserRemovedForPolicyTwoStepViewModel : BaseMailModel { - public class OrganizationUserRemovedForPolicyTwoStepViewModel : BaseMailModel - { - public string OrganizationName { get; set; } - } + public string OrganizationName { get; set; } } diff --git a/src/Core/Models/Mail/PasswordlessSignInModel.cs b/src/Core/Models/Mail/PasswordlessSignInModel.cs index a09d5f7b0..07754cf80 100644 --- a/src/Core/Models/Mail/PasswordlessSignInModel.cs +++ b/src/Core/Models/Mail/PasswordlessSignInModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class PasswordlessSignInModel { - public class PasswordlessSignInModel - { - public string Url { get; set; } - } + public string Url { get; set; } } diff --git a/src/Core/Models/Mail/PaymentFailedViewModel.cs b/src/Core/Models/Mail/PaymentFailedViewModel.cs index 1eb5e6952..387feeb02 100644 --- a/src/Core/Models/Mail/PaymentFailedViewModel.cs +++ b/src/Core/Models/Mail/PaymentFailedViewModel.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class PaymentFailedViewModel : BaseMailModel { - public class PaymentFailedViewModel : BaseMailModel - { - public decimal Amount { get; set; } - public bool MentionInvoices { get; set; } - } + public decimal Amount { get; set; } + public bool MentionInvoices { get; set; } } diff --git a/src/Core/Models/Mail/Provider/ProviderSetupInviteViewModel.cs b/src/Core/Models/Mail/Provider/ProviderSetupInviteViewModel.cs index daaba8a49..f351a5fe1 100644 --- a/src/Core/Models/Mail/Provider/ProviderSetupInviteViewModel.cs +++ b/src/Core/Models/Mail/Provider/ProviderSetupInviteViewModel.cs @@ -1,14 +1,13 @@ -namespace Bit.Core.Models.Mail.Provider +namespace Bit.Core.Models.Mail.Provider; + +public class ProviderSetupInviteViewModel : BaseMailModel { - public class ProviderSetupInviteViewModel : BaseMailModel - { - public string ProviderId { get; set; } - public string Email { get; set; } - public string Token { get; set; } - public string Url => string.Format("{0}/providers/setup-provider?providerId={1}&email={2}&token={3}", - WebVaultUrl, - ProviderId, - Email, - Token); - } + public string ProviderId { get; set; } + public string Email { get; set; } + public string Token { get; set; } + public string Url => string.Format("{0}/providers/setup-provider?providerId={1}&email={2}&token={3}", + WebVaultUrl, + ProviderId, + Email, + Token); } diff --git a/src/Core/Models/Mail/Provider/ProviderUserConfirmedViewModel.cs b/src/Core/Models/Mail/Provider/ProviderUserConfirmedViewModel.cs index 8a8716c46..30d24ad1e 100644 --- a/src/Core/Models/Mail/Provider/ProviderUserConfirmedViewModel.cs +++ b/src/Core/Models/Mail/Provider/ProviderUserConfirmedViewModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Models.Mail.Provider +namespace Bit.Core.Models.Mail.Provider; + +public class ProviderUserConfirmedViewModel : BaseMailModel { - public class ProviderUserConfirmedViewModel : BaseMailModel - { - public string ProviderName { get; set; } - } + public string ProviderName { get; set; } } diff --git a/src/Core/Models/Mail/Provider/ProviderUserInvitedViewModel.cs b/src/Core/Models/Mail/Provider/ProviderUserInvitedViewModel.cs index 964c51759..e418d30f2 100644 --- a/src/Core/Models/Mail/Provider/ProviderUserInvitedViewModel.cs +++ b/src/Core/Models/Mail/Provider/ProviderUserInvitedViewModel.cs @@ -1,20 +1,19 @@ -namespace Bit.Core.Models.Mail.Provider +namespace Bit.Core.Models.Mail.Provider; + +public class ProviderUserInvitedViewModel : BaseMailModel { - public class ProviderUserInvitedViewModel : BaseMailModel - { - public string ProviderName { get; set; } - public string ProviderId { get; set; } - public string ProviderUserId { get; set; } - public string Email { get; set; } - public string ProviderNameUrlEncoded { get; set; } - public string Token { get; set; } - public string Url => string.Format("{0}/providers/accept-provider?providerId={1}&" + - "providerUserId={2}&email={3}&providerName={4}&token={5}", - WebVaultUrl, - ProviderId, - ProviderUserId, - Email, - ProviderNameUrlEncoded, - Token); - } + public string ProviderName { get; set; } + public string ProviderId { get; set; } + public string ProviderUserId { get; set; } + public string Email { get; set; } + public string ProviderNameUrlEncoded { get; set; } + public string Token { get; set; } + public string Url => string.Format("{0}/providers/accept-provider?providerId={1}&" + + "providerUserId={2}&email={3}&providerName={4}&token={5}", + WebVaultUrl, + ProviderId, + ProviderUserId, + Email, + ProviderNameUrlEncoded, + Token); } diff --git a/src/Core/Models/Mail/Provider/ProviderUserRemovedViewModel.cs b/src/Core/Models/Mail/Provider/ProviderUserRemovedViewModel.cs index 4d64ed3d7..aef9d9c59 100644 --- a/src/Core/Models/Mail/Provider/ProviderUserRemovedViewModel.cs +++ b/src/Core/Models/Mail/Provider/ProviderUserRemovedViewModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Models.Mail.Provider +namespace Bit.Core.Models.Mail.Provider; + +public class ProviderUserRemovedViewModel : BaseMailModel { - public class ProviderUserRemovedViewModel : BaseMailModel - { - public string ProviderName { get; set; } - } + public string ProviderName { get; set; } } diff --git a/src/Core/Models/Mail/RecoverTwoFactorModel.cs b/src/Core/Models/Mail/RecoverTwoFactorModel.cs index f9b8cb5d4..b62f07671 100644 --- a/src/Core/Models/Mail/RecoverTwoFactorModel.cs +++ b/src/Core/Models/Mail/RecoverTwoFactorModel.cs @@ -1,10 +1,9 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class RecoverTwoFactorModel : BaseMailModel { - public class RecoverTwoFactorModel : BaseMailModel - { - public string TheDate { get; set; } - public string TheTime { get; set; } - public string TimeZone { get; set; } - public string IpAddress { get; set; } - } + public string TheDate { get; set; } + public string TheTime { get; set; } + public string TimeZone { get; set; } + public string IpAddress { get; set; } } diff --git a/src/Core/Models/Mail/UpdateTempPasswordViewModel.cs b/src/Core/Models/Mail/UpdateTempPasswordViewModel.cs index ed35d3e97..6e45df530 100644 --- a/src/Core/Models/Mail/UpdateTempPasswordViewModel.cs +++ b/src/Core/Models/Mail/UpdateTempPasswordViewModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Models.Mail +namespace Bit.Core.Models.Mail; + +public class UpdateTempPasswordViewModel { - public class UpdateTempPasswordViewModel - { - public string UserName { get; set; } - } + public string UserName { get; set; } } diff --git a/src/Core/Models/Mail/VerifyDeleteModel.cs b/src/Core/Models/Mail/VerifyDeleteModel.cs index dbe719981..22775aae1 100644 --- a/src/Core/Models/Mail/VerifyDeleteModel.cs +++ b/src/Core/Models/Mail/VerifyDeleteModel.cs @@ -1,16 +1,15 @@ -namespace Bit.Core.Models.Mail -{ - public class VerifyDeleteModel : BaseMailModel - { - public string Url => string.Format("{0}/verify-recover-delete?userId={1}&token={2}&email={3}", - WebVaultUrl, - UserId, - Token, - EmailEncoded); +namespace Bit.Core.Models.Mail; - public Guid UserId { get; set; } - public string Email { get; set; } - public string EmailEncoded { get; set; } - public string Token { get; set; } - } +public class VerifyDeleteModel : BaseMailModel +{ + public string Url => string.Format("{0}/verify-recover-delete?userId={1}&token={2}&email={3}", + WebVaultUrl, + UserId, + Token, + EmailEncoded); + + public Guid UserId { get; set; } + public string Email { get; set; } + public string EmailEncoded { get; set; } + public string Token { get; set; } } diff --git a/src/Core/Models/Mail/VerifyEmailModel.cs b/src/Core/Models/Mail/VerifyEmailModel.cs index 934ac590f..17b2eba86 100644 --- a/src/Core/Models/Mail/VerifyEmailModel.cs +++ b/src/Core/Models/Mail/VerifyEmailModel.cs @@ -1,13 +1,12 @@ -namespace Bit.Core.Models.Mail -{ - public class VerifyEmailModel : BaseMailModel - { - public string Url => string.Format("{0}/verify-email?userId={1}&token={2}", - WebVaultUrl, - UserId, - Token); +namespace Bit.Core.Models.Mail; - public Guid UserId { get; set; } - public string Token { get; set; } - } +public class VerifyEmailModel : BaseMailModel +{ + public string Url => string.Format("{0}/verify-email?userId={1}&token={2}", + WebVaultUrl, + UserId, + Token); + + public Guid UserId { get; set; } + public string Token { get; set; } } diff --git a/src/Core/Models/OrganizationConnectionConfigs/BillingSyncConfig.cs b/src/Core/Models/OrganizationConnectionConfigs/BillingSyncConfig.cs index 8b46eb831..204e165d0 100644 --- a/src/Core/Models/OrganizationConnectionConfigs/BillingSyncConfig.cs +++ b/src/Core/Models/OrganizationConnectionConfigs/BillingSyncConfig.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Models.OrganizationConnectionConfigs +namespace Bit.Core.Models.OrganizationConnectionConfigs; + +public class BillingSyncConfig { - public class BillingSyncConfig - { - public string BillingSyncKey { get; set; } - public Guid CloudOrganizationId { get; set; } - } + public string BillingSyncKey { get; set; } + public Guid CloudOrganizationId { get; set; } } diff --git a/src/Core/Models/OrganizationConnectionConfigs/ScimConfig.cs b/src/Core/Models/OrganizationConnectionConfigs/ScimConfig.cs index a7eeb632b..63a1606cb 100644 --- a/src/Core/Models/OrganizationConnectionConfigs/ScimConfig.cs +++ b/src/Core/Models/OrganizationConnectionConfigs/ScimConfig.cs @@ -1,12 +1,11 @@ using System.Text.Json.Serialization; using Bit.Core.Enums; -namespace Bit.Core.Models.OrganizationConnectionConfigs +namespace Bit.Core.Models.OrganizationConnectionConfigs; + +public class ScimConfig { - public class ScimConfig - { - public bool Enabled { get; set; } - [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] - public ScimProviderType? ScimProvider { get; set; } - } + public bool Enabled { get; set; } + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public ScimProviderType? ScimProvider { get; set; } } diff --git a/src/Core/Models/PushNotification.cs b/src/Core/Models/PushNotification.cs index 7f34b2562..4cbdae8b6 100644 --- a/src/Core/Models/PushNotification.cs +++ b/src/Core/Models/PushNotification.cs @@ -1,47 +1,46 @@ using Bit.Core.Enums; -namespace Bit.Core.Models +namespace Bit.Core.Models; + +public class PushNotificationData { - public class PushNotificationData + public PushNotificationData(PushType type, T payload, string contextId) { - public PushNotificationData(PushType type, T payload, string contextId) - { - Type = type; - Payload = payload; - ContextId = contextId; - } - - public PushType Type { get; set; } - public T Payload { get; set; } - public string ContextId { get; set; } + Type = type; + Payload = payload; + ContextId = contextId; } - public class SyncCipherPushNotification - { - public Guid Id { get; set; } - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public IEnumerable CollectionIds { get; set; } - public DateTime RevisionDate { get; set; } - } - - public class SyncFolderPushNotification - { - public Guid Id { get; set; } - public Guid UserId { get; set; } - public DateTime RevisionDate { get; set; } - } - - public class UserPushNotification - { - public Guid UserId { get; set; } - public DateTime Date { get; set; } - } - - public class SyncSendPushNotification - { - public Guid Id { get; set; } - public Guid UserId { get; set; } - public DateTime RevisionDate { get; set; } - } + public PushType Type { get; set; } + public T Payload { get; set; } + public string ContextId { get; set; } +} + +public class SyncCipherPushNotification +{ + public Guid Id { get; set; } + public Guid? UserId { get; set; } + public Guid? OrganizationId { get; set; } + public IEnumerable CollectionIds { get; set; } + public DateTime RevisionDate { get; set; } +} + +public class SyncFolderPushNotification +{ + public Guid Id { get; set; } + public Guid UserId { get; set; } + public DateTime RevisionDate { get; set; } +} + +public class UserPushNotification +{ + public Guid UserId { get; set; } + public DateTime Date { get; set; } +} + +public class SyncSendPushNotification +{ + public Guid Id { get; set; } + public Guid UserId { get; set; } + public DateTime RevisionDate { get; set; } } diff --git a/src/Core/Models/StaticStore/Plan.cs b/src/Core/Models/StaticStore/Plan.cs index 98686f5ab..25a947d18 100644 --- a/src/Core/Models/StaticStore/Plan.cs +++ b/src/Core/Models/StaticStore/Plan.cs @@ -1,55 +1,54 @@ using Bit.Core.Enums; -namespace Bit.Core.Models.StaticStore +namespace Bit.Core.Models.StaticStore; + +public class Plan { - public class Plan - { - public PlanType Type { get; set; } - public ProductType Product { get; set; } - public string Name { get; set; } - public bool IsAnnual { get; set; } - public string NameLocalizationKey { get; set; } - public string DescriptionLocalizationKey { get; set; } - public bool CanBeUsedByBusiness { get; set; } - public int BaseSeats { get; set; } - public short? BaseStorageGb { get; set; } - public short? MaxCollections { get; set; } - public short? MaxUsers { get; set; } - public bool AllowSeatAutoscale { get; set; } + public PlanType Type { get; set; } + public ProductType Product { get; set; } + public string Name { get; set; } + public bool IsAnnual { get; set; } + public string NameLocalizationKey { get; set; } + public string DescriptionLocalizationKey { get; set; } + public bool CanBeUsedByBusiness { get; set; } + public int BaseSeats { get; set; } + public short? BaseStorageGb { get; set; } + public short? MaxCollections { get; set; } + public short? MaxUsers { get; set; } + public bool AllowSeatAutoscale { get; set; } - public bool HasAdditionalSeatsOption { get; set; } - public int? MaxAdditionalSeats { get; set; } - public bool HasAdditionalStorageOption { get; set; } - public short? MaxAdditionalStorage { get; set; } - public bool HasPremiumAccessOption { get; set; } - public int? TrialPeriodDays { get; set; } + public bool HasAdditionalSeatsOption { get; set; } + public int? MaxAdditionalSeats { get; set; } + public bool HasAdditionalStorageOption { get; set; } + public short? MaxAdditionalStorage { get; set; } + public bool HasPremiumAccessOption { get; set; } + public int? TrialPeriodDays { get; set; } - public bool HasSelfHost { get; set; } - public bool HasPolicies { get; set; } - public bool HasGroups { get; set; } - public bool HasDirectory { get; set; } - public bool HasEvents { get; set; } - public bool HasTotp { get; set; } - public bool Has2fa { get; set; } - public bool HasApi { get; set; } - public bool HasSso { get; set; } - public bool HasKeyConnector { get; set; } - public bool HasScim { get; set; } - public bool HasResetPassword { get; set; } - public bool UsersGetPremium { get; set; } + public bool HasSelfHost { get; set; } + public bool HasPolicies { get; set; } + public bool HasGroups { get; set; } + public bool HasDirectory { get; set; } + public bool HasEvents { get; set; } + public bool HasTotp { get; set; } + public bool Has2fa { get; set; } + public bool HasApi { get; set; } + public bool HasSso { get; set; } + public bool HasKeyConnector { get; set; } + public bool HasScim { get; set; } + public bool HasResetPassword { get; set; } + public bool UsersGetPremium { get; set; } - public int UpgradeSortOrder { get; set; } - public int DisplaySortOrder { get; set; } - public int? LegacyYear { get; set; } - public bool Disabled { get; set; } + public int UpgradeSortOrder { get; set; } + public int DisplaySortOrder { get; set; } + public int? LegacyYear { get; set; } + public bool Disabled { get; set; } - public string StripePlanId { get; set; } - public string StripeSeatPlanId { get; set; } - public string StripeStoragePlanId { get; set; } - public string StripePremiumAccessPlanId { get; set; } - public decimal BasePrice { get; set; } - public decimal SeatPrice { get; set; } - public decimal AdditionalStoragePricePerGb { get; set; } - public decimal PremiumAccessOptionPrice { get; set; } - } + public string StripePlanId { get; set; } + public string StripeSeatPlanId { get; set; } + public string StripeStoragePlanId { get; set; } + public string StripePremiumAccessPlanId { get; set; } + public decimal BasePrice { get; set; } + public decimal SeatPrice { get; set; } + public decimal AdditionalStoragePricePerGb { get; set; } + public decimal PremiumAccessOptionPrice { get; set; } } diff --git a/src/Core/Models/StaticStore/SponsoredPlan.cs b/src/Core/Models/StaticStore/SponsoredPlan.cs index e1a8dbd96..bcd23874a 100644 --- a/src/Core/Models/StaticStore/SponsoredPlan.cs +++ b/src/Core/Models/StaticStore/SponsoredPlan.cs @@ -1,14 +1,13 @@ using Bit.Core.Enums; using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Core.Models.StaticStore +namespace Bit.Core.Models.StaticStore; + +public class SponsoredPlan { - public class SponsoredPlan - { - public PlanSponsorshipType PlanSponsorshipType { get; set; } - public ProductType SponsoredProductType { get; set; } - public ProductType SponsoringProductType { get; set; } - public string StripePlanId { get; set; } - public Func UsersCanSponsor { get; set; } - } + public PlanSponsorshipType PlanSponsorshipType { get; set; } + public ProductType SponsoredProductType { get; set; } + public ProductType SponsoringProductType { get; set; } + public string StripePlanId { get; set; } + public Func UsersCanSponsor { get; set; } } diff --git a/src/Core/Models/Stripe/StripeSubscriptionListOptions.cs b/src/Core/Models/Stripe/StripeSubscriptionListOptions.cs index 1672dc407..f32576c40 100644 --- a/src/Core/Models/Stripe/StripeSubscriptionListOptions.cs +++ b/src/Core/Models/Stripe/StripeSubscriptionListOptions.cs @@ -1,49 +1,48 @@ -namespace Bit.Core.Models.BitStripe +namespace Bit.Core.Models.BitStripe; + +// Stripe's SubscriptionListOptions model has a complex input for date filters. +// It expects a dictionary, and has lots of validation rules around what can have a value and what can't. +// To simplify this a bit we are extending Stripe's model and using our own date inputs, and building the dictionary they expect JiT. +// ___ +// Our model also facilitates selecting all elements in a list, which is unsupported by Stripe's model. +public class StripeSubscriptionListOptions : Stripe.SubscriptionListOptions { - // Stripe's SubscriptionListOptions model has a complex input for date filters. - // It expects a dictionary, and has lots of validation rules around what can have a value and what can't. - // To simplify this a bit we are extending Stripe's model and using our own date inputs, and building the dictionary they expect JiT. - // ___ - // Our model also facilitates selecting all elements in a list, which is unsupported by Stripe's model. - public class StripeSubscriptionListOptions : Stripe.SubscriptionListOptions + public DateTime? CurrentPeriodEndDate { get; set; } + public string CurrentPeriodEndRange { get; set; } = "lt"; + public bool SelectAll { get; set; } + public new Stripe.DateRangeOptions CurrentPeriodEnd { - public DateTime? CurrentPeriodEndDate { get; set; } - public string CurrentPeriodEndRange { get; set; } = "lt"; - public bool SelectAll { get; set; } - public new Stripe.DateRangeOptions CurrentPeriodEnd + get { - get - { - return CurrentPeriodEndDate.HasValue ? - new Stripe.DateRangeOptions() - { - LessThan = CurrentPeriodEndRange == "lt" ? CurrentPeriodEndDate : null, - GreaterThan = CurrentPeriodEndRange == "gt" ? CurrentPeriodEndDate : null - } : - null; - } - } - - public Stripe.SubscriptionListOptions ToStripeApiOptions() - { - var stripeApiOptions = (Stripe.SubscriptionListOptions)this; - - if (SelectAll) - { - stripeApiOptions.EndingBefore = null; - stripeApiOptions.StartingAfter = null; - } - - if (CurrentPeriodEndDate.HasValue) - { - stripeApiOptions.CurrentPeriodEnd = new Stripe.DateRangeOptions() + return CurrentPeriodEndDate.HasValue ? + new Stripe.DateRangeOptions() { LessThan = CurrentPeriodEndRange == "lt" ? CurrentPeriodEndDate : null, GreaterThan = CurrentPeriodEndRange == "gt" ? CurrentPeriodEndDate : null - }; - } - - return stripeApiOptions; + } : + null; } } + + public Stripe.SubscriptionListOptions ToStripeApiOptions() + { + var stripeApiOptions = (Stripe.SubscriptionListOptions)this; + + if (SelectAll) + { + stripeApiOptions.EndingBefore = null; + stripeApiOptions.StartingAfter = null; + } + + if (CurrentPeriodEndDate.HasValue) + { + stripeApiOptions.CurrentPeriodEnd = new Stripe.DateRangeOptions() + { + LessThan = CurrentPeriodEndRange == "lt" ? CurrentPeriodEndDate : null, + GreaterThan = CurrentPeriodEndRange == "gt" ? CurrentPeriodEndDate : null + }; + } + + return stripeApiOptions; + } } diff --git a/src/Core/Models/TwoFactorProvider.cs b/src/Core/Models/TwoFactorProvider.cs index 7e48ed397..0ff791ff8 100644 --- a/src/Core/Models/TwoFactorProvider.cs +++ b/src/Core/Models/TwoFactorProvider.cs @@ -2,66 +2,65 @@ using Bit.Core.Enums; using Fido2NetLib.Objects; -namespace Bit.Core.Models +namespace Bit.Core.Models; + +public class TwoFactorProvider { - public class TwoFactorProvider + public bool Enabled { get; set; } + public Dictionary MetaData { get; set; } = new Dictionary(); + + public class WebAuthnData { - public bool Enabled { get; set; } - public Dictionary MetaData { get; set; } = new Dictionary(); + public WebAuthnData() { } - public class WebAuthnData + public WebAuthnData(dynamic o) { - public WebAuthnData() { } - - public WebAuthnData(dynamic o) + Name = o.Name; + try { - Name = o.Name; - try - { - Descriptor = o.Descriptor; - } - catch - { - // Fallback for older newtonsoft serialized tokens. - if (o.Descriptor.Type == 0) - { - o.Descriptor.Type = "public-key"; - } - Descriptor = JsonSerializer.Deserialize(o.Descriptor.ToString(), - new JsonSerializerOptions { PropertyNameCaseInsensitive = true }); - } - PublicKey = o.PublicKey; - UserHandle = o.UserHandle; - SignatureCounter = o.SignatureCounter; - CredType = o.CredType; - RegDate = o.RegDate; - AaGuid = o.AaGuid; - Migrated = o.Migrated; + Descriptor = o.Descriptor; } - - public string Name { get; set; } - public PublicKeyCredentialDescriptor Descriptor { get; internal set; } - public byte[] PublicKey { get; internal set; } - public byte[] UserHandle { get; internal set; } - public uint SignatureCounter { get; set; } - public string CredType { get; internal set; } - public DateTime RegDate { get; internal set; } - public Guid AaGuid { get; internal set; } - public bool Migrated { get; internal set; } + catch + { + // Fallback for older newtonsoft serialized tokens. + if (o.Descriptor.Type == 0) + { + o.Descriptor.Type = "public-key"; + } + Descriptor = JsonSerializer.Deserialize(o.Descriptor.ToString(), + new JsonSerializerOptions { PropertyNameCaseInsensitive = true }); + } + PublicKey = o.PublicKey; + UserHandle = o.UserHandle; + SignatureCounter = o.SignatureCounter; + CredType = o.CredType; + RegDate = o.RegDate; + AaGuid = o.AaGuid; + Migrated = o.Migrated; } - public static bool RequiresPremium(TwoFactorProviderType type) + public string Name { get; set; } + public PublicKeyCredentialDescriptor Descriptor { get; internal set; } + public byte[] PublicKey { get; internal set; } + public byte[] UserHandle { get; internal set; } + public uint SignatureCounter { get; set; } + public string CredType { get; internal set; } + public DateTime RegDate { get; internal set; } + public Guid AaGuid { get; internal set; } + public bool Migrated { get; internal set; } + } + + public static bool RequiresPremium(TwoFactorProviderType type) + { + switch (type) { - switch (type) - { - case TwoFactorProviderType.Duo: - case TwoFactorProviderType.YubiKey: - case TwoFactorProviderType.U2f: // Keep to ensure old U2f keys are considered premium - case TwoFactorProviderType.WebAuthn: - return true; - default: - return false; - } + case TwoFactorProviderType.Duo: + case TwoFactorProviderType.YubiKey: + case TwoFactorProviderType.U2f: // Keep to ensure old U2f keys are considered premium + case TwoFactorProviderType.WebAuthn: + return true; + default: + return false; } } } diff --git a/src/Core/OrganizationFeatures/OrganizationApiKeys/GetOrganizationApiKeyCommand.cs b/src/Core/OrganizationFeatures/OrganizationApiKeys/GetOrganizationApiKeyCommand.cs index 93b4d15bd..1a0156241 100644 --- a/src/Core/OrganizationFeatures/OrganizationApiKeys/GetOrganizationApiKeyCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationApiKeys/GetOrganizationApiKeyCommand.cs @@ -4,43 +4,42 @@ using Bit.Core.OrganizationFeatures.OrganizationApiKeys.Interfaces; using Bit.Core.Repositories; using Bit.Core.Utilities; -namespace Bit.Core.OrganizationFeatures.OrganizationApiKeys +namespace Bit.Core.OrganizationFeatures.OrganizationApiKeys; + +public class GetOrganizationApiKeyCommand : IGetOrganizationApiKeyCommand { - public class GetOrganizationApiKeyCommand : IGetOrganizationApiKeyCommand + private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; + + public GetOrganizationApiKeyCommand(IOrganizationApiKeyRepository organizationApiKeyRepository) { - private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; + _organizationApiKeyRepository = organizationApiKeyRepository; + } - public GetOrganizationApiKeyCommand(IOrganizationApiKeyRepository organizationApiKeyRepository) + public async Task GetOrganizationApiKeyAsync(Guid organizationId, OrganizationApiKeyType organizationApiKeyType) + { + if (!Enum.IsDefined(organizationApiKeyType)) { - _organizationApiKeyRepository = organizationApiKeyRepository; + throw new ArgumentOutOfRangeException(nameof(organizationApiKeyType), $"Invalid value for enum {nameof(OrganizationApiKeyType)}"); } - public async Task GetOrganizationApiKeyAsync(Guid organizationId, OrganizationApiKeyType organizationApiKeyType) + var apiKeys = await _organizationApiKeyRepository + .GetManyByOrganizationIdTypeAsync(organizationId, organizationApiKeyType); + + if (apiKeys == null || !apiKeys.Any()) { - if (!Enum.IsDefined(organizationApiKeyType)) + var apiKey = new OrganizationApiKey { - throw new ArgumentOutOfRangeException(nameof(organizationApiKeyType), $"Invalid value for enum {nameof(OrganizationApiKeyType)}"); - } + OrganizationId = organizationId, + Type = organizationApiKeyType, + ApiKey = CoreHelpers.SecureRandomString(30), + RevisionDate = DateTime.UtcNow, + }; - var apiKeys = await _organizationApiKeyRepository - .GetManyByOrganizationIdTypeAsync(organizationId, organizationApiKeyType); - - if (apiKeys == null || !apiKeys.Any()) - { - var apiKey = new OrganizationApiKey - { - OrganizationId = organizationId, - Type = organizationApiKeyType, - ApiKey = CoreHelpers.SecureRandomString(30), - RevisionDate = DateTime.UtcNow, - }; - - await _organizationApiKeyRepository.CreateAsync(apiKey); - return apiKey; - } - - // NOTE: Currently we only allow one type of api key per organization - return apiKeys.Single(); + await _organizationApiKeyRepository.CreateAsync(apiKey); + return apiKey; } + + // NOTE: Currently we only allow one type of api key per organization + return apiKeys.Single(); } } diff --git a/src/Core/OrganizationFeatures/OrganizationApiKeys/Interfaces/IGetOrganizationApiKeyCommand.cs b/src/Core/OrganizationFeatures/OrganizationApiKeys/Interfaces/IGetOrganizationApiKeyCommand.cs index 645fb1086..5fcfdedd9 100644 --- a/src/Core/OrganizationFeatures/OrganizationApiKeys/Interfaces/IGetOrganizationApiKeyCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationApiKeys/Interfaces/IGetOrganizationApiKeyCommand.cs @@ -1,10 +1,9 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.OrganizationFeatures.OrganizationApiKeys.Interfaces +namespace Bit.Core.OrganizationFeatures.OrganizationApiKeys.Interfaces; + +public interface IGetOrganizationApiKeyCommand { - public interface IGetOrganizationApiKeyCommand - { - Task GetOrganizationApiKeyAsync(Guid organizationId, OrganizationApiKeyType organizationApiKeyType); - } + Task GetOrganizationApiKeyAsync(Guid organizationId, OrganizationApiKeyType organizationApiKeyType); } diff --git a/src/Core/OrganizationFeatures/OrganizationApiKeys/Interfaces/IRotateOrganizationApiKeyCommand.cs b/src/Core/OrganizationFeatures/OrganizationApiKeys/Interfaces/IRotateOrganizationApiKeyCommand.cs index 85d047987..a5cf51c3f 100644 --- a/src/Core/OrganizationFeatures/OrganizationApiKeys/Interfaces/IRotateOrganizationApiKeyCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationApiKeys/Interfaces/IRotateOrganizationApiKeyCommand.cs @@ -1,9 +1,8 @@ using Bit.Core.Entities; -namespace Bit.Core.OrganizationFeatures.OrganizationApiKeys.Interfaces +namespace Bit.Core.OrganizationFeatures.OrganizationApiKeys.Interfaces; + +public interface IRotateOrganizationApiKeyCommand { - public interface IRotateOrganizationApiKeyCommand - { - Task RotateApiKeyAsync(OrganizationApiKey organizationApiKey); - } + Task RotateApiKeyAsync(OrganizationApiKey organizationApiKey); } diff --git a/src/Core/OrganizationFeatures/OrganizationApiKeys/RotateOrganizationApiKeyCommand.cs b/src/Core/OrganizationFeatures/OrganizationApiKeys/RotateOrganizationApiKeyCommand.cs index 967f39947..f43aaa5f3 100644 --- a/src/Core/OrganizationFeatures/OrganizationApiKeys/RotateOrganizationApiKeyCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationApiKeys/RotateOrganizationApiKeyCommand.cs @@ -3,23 +3,22 @@ using Bit.Core.OrganizationFeatures.OrganizationApiKeys.Interfaces; using Bit.Core.Repositories; using Bit.Core.Utilities; -namespace Bit.Core.OrganizationFeatures.OrganizationApiKeys +namespace Bit.Core.OrganizationFeatures.OrganizationApiKeys; + +public class RotateOrganizationApiKeyCommand : IRotateOrganizationApiKeyCommand { - public class RotateOrganizationApiKeyCommand : IRotateOrganizationApiKeyCommand + private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; + + public RotateOrganizationApiKeyCommand(IOrganizationApiKeyRepository organizationApiKeyRepository) { - private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; + _organizationApiKeyRepository = organizationApiKeyRepository; + } - public RotateOrganizationApiKeyCommand(IOrganizationApiKeyRepository organizationApiKeyRepository) - { - _organizationApiKeyRepository = organizationApiKeyRepository; - } - - public async Task RotateApiKeyAsync(OrganizationApiKey organizationApiKey) - { - organizationApiKey.ApiKey = CoreHelpers.SecureRandomString(30); - organizationApiKey.RevisionDate = DateTime.UtcNow; - await _organizationApiKeyRepository.UpsertAsync(organizationApiKey); - return organizationApiKey; - } + public async Task RotateApiKeyAsync(OrganizationApiKey organizationApiKey) + { + organizationApiKey.ApiKey = CoreHelpers.SecureRandomString(30); + organizationApiKey.RevisionDate = DateTime.UtcNow; + await _organizationApiKeyRepository.UpsertAsync(organizationApiKey); + return organizationApiKey; } } diff --git a/src/Core/OrganizationFeatures/OrganizationConnections/CreateOrganizationConnectionCommand.cs b/src/Core/OrganizationFeatures/OrganizationConnections/CreateOrganizationConnectionCommand.cs index c54ef5dfc..e3f308bc5 100644 --- a/src/Core/OrganizationFeatures/OrganizationConnections/CreateOrganizationConnectionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationConnections/CreateOrganizationConnectionCommand.cs @@ -3,20 +3,19 @@ using Bit.Core.Models.Data.Organizations.OrganizationConnections; using Bit.Core.OrganizationFeatures.OrganizationConnections.Interfaces; using Bit.Core.Repositories; -namespace Bit.Core.OrganizationFeatures.OrganizationConnections +namespace Bit.Core.OrganizationFeatures.OrganizationConnections; + +public class CreateOrganizationConnectionCommand : ICreateOrganizationConnectionCommand { - public class CreateOrganizationConnectionCommand : ICreateOrganizationConnectionCommand + private readonly IOrganizationConnectionRepository _organizationConnectionRepository; + + public CreateOrganizationConnectionCommand(IOrganizationConnectionRepository organizationConnectionRepository) { - private readonly IOrganizationConnectionRepository _organizationConnectionRepository; + _organizationConnectionRepository = organizationConnectionRepository; + } - public CreateOrganizationConnectionCommand(IOrganizationConnectionRepository organizationConnectionRepository) - { - _organizationConnectionRepository = organizationConnectionRepository; - } - - public async Task CreateAsync(OrganizationConnectionData connectionData) where T : new() - { - return await _organizationConnectionRepository.CreateAsync(connectionData.ToEntity()); - } + public async Task CreateAsync(OrganizationConnectionData connectionData) where T : new() + { + return await _organizationConnectionRepository.CreateAsync(connectionData.ToEntity()); } } diff --git a/src/Core/OrganizationFeatures/OrganizationConnections/DeleteOrganizationConnectionCommand.cs b/src/Core/OrganizationFeatures/OrganizationConnections/DeleteOrganizationConnectionCommand.cs index 784975780..7166059db 100644 --- a/src/Core/OrganizationFeatures/OrganizationConnections/DeleteOrganizationConnectionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationConnections/DeleteOrganizationConnectionCommand.cs @@ -2,20 +2,19 @@ using Bit.Core.OrganizationFeatures.OrganizationConnections.Interfaces; using Bit.Core.Repositories; -namespace Bit.Core.OrganizationFeatures.OrganizationConnections +namespace Bit.Core.OrganizationFeatures.OrganizationConnections; + +public class DeleteOrganizationConnectionCommand : IDeleteOrganizationConnectionCommand { - public class DeleteOrganizationConnectionCommand : IDeleteOrganizationConnectionCommand + private readonly IOrganizationConnectionRepository _organizationConnectionRepository; + + public DeleteOrganizationConnectionCommand(IOrganizationConnectionRepository organizationConnectionRepository) { - private readonly IOrganizationConnectionRepository _organizationConnectionRepository; + _organizationConnectionRepository = organizationConnectionRepository; + } - public DeleteOrganizationConnectionCommand(IOrganizationConnectionRepository organizationConnectionRepository) - { - _organizationConnectionRepository = organizationConnectionRepository; - } - - public async Task DeleteAsync(OrganizationConnection connection) - { - await _organizationConnectionRepository.DeleteAsync(connection); - } + public async Task DeleteAsync(OrganizationConnection connection) + { + await _organizationConnectionRepository.DeleteAsync(connection); } } diff --git a/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/ICreateOrganizationConnectionCommand.cs b/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/ICreateOrganizationConnectionCommand.cs index c91985d75..b31920b10 100644 --- a/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/ICreateOrganizationConnectionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/ICreateOrganizationConnectionCommand.cs @@ -1,10 +1,9 @@ using Bit.Core.Entities; using Bit.Core.Models.Data.Organizations.OrganizationConnections; -namespace Bit.Core.OrganizationFeatures.OrganizationConnections.Interfaces +namespace Bit.Core.OrganizationFeatures.OrganizationConnections.Interfaces; + +public interface ICreateOrganizationConnectionCommand { - public interface ICreateOrganizationConnectionCommand - { - Task CreateAsync(OrganizationConnectionData connectionData) where T : new(); - } + Task CreateAsync(OrganizationConnectionData connectionData) where T : new(); } diff --git a/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/IDeleteOrganizationConnectionCommand.cs b/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/IDeleteOrganizationConnectionCommand.cs index 1b92a9fcf..818609aef 100644 --- a/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/IDeleteOrganizationConnectionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/IDeleteOrganizationConnectionCommand.cs @@ -1,9 +1,8 @@ using Bit.Core.Entities; -namespace Bit.Core.OrganizationFeatures.OrganizationConnections.Interfaces +namespace Bit.Core.OrganizationFeatures.OrganizationConnections.Interfaces; + +public interface IDeleteOrganizationConnectionCommand { - public interface IDeleteOrganizationConnectionCommand - { - Task DeleteAsync(OrganizationConnection connection); - } + Task DeleteAsync(OrganizationConnection connection); } diff --git a/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/IUpdateOrganizationConnectionCommand.cs b/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/IUpdateOrganizationConnectionCommand.cs index d01fd0b9a..742e89c97 100644 --- a/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/IUpdateOrganizationConnectionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationConnections/Interfaces/IUpdateOrganizationConnectionCommand.cs @@ -1,10 +1,9 @@ using Bit.Core.Entities; using Bit.Core.Models.Data.Organizations.OrganizationConnections; -namespace Bit.Core.OrganizationFeatures.OrganizationConnections.Interfaces +namespace Bit.Core.OrganizationFeatures.OrganizationConnections.Interfaces; + +public interface IUpdateOrganizationConnectionCommand { - public interface IUpdateOrganizationConnectionCommand - { - Task UpdateAsync(OrganizationConnectionData connectionData) where T : new(); - } + Task UpdateAsync(OrganizationConnectionData connectionData) where T : new(); } diff --git a/src/Core/OrganizationFeatures/OrganizationConnections/UpdateOrganizationConnectionCommand.cs b/src/Core/OrganizationFeatures/OrganizationConnections/UpdateOrganizationConnectionCommand.cs index 74aa08bd3..0d872b6f1 100644 --- a/src/Core/OrganizationFeatures/OrganizationConnections/UpdateOrganizationConnectionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationConnections/UpdateOrganizationConnectionCommand.cs @@ -4,34 +4,33 @@ using Bit.Core.Models.Data.Organizations.OrganizationConnections; using Bit.Core.OrganizationFeatures.OrganizationConnections.Interfaces; using Bit.Core.Repositories; -namespace Bit.Core.OrganizationFeatures.OrganizationConnections +namespace Bit.Core.OrganizationFeatures.OrganizationConnections; + +public class UpdateOrganizationConnectionCommand : IUpdateOrganizationConnectionCommand { - public class UpdateOrganizationConnectionCommand : IUpdateOrganizationConnectionCommand + private readonly IOrganizationConnectionRepository _organizationConnectionRepository; + + public UpdateOrganizationConnectionCommand(IOrganizationConnectionRepository organizationConnectionRepository) { - private readonly IOrganizationConnectionRepository _organizationConnectionRepository; + _organizationConnectionRepository = organizationConnectionRepository; + } - public UpdateOrganizationConnectionCommand(IOrganizationConnectionRepository organizationConnectionRepository) + public async Task UpdateAsync(OrganizationConnectionData connectionData) where T : new() + { + if (!connectionData.Id.HasValue) { - _organizationConnectionRepository = organizationConnectionRepository; + throw new Exception("Cannot update connection, Connection does not exist."); } - public async Task UpdateAsync(OrganizationConnectionData connectionData) where T : new() + var connection = await _organizationConnectionRepository.GetByIdAsync(connectionData.Id.Value); + + if (connection == null) { - if (!connectionData.Id.HasValue) - { - throw new Exception("Cannot update connection, Connection does not exist."); - } - - var connection = await _organizationConnectionRepository.GetByIdAsync(connectionData.Id.Value); - - if (connection == null) - { - throw new NotFoundException(); - } - - var entity = connectionData.ToEntity(); - await _organizationConnectionRepository.UpsertAsync(entity); - return entity; + throw new NotFoundException(); } + + var entity = connectionData.ToEntity(); + await _organizationConnectionRepository.UpsertAsync(entity); + return entity; } } diff --git a/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs b/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs index 94e59ab31..e428318c5 100644 --- a/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs +++ b/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs @@ -13,65 +13,64 @@ using Bit.Core.Tokens; using Microsoft.AspNetCore.DataProtection; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.OrganizationFeatures +namespace Bit.Core.OrganizationFeatures; + +public static class OrganizationServiceCollectionExtensions { - public static class OrganizationServiceCollectionExtensions + public static void AddOrganizationServices(this IServiceCollection services, IGlobalSettings globalSettings) { - public static void AddOrganizationServices(this IServiceCollection services, IGlobalSettings globalSettings) - { - services.AddScoped(); - services.AddTokenizers(); - services.AddOrganizationConnectionCommands(); - services.AddOrganizationSponsorshipCommands(globalSettings); - services.AddOrganizationApiKeyCommands(); - } + services.AddScoped(); + services.AddTokenizers(); + services.AddOrganizationConnectionCommands(); + services.AddOrganizationSponsorshipCommands(globalSettings); + services.AddOrganizationApiKeyCommands(); + } - private static void AddOrganizationConnectionCommands(this IServiceCollection services) - { - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - } + private static void AddOrganizationConnectionCommands(this IServiceCollection services) + { + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + } - private static void AddOrganizationSponsorshipCommands(this IServiceCollection services, IGlobalSettings globalSettings) + private static void AddOrganizationSponsorshipCommands(this IServiceCollection services, IGlobalSettings globalSettings) + { + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + if (globalSettings.SelfHosted) { - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - if (globalSettings.SelfHosted) - { - services.AddScoped(); - } - else - { - services.AddScoped(); - } + services.AddScoped(); } - - private static void AddOrganizationApiKeyCommands(this IServiceCollection services) + else { - services.AddScoped(); - services.AddScoped(); - } - - private static void AddTokenizers(this IServiceCollection services) - { - services.AddSingleton>(serviceProvider => - new DataProtectorTokenFactory( - OrganizationSponsorshipOfferTokenable.ClearTextPrefix, - OrganizationSponsorshipOfferTokenable.DataProtectorPurpose, - serviceProvider.GetDataProtectionProvider()) - ); + services.AddScoped(); } } + + private static void AddOrganizationApiKeyCommands(this IServiceCollection services) + { + services.AddScoped(); + services.AddScoped(); + } + + private static void AddTokenizers(this IServiceCollection services) + { + services.AddSingleton>(serviceProvider => + new DataProtectorTokenFactory( + OrganizationSponsorshipOfferTokenable.ClearTextPrefix, + OrganizationSponsorshipOfferTokenable.DataProtectorPurpose, + serviceProvider.GetDataProtectionProvider()) + ); + } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommand.cs index 71ce0b4fa..111cec395 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommand.cs @@ -2,39 +2,38 @@ using Bit.Core.Exceptions; using Bit.Core.Repositories; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; + +public abstract class CancelSponsorshipCommand { - public abstract class CancelSponsorshipCommand + protected readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + protected readonly IOrganizationRepository _organizationRepository; + + public CancelSponsorshipCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IOrganizationRepository organizationRepository) { - protected readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - protected readonly IOrganizationRepository _organizationRepository; + _organizationSponsorshipRepository = organizationSponsorshipRepository; + _organizationRepository = organizationRepository; + } - public CancelSponsorshipCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IOrganizationRepository organizationRepository) + protected virtual async Task DeleteSponsorshipAsync(OrganizationSponsorship sponsorship = null) + { + if (sponsorship == null) { - _organizationSponsorshipRepository = organizationSponsorshipRepository; - _organizationRepository = organizationRepository; + return; } - protected virtual async Task DeleteSponsorshipAsync(OrganizationSponsorship sponsorship = null) - { - if (sponsorship == null) - { - return; - } + await _organizationSponsorshipRepository.DeleteAsync(sponsorship); + } - await _organizationSponsorshipRepository.DeleteAsync(sponsorship); + protected async Task MarkToDeleteSponsorshipAsync(OrganizationSponsorship sponsorship) + { + if (sponsorship == null) + { + throw new BadRequestException("The sponsorship you are trying to cancel does not exist"); } - protected async Task MarkToDeleteSponsorshipAsync(OrganizationSponsorship sponsorship) - { - if (sponsorship == null) - { - throw new BadRequestException("The sponsorship you are trying to cancel does not exist"); - } - - sponsorship.ToDelete = true; - await _organizationSponsorshipRepository.UpsertAsync(sponsorship); - } + sponsorship.ToDelete = true; + await _organizationSponsorshipRepository.UpsertAsync(sponsorship); } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudRevokeSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudRevokeSponsorshipCommand.cs index d12765eff..76c180f74 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudRevokeSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudRevokeSponsorshipCommand.cs @@ -3,31 +3,30 @@ using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; + +public class CloudRevokeSponsorshipCommand : CancelSponsorshipCommand, IRevokeSponsorshipCommand { - public class CloudRevokeSponsorshipCommand : CancelSponsorshipCommand, IRevokeSponsorshipCommand + public CloudRevokeSponsorshipCommand( + IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IOrganizationRepository organizationRepository) : base(organizationSponsorshipRepository, organizationRepository) { - public CloudRevokeSponsorshipCommand( - IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IOrganizationRepository organizationRepository) : base(organizationSponsorshipRepository, organizationRepository) + } + + public async Task RevokeSponsorshipAsync(OrganizationSponsorship sponsorship) + { + if (sponsorship == null) { + throw new BadRequestException("You are not currently sponsoring an organization."); } - public async Task RevokeSponsorshipAsync(OrganizationSponsorship sponsorship) + if (sponsorship.SponsoredOrganizationId == null) { - if (sponsorship == null) - { - throw new BadRequestException("You are not currently sponsoring an organization."); - } - - if (sponsorship.SponsoredOrganizationId == null) - { - await base.DeleteSponsorshipAsync(sponsorship); - } - else - { - await MarkToDeleteSponsorshipAsync(sponsorship); - } + await base.DeleteSponsorshipAsync(sponsorship); + } + else + { + await MarkToDeleteSponsorshipAsync(sponsorship); } } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommand.cs index c4da82c96..d0569278b 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommand.cs @@ -7,128 +7,127 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Utilities; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud -{ - public class CloudSyncSponsorshipsCommand : ICloudSyncSponsorshipsCommand - { - private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - private readonly IEventService _eventService; +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; - public CloudSyncSponsorshipsCommand( - IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IEventService eventService) +public class CloudSyncSponsorshipsCommand : ICloudSyncSponsorshipsCommand +{ + private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + private readonly IEventService _eventService; + + public CloudSyncSponsorshipsCommand( + IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IEventService eventService) + { + _organizationSponsorshipRepository = organizationSponsorshipRepository; + _eventService = eventService; + } + + public async Task<(OrganizationSponsorshipSyncData, IEnumerable)> SyncOrganization(Organization sponsoringOrg, IEnumerable sponsorshipsData) + { + if (sponsoringOrg == null) { - _organizationSponsorshipRepository = organizationSponsorshipRepository; - _eventService = eventService; + throw new BadRequestException("Failed to sync sponsorship - missing organization."); } - public async Task<(OrganizationSponsorshipSyncData, IEnumerable)> SyncOrganization(Organization sponsoringOrg, IEnumerable sponsorshipsData) + var (processedSponsorshipsData, sponsorshipsToEmailOffer) = sponsorshipsData.Any() ? + await DoSyncAsync(sponsoringOrg, sponsorshipsData) : + (sponsorshipsData, Array.Empty()); + + await RecordEvent(sponsoringOrg); + + return (new OrganizationSponsorshipSyncData { - if (sponsoringOrg == null) + SponsorshipsBatch = processedSponsorshipsData + }, sponsorshipsToEmailOffer); + } + + private async Task<(IEnumerable data, IEnumerable toOffer)> DoSyncAsync(Organization sponsoringOrg, IEnumerable sponsorshipsData) + { + var existingSponsorshipsDict = (await _organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(sponsoringOrg.Id)) + .ToDictionary(i => i.SponsoringOrganizationUserId); + + var sponsorshipsToUpsert = new List(); + var sponsorshipIdsToDelete = new List(); + var sponsorshipsToReturn = new List(); + + foreach (var selfHostedSponsorship in sponsorshipsData) + { + var requiredSponsoringProductType = StaticStore.GetSponsoredPlan(selfHostedSponsorship.PlanSponsorshipType)?.SponsoringProductType; + if (requiredSponsoringProductType == null + || StaticStore.GetPlan(sponsoringOrg.PlanType).Product != requiredSponsoringProductType.Value) { - throw new BadRequestException("Failed to sync sponsorship - missing organization."); + continue; // prevent unsupported sponsorships } - var (processedSponsorshipsData, sponsorshipsToEmailOffer) = sponsorshipsData.Any() ? - await DoSyncAsync(sponsoringOrg, sponsorshipsData) : - (sponsorshipsData, Array.Empty()); - - await RecordEvent(sponsoringOrg); - - return (new OrganizationSponsorshipSyncData + if (!existingSponsorshipsDict.TryGetValue(selfHostedSponsorship.SponsoringOrganizationUserId, out var cloudSponsorship)) { - SponsorshipsBatch = processedSponsorshipsData - }, sponsorshipsToEmailOffer); - } - - private async Task<(IEnumerable data, IEnumerable toOffer)> DoSyncAsync(Organization sponsoringOrg, IEnumerable sponsorshipsData) - { - var existingSponsorshipsDict = (await _organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(sponsoringOrg.Id)) - .ToDictionary(i => i.SponsoringOrganizationUserId); - - var sponsorshipsToUpsert = new List(); - var sponsorshipIdsToDelete = new List(); - var sponsorshipsToReturn = new List(); - - foreach (var selfHostedSponsorship in sponsorshipsData) - { - var requiredSponsoringProductType = StaticStore.GetSponsoredPlan(selfHostedSponsorship.PlanSponsorshipType)?.SponsoringProductType; - if (requiredSponsoringProductType == null - || StaticStore.GetPlan(sponsoringOrg.PlanType).Product != requiredSponsoringProductType.Value) + if (selfHostedSponsorship.ToDelete && selfHostedSponsorship.LastSyncDate == null) { - continue; // prevent unsupported sponsorships + continue; // prevent invalid sponsorships in cloud. These should have been deleted by self hosted } - - if (!existingSponsorshipsDict.TryGetValue(selfHostedSponsorship.SponsoringOrganizationUserId, out var cloudSponsorship)) + if (OrgDisabledForMoreThanGracePeriod(sponsoringOrg)) { - if (selfHostedSponsorship.ToDelete && selfHostedSponsorship.LastSyncDate == null) - { - continue; // prevent invalid sponsorships in cloud. These should have been deleted by self hosted - } - if (OrgDisabledForMoreThanGracePeriod(sponsoringOrg)) - { - continue; // prevent new sponsorships from disabled orgs - } - cloudSponsorship = new OrganizationSponsorship - { - SponsoringOrganizationId = sponsoringOrg.Id, - SponsoringOrganizationUserId = selfHostedSponsorship.SponsoringOrganizationUserId, - FriendlyName = selfHostedSponsorship.FriendlyName, - OfferedToEmail = selfHostedSponsorship.OfferedToEmail, - PlanSponsorshipType = selfHostedSponsorship.PlanSponsorshipType, - LastSyncDate = DateTime.UtcNow, - }; + continue; // prevent new sponsorships from disabled orgs + } + cloudSponsorship = new OrganizationSponsorship + { + SponsoringOrganizationId = sponsoringOrg.Id, + SponsoringOrganizationUserId = selfHostedSponsorship.SponsoringOrganizationUserId, + FriendlyName = selfHostedSponsorship.FriendlyName, + OfferedToEmail = selfHostedSponsorship.OfferedToEmail, + PlanSponsorshipType = selfHostedSponsorship.PlanSponsorshipType, + LastSyncDate = DateTime.UtcNow, + }; + } + else + { + cloudSponsorship.LastSyncDate = DateTime.UtcNow; + } + + if (selfHostedSponsorship.ToDelete) + { + if (cloudSponsorship.SponsoredOrganizationId == null) + { + sponsorshipIdsToDelete.Add(cloudSponsorship.Id); + selfHostedSponsorship.CloudSponsorshipRemoved = true; } else { - cloudSponsorship.LastSyncDate = DateTime.UtcNow; + cloudSponsorship.ToDelete = true; } - - if (selfHostedSponsorship.ToDelete) - { - if (cloudSponsorship.SponsoredOrganizationId == null) - { - sponsorshipIdsToDelete.Add(cloudSponsorship.Id); - selfHostedSponsorship.CloudSponsorshipRemoved = true; - } - else - { - cloudSponsorship.ToDelete = true; - } - } - sponsorshipsToUpsert.Add(cloudSponsorship); - - selfHostedSponsorship.ValidUntil = cloudSponsorship.ValidUntil; - selfHostedSponsorship.LastSyncDate = DateTime.UtcNow; - sponsorshipsToReturn.Add(selfHostedSponsorship); - } - var sponsorshipsToEmailOffer = sponsorshipsToUpsert.Where(s => s.Id == default).ToArray(); - if (sponsorshipsToUpsert.Any()) - { - await _organizationSponsorshipRepository.UpsertManyAsync(sponsorshipsToUpsert); - } - if (sponsorshipIdsToDelete.Any()) - { - await _organizationSponsorshipRepository.DeleteManyAsync(sponsorshipIdsToDelete); } + sponsorshipsToUpsert.Add(cloudSponsorship); - return (sponsorshipsToReturn, sponsorshipsToEmailOffer); + selfHostedSponsorship.ValidUntil = cloudSponsorship.ValidUntil; + selfHostedSponsorship.LastSyncDate = DateTime.UtcNow; + sponsorshipsToReturn.Add(selfHostedSponsorship); } - - /// - /// True if Organization is disabled and the expiration date is more than three months ago - /// - /// - private bool OrgDisabledForMoreThanGracePeriod(Organization organization) => - !organization.Enabled && - ( - !organization.ExpirationDate.HasValue || - DateTime.UtcNow.Subtract(organization.ExpirationDate.Value).TotalDays > 93 - ); - - private async Task RecordEvent(Organization organization) + var sponsorshipsToEmailOffer = sponsorshipsToUpsert.Where(s => s.Id == default).ToArray(); + if (sponsorshipsToUpsert.Any()) { - await _eventService.LogOrganizationEventAsync(organization, EventType.Organization_SponsorshipsSynced); + await _organizationSponsorshipRepository.UpsertManyAsync(sponsorshipsToUpsert); } + if (sponsorshipIdsToDelete.Any()) + { + await _organizationSponsorshipRepository.DeleteManyAsync(sponsorshipIdsToDelete); + } + + return (sponsorshipsToReturn, sponsorshipsToEmailOffer); + } + + /// + /// True if Organization is disabled and the expiration date is more than three months ago + /// + /// + private bool OrgDisabledForMoreThanGracePeriod(Organization organization) => + !organization.Enabled && + ( + !organization.ExpirationDate.HasValue || + DateTime.UtcNow.Subtract(organization.ExpirationDate.Value).TotalDays > 93 + ); + + private async Task RecordEvent(Organization organization) + { + await _eventService.LogOrganizationEventAsync(organization, EventType.Organization_SponsorshipsSynced); } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/OrganizationSponsorshipRenewCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/OrganizationSponsorshipRenewCommand.cs index 148b525d7..1d7b66a66 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/OrganizationSponsorshipRenewCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/OrganizationSponsorshipRenewCommand.cs @@ -1,28 +1,27 @@ using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; + +public class OrganizationSponsorshipRenewCommand : IOrganizationSponsorshipRenewCommand { - public class OrganizationSponsorshipRenewCommand : IOrganizationSponsorshipRenewCommand + private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + + public OrganizationSponsorshipRenewCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository) { - private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + _organizationSponsorshipRepository = organizationSponsorshipRepository; + } - public OrganizationSponsorshipRenewCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository) + public async Task UpdateExpirationDateAsync(Guid organizationId, DateTime expireDate) + { + var sponsorship = await _organizationSponsorshipRepository.GetBySponsoredOrganizationIdAsync(organizationId); + + if (sponsorship == null) { - _organizationSponsorshipRepository = organizationSponsorshipRepository; + return; } - public async Task UpdateExpirationDateAsync(Guid organizationId, DateTime expireDate) - { - var sponsorship = await _organizationSponsorshipRepository.GetBySponsoredOrganizationIdAsync(organizationId); - - if (sponsorship == null) - { - return; - } - - sponsorship.ValidUntil = expireDate; - await _organizationSponsorshipRepository.UpsertAsync(sponsorship); - } + sponsorship.ValidUntil = expireDate; + await _organizationSponsorshipRepository.UpsertAsync(sponsorship); } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/RemoveSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/RemoveSponsorshipCommand.cs index 136c1681b..1e05f8bc4 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/RemoveSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/RemoveSponsorshipCommand.cs @@ -3,24 +3,23 @@ using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; + +public class RemoveSponsorshipCommand : CancelSponsorshipCommand, IRemoveSponsorshipCommand { - public class RemoveSponsorshipCommand : CancelSponsorshipCommand, IRemoveSponsorshipCommand + public RemoveSponsorshipCommand( + IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IOrganizationRepository organizationRepository) : base(organizationSponsorshipRepository, organizationRepository) { - public RemoveSponsorshipCommand( - IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IOrganizationRepository organizationRepository) : base(organizationSponsorshipRepository, organizationRepository) + } + + public async Task RemoveSponsorshipAsync(OrganizationSponsorship sponsorship) + { + if (sponsorship == null || sponsorship.SponsoredOrganizationId == null) { + throw new BadRequestException("The requested organization is not currently being sponsored."); } - public async Task RemoveSponsorshipAsync(OrganizationSponsorship sponsorship) - { - if (sponsorship == null || sponsorship.SponsoredOrganizationId == null) - { - throw new BadRequestException("The requested organization is not currently being sponsored."); - } - - await MarkToDeleteSponsorshipAsync(sponsorship); - } + await MarkToDeleteSponsorshipAsync(sponsorship); } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SendSponsorshipOfferCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SendSponsorshipOfferCommand.cs index b77706051..5f9a62d25 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SendSponsorshipOfferCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SendSponsorshipOfferCommand.cs @@ -7,64 +7,63 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Tokens; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; + +public class SendSponsorshipOfferCommand : ISendSponsorshipOfferCommand { - public class SendSponsorshipOfferCommand : ISendSponsorshipOfferCommand + private readonly IUserRepository _userRepository; + private readonly IMailService _mailService; + private readonly IDataProtectorTokenFactory _tokenFactory; + + public SendSponsorshipOfferCommand(IUserRepository userRepository, + IMailService mailService, + IDataProtectorTokenFactory tokenFactory) { - private readonly IUserRepository _userRepository; - private readonly IMailService _mailService; - private readonly IDataProtectorTokenFactory _tokenFactory; + _userRepository = userRepository; + _mailService = mailService; + _tokenFactory = tokenFactory; + } - public SendSponsorshipOfferCommand(IUserRepository userRepository, - IMailService mailService, - IDataProtectorTokenFactory tokenFactory) - { - _userRepository = userRepository; - _mailService = mailService; - _tokenFactory = tokenFactory; - } - - public async Task BulkSendSponsorshipOfferAsync(string sponsoringOrgName, IEnumerable sponsorships) - { - var invites = new List<(string, bool, string)>(); - foreach (var sponsorship in sponsorships) - { - var user = await _userRepository.GetByEmailAsync(sponsorship.OfferedToEmail); - var isExistingAccount = user != null; - invites.Add((sponsorship.OfferedToEmail, user != null, _tokenFactory.Protect(new OrganizationSponsorshipOfferTokenable(sponsorship)))); - } - - await _mailService.BulkSendFamiliesForEnterpriseOfferEmailAsync(sponsoringOrgName, invites); - } - - public async Task SendSponsorshipOfferAsync(OrganizationSponsorship sponsorship, string sponsoringOrgName) + public async Task BulkSendSponsorshipOfferAsync(string sponsoringOrgName, IEnumerable sponsorships) + { + var invites = new List<(string, bool, string)>(); + foreach (var sponsorship in sponsorships) { var user = await _userRepository.GetByEmailAsync(sponsorship.OfferedToEmail); var isExistingAccount = user != null; - - await _mailService.SendFamiliesForEnterpriseOfferEmailAsync(sponsoringOrgName, sponsorship.OfferedToEmail, - isExistingAccount, _tokenFactory.Protect(new OrganizationSponsorshipOfferTokenable(sponsorship))); + invites.Add((sponsorship.OfferedToEmail, user != null, _tokenFactory.Protect(new OrganizationSponsorshipOfferTokenable(sponsorship)))); } - public async Task SendSponsorshipOfferAsync(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, - OrganizationSponsorship sponsorship) + await _mailService.BulkSendFamiliesForEnterpriseOfferEmailAsync(sponsoringOrgName, invites); + } + + public async Task SendSponsorshipOfferAsync(OrganizationSponsorship sponsorship, string sponsoringOrgName) + { + var user = await _userRepository.GetByEmailAsync(sponsorship.OfferedToEmail); + var isExistingAccount = user != null; + + await _mailService.SendFamiliesForEnterpriseOfferEmailAsync(sponsoringOrgName, sponsorship.OfferedToEmail, + isExistingAccount, _tokenFactory.Protect(new OrganizationSponsorshipOfferTokenable(sponsorship))); + } + + public async Task SendSponsorshipOfferAsync(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, + OrganizationSponsorship sponsorship) + { + if (sponsoringOrg == null) { - if (sponsoringOrg == null) - { - throw new BadRequestException("Cannot find the requested sponsoring organization."); - } - - if (sponsoringOrgUser == null || sponsoringOrgUser.Status != OrganizationUserStatusType.Confirmed) - { - throw new BadRequestException("Only confirmed users can sponsor other organizations."); - } - - if (sponsorship == null || sponsorship.OfferedToEmail == null) - { - throw new BadRequestException("Cannot find an outstanding sponsorship offer for this organization."); - } - - await SendSponsorshipOfferAsync(sponsorship, sponsoringOrg.Name); + throw new BadRequestException("Cannot find the requested sponsoring organization."); } + + if (sponsoringOrgUser == null || sponsoringOrgUser.Status != OrganizationUserStatusType.Confirmed) + { + throw new BadRequestException("Only confirmed users can sponsor other organizations."); + } + + if (sponsorship == null || sponsorship.OfferedToEmail == null) + { + throw new BadRequestException("Cannot find an outstanding sponsorship offer for this organization."); + } + + await SendSponsorshipOfferAsync(sponsorship, sponsoringOrg.Name); } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs index 698ec549d..9230e7d13 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs @@ -5,63 +5,62 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Utilities; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; + +public class SetUpSponsorshipCommand : ISetUpSponsorshipCommand { - public class SetUpSponsorshipCommand : ISetUpSponsorshipCommand + private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IPaymentService _paymentService; + + public SetUpSponsorshipCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, IOrganizationRepository organizationRepository, IPaymentService paymentService) { - private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IPaymentService _paymentService; + _organizationSponsorshipRepository = organizationSponsorshipRepository; + _organizationRepository = organizationRepository; + _paymentService = paymentService; + } - public SetUpSponsorshipCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, IOrganizationRepository organizationRepository, IPaymentService paymentService) + public async Task SetUpSponsorshipAsync(OrganizationSponsorship sponsorship, + Organization sponsoredOrganization) + { + if (sponsorship == null) { - _organizationSponsorshipRepository = organizationSponsorshipRepository; - _organizationRepository = organizationRepository; - _paymentService = paymentService; + throw new BadRequestException("No unredeemed sponsorship offer exists for you."); } - public async Task SetUpSponsorshipAsync(OrganizationSponsorship sponsorship, - Organization sponsoredOrganization) + var existingOrgSponsorship = await _organizationSponsorshipRepository + .GetBySponsoredOrganizationIdAsync(sponsoredOrganization.Id); + if (existingOrgSponsorship != null) { - if (sponsorship == null) - { - throw new BadRequestException("No unredeemed sponsorship offer exists for you."); - } - - var existingOrgSponsorship = await _organizationSponsorshipRepository - .GetBySponsoredOrganizationIdAsync(sponsoredOrganization.Id); - if (existingOrgSponsorship != null) - { - throw new BadRequestException("Cannot redeem a sponsorship offer for an organization that is already sponsored. Revoke existing sponsorship first."); - } - - if (sponsorship.PlanSponsorshipType == null) - { - throw new BadRequestException("Cannot set up sponsorship without a known sponsorship type."); - } - - // Do not allow self-hosted sponsorships that haven't been synced for > 0.5 year - if (sponsorship.LastSyncDate != null && DateTime.UtcNow.Subtract(sponsorship.LastSyncDate.Value).TotalDays > 182.5) - { - await _organizationSponsorshipRepository.DeleteAsync(sponsorship); - throw new BadRequestException("This sponsorship offer is more than 6 months old and has expired."); - } - - // Check org to sponsor's product type - var requiredSponsoredProductType = StaticStore.GetSponsoredPlan(sponsorship.PlanSponsorshipType.Value)?.SponsoredProductType; - if (requiredSponsoredProductType == null || - sponsoredOrganization == null || - StaticStore.GetPlan(sponsoredOrganization.PlanType).Product != requiredSponsoredProductType.Value) - { - throw new BadRequestException("Can only redeem sponsorship offer on families organizations."); - } - - await _paymentService.SponsorOrganizationAsync(sponsoredOrganization, sponsorship); - await _organizationRepository.UpsertAsync(sponsoredOrganization); - - sponsorship.SponsoredOrganizationId = sponsoredOrganization.Id; - sponsorship.OfferedToEmail = null; - await _organizationSponsorshipRepository.UpsertAsync(sponsorship); + throw new BadRequestException("Cannot redeem a sponsorship offer for an organization that is already sponsored. Revoke existing sponsorship first."); } + + if (sponsorship.PlanSponsorshipType == null) + { + throw new BadRequestException("Cannot set up sponsorship without a known sponsorship type."); + } + + // Do not allow self-hosted sponsorships that haven't been synced for > 0.5 year + if (sponsorship.LastSyncDate != null && DateTime.UtcNow.Subtract(sponsorship.LastSyncDate.Value).TotalDays > 182.5) + { + await _organizationSponsorshipRepository.DeleteAsync(sponsorship); + throw new BadRequestException("This sponsorship offer is more than 6 months old and has expired."); + } + + // Check org to sponsor's product type + var requiredSponsoredProductType = StaticStore.GetSponsoredPlan(sponsorship.PlanSponsorshipType.Value)?.SponsoredProductType; + if (requiredSponsoredProductType == null || + sponsoredOrganization == null || + StaticStore.GetPlan(sponsoredOrganization.PlanType).Product != requiredSponsoredProductType.Value) + { + throw new BadRequestException("Can only redeem sponsorship offer on families organizations."); + } + + await _paymentService.SponsorOrganizationAsync(sponsoredOrganization, sponsorship); + await _organizationRepository.UpsertAsync(sponsoredOrganization); + + sponsorship.SponsoredOrganizationId = sponsoredOrganization.Id; + sponsorship.OfferedToEmail = null; + await _organizationSponsorshipRepository.UpsertAsync(sponsorship); } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateBillingSyncKeyCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateBillingSyncKeyCommand.cs index f1032f0b2..19c4398a7 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateBillingSyncKeyCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateBillingSyncKeyCommand.cs @@ -3,38 +3,37 @@ using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; + +public class ValidateBillingSyncKeyCommand : IValidateBillingSyncKeyCommand { - public class ValidateBillingSyncKeyCommand : IValidateBillingSyncKeyCommand + private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + private readonly IOrganizationApiKeyRepository _apiKeyRepository; + + public ValidateBillingSyncKeyCommand( + IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IOrganizationApiKeyRepository organizationApiKeyRepository) { - private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - private readonly IOrganizationApiKeyRepository _apiKeyRepository; + _organizationSponsorshipRepository = organizationSponsorshipRepository; + _apiKeyRepository = organizationApiKeyRepository; + } - public ValidateBillingSyncKeyCommand( - IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IOrganizationApiKeyRepository organizationApiKeyRepository) + public async Task ValidateBillingSyncKeyAsync(Organization organization, string billingSyncKey) + { + if (organization == null) { - _organizationSponsorshipRepository = organizationSponsorshipRepository; - _apiKeyRepository = organizationApiKeyRepository; + throw new BadRequestException("Invalid organization"); } - - public async Task ValidateBillingSyncKeyAsync(Organization organization, string billingSyncKey) + if (string.IsNullOrWhiteSpace(billingSyncKey)) { - if (organization == null) - { - throw new BadRequestException("Invalid organization"); - } - if (string.IsNullOrWhiteSpace(billingSyncKey)) - { - return false; - } - - var orgApiKey = (await _apiKeyRepository.GetManyByOrganizationIdTypeAsync(organization.Id, Enums.OrganizationApiKeyType.BillingSync)).FirstOrDefault(); - if (string.Equals(orgApiKey.ApiKey, billingSyncKey)) - { - return true; - } return false; } + + var orgApiKey = (await _apiKeyRepository.GetManyByOrganizationIdTypeAsync(organization.Id, Enums.OrganizationApiKeyType.BillingSync)).FirstOrDefault(); + if (string.Equals(orgApiKey.ApiKey, billingSyncKey)) + { + return true; + } + return false; } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateRedemptionTokenCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateRedemptionTokenCommand.cs index 179f5b3ac..fc3f5b132 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateRedemptionTokenCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateRedemptionTokenCommand.cs @@ -4,34 +4,33 @@ using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterpri using Bit.Core.Repositories; using Bit.Core.Tokens; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; + +public class ValidateRedemptionTokenCommand : IValidateRedemptionTokenCommand { - public class ValidateRedemptionTokenCommand : IValidateRedemptionTokenCommand + private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + private readonly IDataProtectorTokenFactory _dataProtectorTokenFactory; + + public ValidateRedemptionTokenCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IDataProtectorTokenFactory dataProtectorTokenFactory) { - private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - private readonly IDataProtectorTokenFactory _dataProtectorTokenFactory; + _organizationSponsorshipRepository = organizationSponsorshipRepository; + _dataProtectorTokenFactory = dataProtectorTokenFactory; + } - public ValidateRedemptionTokenCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IDataProtectorTokenFactory dataProtectorTokenFactory) + public async Task<(bool valid, OrganizationSponsorship sponsorship)> ValidateRedemptionTokenAsync(string encryptedToken, string sponsoredUserEmail) + { + + if (!_dataProtectorTokenFactory.TryUnprotect(encryptedToken, out var tokenable)) { - _organizationSponsorshipRepository = organizationSponsorshipRepository; - _dataProtectorTokenFactory = dataProtectorTokenFactory; + return (false, null); } - public async Task<(bool valid, OrganizationSponsorship sponsorship)> ValidateRedemptionTokenAsync(string encryptedToken, string sponsoredUserEmail) + var sponsorship = await _organizationSponsorshipRepository.GetByIdAsync(tokenable.Id); + if (!tokenable.IsValid(sponsorship, sponsoredUserEmail)) { - - if (!_dataProtectorTokenFactory.TryUnprotect(encryptedToken, out var tokenable)) - { - return (false, null); - } - - var sponsorship = await _organizationSponsorshipRepository.GetByIdAsync(tokenable.Id); - if (!tokenable.IsValid(sponsorship, sponsoredUserEmail)) - { - return (false, sponsorship); - } - return (true, sponsorship); + return (false, sponsorship); } + return (true, sponsorship); } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs index 3b0bf3f14..3f2d7af5e 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs @@ -4,112 +4,111 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Microsoft.Extensions.Logging; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; + +public class ValidateSponsorshipCommand : CancelSponsorshipCommand, IValidateSponsorshipCommand { - public class ValidateSponsorshipCommand : CancelSponsorshipCommand, IValidateSponsorshipCommand + private readonly IPaymentService _paymentService; + private readonly IMailService _mailService; + private readonly ILogger _logger; + + public ValidateSponsorshipCommand( + IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IOrganizationRepository organizationRepository, + IPaymentService paymentService, + IMailService mailService, + ILogger logger) : base(organizationSponsorshipRepository, organizationRepository) { - private readonly IPaymentService _paymentService; - private readonly IMailService _mailService; - private readonly ILogger _logger; - - public ValidateSponsorshipCommand( - IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IOrganizationRepository organizationRepository, - IPaymentService paymentService, - IMailService mailService, - ILogger logger) : base(organizationSponsorshipRepository, organizationRepository) - { - _paymentService = paymentService; - _mailService = mailService; - _logger = logger; - } - - public async Task ValidateSponsorshipAsync(Guid sponsoredOrganizationId) - { - var sponsoredOrganization = await _organizationRepository.GetByIdAsync(sponsoredOrganizationId); - if (sponsoredOrganization == null) - { - return false; - } - - var existingSponsorship = await _organizationSponsorshipRepository - .GetBySponsoredOrganizationIdAsync(sponsoredOrganizationId); - - if (existingSponsorship == null) - { - await CancelSponsorshipAsync(sponsoredOrganization, null); - return false; - } - - if (existingSponsorship.SponsoringOrganizationId == null || existingSponsorship.SponsoringOrganizationUserId == default || existingSponsorship.PlanSponsorshipType == null) - { - await CancelSponsorshipAsync(sponsoredOrganization, existingSponsorship); - return false; - } - var sponsoredPlan = Utilities.StaticStore.GetSponsoredPlan(existingSponsorship.PlanSponsorshipType.Value); - - var sponsoringOrganization = await _organizationRepository - .GetByIdAsync(existingSponsorship.SponsoringOrganizationId.Value); - if (sponsoringOrganization == null) - { - await CancelSponsorshipAsync(sponsoredOrganization, existingSponsorship); - return false; - } - - var sponsoringOrgPlan = Utilities.StaticStore.GetPlan(sponsoringOrganization.PlanType); - if (OrgDisabledForMoreThanGracePeriod(sponsoringOrganization) || - sponsoredPlan.SponsoringProductType != sponsoringOrgPlan.Product || - existingSponsorship.ToDelete || - SponsorshipIsSelfHostedOutOfSync(existingSponsorship)) - { - await CancelSponsorshipAsync(sponsoredOrganization, existingSponsorship); - return false; - } - - return true; - } - - protected async Task CancelSponsorshipAsync(Organization sponsoredOrganization, OrganizationSponsorship sponsorship = null) - { - if (sponsoredOrganization != null) - { - await _paymentService.RemoveOrganizationSponsorshipAsync(sponsoredOrganization, sponsorship); - await _organizationRepository.UpsertAsync(sponsoredOrganization); - - try - { - if (sponsorship != null) - { - await _mailService.SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync( - sponsoredOrganization.BillingEmailAddress(), - sponsorship.ValidUntil ?? DateTime.UtcNow.AddDays(15)); - } - } - catch (Exception e) - { - _logger.LogError("Error sending Family sponsorship removed email.", e); - } - } - await base.DeleteSponsorshipAsync(sponsorship); - } - - /// - /// True if Sponsorship is from a self-hosted instance that has failed to sync for more than 6 months - /// - /// - private bool SponsorshipIsSelfHostedOutOfSync(OrganizationSponsorship sponsorship) => - sponsorship.LastSyncDate.HasValue && - DateTime.UtcNow.Subtract(sponsorship.LastSyncDate.Value).TotalDays > 182.5; - - /// - /// True if Organization is disabled and the expiration date is more than three months ago - /// - /// - private bool OrgDisabledForMoreThanGracePeriod(Organization organization) => - !organization.Enabled && - ( - !organization.ExpirationDate.HasValue || - DateTime.UtcNow.Subtract(organization.ExpirationDate.Value).TotalDays > 93 - ); + _paymentService = paymentService; + _mailService = mailService; + _logger = logger; } + + public async Task ValidateSponsorshipAsync(Guid sponsoredOrganizationId) + { + var sponsoredOrganization = await _organizationRepository.GetByIdAsync(sponsoredOrganizationId); + if (sponsoredOrganization == null) + { + return false; + } + + var existingSponsorship = await _organizationSponsorshipRepository + .GetBySponsoredOrganizationIdAsync(sponsoredOrganizationId); + + if (existingSponsorship == null) + { + await CancelSponsorshipAsync(sponsoredOrganization, null); + return false; + } + + if (existingSponsorship.SponsoringOrganizationId == null || existingSponsorship.SponsoringOrganizationUserId == default || existingSponsorship.PlanSponsorshipType == null) + { + await CancelSponsorshipAsync(sponsoredOrganization, existingSponsorship); + return false; + } + var sponsoredPlan = Utilities.StaticStore.GetSponsoredPlan(existingSponsorship.PlanSponsorshipType.Value); + + var sponsoringOrganization = await _organizationRepository + .GetByIdAsync(existingSponsorship.SponsoringOrganizationId.Value); + if (sponsoringOrganization == null) + { + await CancelSponsorshipAsync(sponsoredOrganization, existingSponsorship); + return false; + } + + var sponsoringOrgPlan = Utilities.StaticStore.GetPlan(sponsoringOrganization.PlanType); + if (OrgDisabledForMoreThanGracePeriod(sponsoringOrganization) || + sponsoredPlan.SponsoringProductType != sponsoringOrgPlan.Product || + existingSponsorship.ToDelete || + SponsorshipIsSelfHostedOutOfSync(existingSponsorship)) + { + await CancelSponsorshipAsync(sponsoredOrganization, existingSponsorship); + return false; + } + + return true; + } + + protected async Task CancelSponsorshipAsync(Organization sponsoredOrganization, OrganizationSponsorship sponsorship = null) + { + if (sponsoredOrganization != null) + { + await _paymentService.RemoveOrganizationSponsorshipAsync(sponsoredOrganization, sponsorship); + await _organizationRepository.UpsertAsync(sponsoredOrganization); + + try + { + if (sponsorship != null) + { + await _mailService.SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync( + sponsoredOrganization.BillingEmailAddress(), + sponsorship.ValidUntil ?? DateTime.UtcNow.AddDays(15)); + } + } + catch (Exception e) + { + _logger.LogError("Error sending Family sponsorship removed email.", e); + } + } + await base.DeleteSponsorshipAsync(sponsorship); + } + + /// + /// True if Sponsorship is from a self-hosted instance that has failed to sync for more than 6 months + /// + /// + private bool SponsorshipIsSelfHostedOutOfSync(OrganizationSponsorship sponsorship) => + sponsorship.LastSyncDate.HasValue && + DateTime.UtcNow.Subtract(sponsorship.LastSyncDate.Value).TotalDays > 182.5; + + /// + /// True if Organization is disabled and the expiration date is more than three months ago + /// + /// + private bool OrgDisabledForMoreThanGracePeriod(Organization organization) => + !organization.Enabled && + ( + !organization.ExpirationDate.HasValue || + DateTime.UtcNow.Subtract(organization.ExpirationDate.Value).TotalDays > 93 + ); } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs index 6d186726a..69e6c3232 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs @@ -6,77 +6,76 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Utilities; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise -{ - public class CreateSponsorshipCommand : ICreateSponsorshipCommand - { - private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - private readonly IUserService _userService; +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; - public CreateSponsorshipCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IUserService userService) +public class CreateSponsorshipCommand : ICreateSponsorshipCommand +{ + private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + private readonly IUserService _userService; + + public CreateSponsorshipCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IUserService userService) + { + _organizationSponsorshipRepository = organizationSponsorshipRepository; + _userService = userService; + } + + public async Task CreateSponsorshipAsync(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, + PlanSponsorshipType sponsorshipType, string sponsoredEmail, string friendlyName) + { + var sponsoringUser = await _userService.GetUserByIdAsync(sponsoringOrgUser.UserId.Value); + if (sponsoringUser == null || string.Equals(sponsoringUser.Email, sponsoredEmail, System.StringComparison.InvariantCultureIgnoreCase)) { - _organizationSponsorshipRepository = organizationSponsorshipRepository; - _userService = userService; + throw new BadRequestException("Cannot offer a Families Organization Sponsorship to yourself. Choose a different email."); } - public async Task CreateSponsorshipAsync(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, - PlanSponsorshipType sponsorshipType, string sponsoredEmail, string friendlyName) + var requiredSponsoringProductType = StaticStore.GetSponsoredPlan(sponsorshipType)?.SponsoringProductType; + if (requiredSponsoringProductType == null || + sponsoringOrg == null || + StaticStore.GetPlan(sponsoringOrg.PlanType).Product != requiredSponsoringProductType.Value) { - var sponsoringUser = await _userService.GetUserByIdAsync(sponsoringOrgUser.UserId.Value); - if (sponsoringUser == null || string.Equals(sponsoringUser.Email, sponsoredEmail, System.StringComparison.InvariantCultureIgnoreCase)) - { - throw new BadRequestException("Cannot offer a Families Organization Sponsorship to yourself. Choose a different email."); - } + throw new BadRequestException("Specified Organization cannot sponsor other organizations."); + } - var requiredSponsoringProductType = StaticStore.GetSponsoredPlan(sponsorshipType)?.SponsoringProductType; - if (requiredSponsoringProductType == null || - sponsoringOrg == null || - StaticStore.GetPlan(sponsoringOrg.PlanType).Product != requiredSponsoringProductType.Value) - { - throw new BadRequestException("Specified Organization cannot sponsor other organizations."); - } + if (sponsoringOrgUser == null || sponsoringOrgUser.Status != OrganizationUserStatusType.Confirmed) + { + throw new BadRequestException("Only confirmed users can sponsor other organizations."); + } - if (sponsoringOrgUser == null || sponsoringOrgUser.Status != OrganizationUserStatusType.Confirmed) - { - throw new BadRequestException("Only confirmed users can sponsor other organizations."); - } + var existingOrgSponsorship = await _organizationSponsorshipRepository + .GetBySponsoringOrganizationUserIdAsync(sponsoringOrgUser.Id); + if (existingOrgSponsorship?.SponsoredOrganizationId != null) + { + throw new BadRequestException("Can only sponsor one organization per Organization User."); + } - var existingOrgSponsorship = await _organizationSponsorshipRepository - .GetBySponsoringOrganizationUserIdAsync(sponsoringOrgUser.Id); - if (existingOrgSponsorship?.SponsoredOrganizationId != null) - { - throw new BadRequestException("Can only sponsor one organization per Organization User."); - } + var sponsorship = new OrganizationSponsorship + { + SponsoringOrganizationId = sponsoringOrg.Id, + SponsoringOrganizationUserId = sponsoringOrgUser.Id, + FriendlyName = friendlyName, + OfferedToEmail = sponsoredEmail, + PlanSponsorshipType = sponsorshipType, + }; - var sponsorship = new OrganizationSponsorship - { - SponsoringOrganizationId = sponsoringOrg.Id, - SponsoringOrganizationUserId = sponsoringOrgUser.Id, - FriendlyName = friendlyName, - OfferedToEmail = sponsoredEmail, - PlanSponsorshipType = sponsorshipType, - }; + if (existingOrgSponsorship != null) + { + // Replace existing invalid offer with our new sponsorship offer + sponsorship.Id = existingOrgSponsorship.Id; + } - if (existingOrgSponsorship != null) + try + { + await _organizationSponsorshipRepository.UpsertAsync(sponsorship); + return sponsorship; + } + catch + { + if (sponsorship.Id != default) { - // Replace existing invalid offer with our new sponsorship offer - sponsorship.Id = existingOrgSponsorship.Id; - } - - try - { - await _organizationSponsorshipRepository.UpsertAsync(sponsorship); - return sponsorship; - } - catch - { - if (sponsorship.Id != default) - { - await _organizationSponsorshipRepository.DeleteAsync(sponsorship); - } - throw; + await _organizationSponsorshipRepository.DeleteAsync(sponsorship); } + throw; } } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ICreateSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ICreateSponsorshipCommand.cs index c321524e7..1ba4b3662 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ICreateSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ICreateSponsorshipCommand.cs @@ -1,11 +1,10 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; + +public interface ICreateSponsorshipCommand { - public interface ICreateSponsorshipCommand - { - Task CreateSponsorshipAsync(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, - PlanSponsorshipType sponsorshipType, string sponsoredEmail, string friendlyName); - } + Task CreateSponsorshipAsync(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, + PlanSponsorshipType sponsorshipType, string sponsoredEmail, string friendlyName); } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IOrganizationSponsorshipRenewCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IOrganizationSponsorshipRenewCommand.cs index 762d166bd..9d04c280d 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IOrganizationSponsorshipRenewCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IOrganizationSponsorshipRenewCommand.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; + +public interface IOrganizationSponsorshipRenewCommand { - public interface IOrganizationSponsorshipRenewCommand - { - Task UpdateExpirationDateAsync(Guid organizationId, DateTime expireDate); - } + Task UpdateExpirationDateAsync(Guid organizationId, DateTime expireDate); } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IRemoveSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IRemoveSponsorshipCommand.cs index 21a8fec89..a37e6cee9 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IRemoveSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IRemoveSponsorshipCommand.cs @@ -1,9 +1,8 @@ using Bit.Core.Entities; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; + +public interface IRemoveSponsorshipCommand { - public interface IRemoveSponsorshipCommand - { - Task RemoveSponsorshipAsync(OrganizationSponsorship sponsorship); - } + Task RemoveSponsorshipAsync(OrganizationSponsorship sponsorship); } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IRevokeSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IRevokeSponsorshipCommand.cs index 18ca25d17..48a496494 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IRevokeSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IRevokeSponsorshipCommand.cs @@ -1,9 +1,8 @@ using Bit.Core.Entities; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; + +public interface IRevokeSponsorshipCommand { - public interface IRevokeSponsorshipCommand - { - Task RevokeSponsorshipAsync(OrganizationSponsorship sponsorship); - } + Task RevokeSponsorshipAsync(OrganizationSponsorship sponsorship); } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISendSponsorshipOfferCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISendSponsorshipOfferCommand.cs index a047c4d66..9795ed00f 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISendSponsorshipOfferCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISendSponsorshipOfferCommand.cs @@ -1,12 +1,11 @@ using Bit.Core.Entities; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; + +public interface ISendSponsorshipOfferCommand { - public interface ISendSponsorshipOfferCommand - { - Task BulkSendSponsorshipOfferAsync(string sponsoringOrgName, IEnumerable invites); - Task SendSponsorshipOfferAsync(OrganizationSponsorship sponsorship, string sponsoringOrgName); - Task SendSponsorshipOfferAsync(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, - OrganizationSponsorship sponsorship); - } + Task BulkSendSponsorshipOfferAsync(string sponsoringOrgName, IEnumerable invites); + Task SendSponsorshipOfferAsync(OrganizationSponsorship sponsorship, string sponsoringOrgName); + Task SendSponsorshipOfferAsync(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, + OrganizationSponsorship sponsorship); } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISetUpSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISetUpSponsorshipCommand.cs index d4c5e9b0e..4c57c9072 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISetUpSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISetUpSponsorshipCommand.cs @@ -1,10 +1,9 @@ using Bit.Core.Entities; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; + +public interface ISetUpSponsorshipCommand { - public interface ISetUpSponsorshipCommand - { - Task SetUpSponsorshipAsync(OrganizationSponsorship sponsorship, - Organization sponsoredOrganization); - } + Task SetUpSponsorshipAsync(OrganizationSponsorship sponsorship, + Organization sponsoredOrganization); } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISyncOrganizationSponsorshipsCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISyncOrganizationSponsorshipsCommand.cs index 9533e4bfd..0b8bb6444 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISyncOrganizationSponsorshipsCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/ISyncOrganizationSponsorshipsCommand.cs @@ -1,15 +1,14 @@ using Bit.Core.Entities; using Bit.Core.Models.Data.Organizations.OrganizationSponsorships; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces -{ - public interface ISelfHostedSyncSponsorshipsCommand - { - Task SyncOrganization(Guid organizationId, Guid cloudOrganizationId, OrganizationConnection billingSyncConnection); - } +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; - public interface ICloudSyncSponsorshipsCommand - { - Task<(OrganizationSponsorshipSyncData, IEnumerable)> SyncOrganization(Organization sponsoringOrg, IEnumerable sponsorshipsData); - } +public interface ISelfHostedSyncSponsorshipsCommand +{ + Task SyncOrganization(Guid organizationId, Guid cloudOrganizationId, OrganizationConnection billingSyncConnection); +} + +public interface ICloudSyncSponsorshipsCommand +{ + Task<(OrganizationSponsorshipSyncData, IEnumerable)> SyncOrganization(Organization sponsoringOrg, IEnumerable sponsorshipsData); } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateBillingSyncKeyCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateBillingSyncKeyCommand.cs index 1ac3e1c0d..53e926903 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateBillingSyncKeyCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateBillingSyncKeyCommand.cs @@ -1,9 +1,8 @@ using Bit.Core.Entities; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; + +public interface IValidateBillingSyncKeyCommand { - public interface IValidateBillingSyncKeyCommand - { - Task ValidateBillingSyncKeyAsync(Organization organization, string billingSyncKey); - } + Task ValidateBillingSyncKeyAsync(Organization organization, string billingSyncKey); } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateRedemptionTokenCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateRedemptionTokenCommand.cs index a7db2ed2e..714e9e2b5 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateRedemptionTokenCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateRedemptionTokenCommand.cs @@ -1,9 +1,8 @@ using Bit.Core.Entities; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; + +public interface IValidateRedemptionTokenCommand { - public interface IValidateRedemptionTokenCommand - { - Task<(bool valid, OrganizationSponsorship sponsorship)> ValidateRedemptionTokenAsync(string encryptedToken, string sponsoredUserEmail); - } + Task<(bool valid, OrganizationSponsorship sponsorship)> ValidateRedemptionTokenAsync(string encryptedToken, string sponsoredUserEmail); } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateSponsorshipCommand.cs index 0d0124677..47b2e47c2 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Interfaces/IValidateSponsorshipCommand.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; + +public interface IValidateSponsorshipCommand { - public interface IValidateSponsorshipCommand - { - Task ValidateSponsorshipAsync(Guid sponsoredOrganizationId); - } + Task ValidateSponsorshipAsync(Guid sponsoredOrganizationId); } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedRevokeSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedRevokeSponsorshipCommand.cs index aad92f43c..820d27758 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedRevokeSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedRevokeSponsorshipCommand.cs @@ -3,31 +3,30 @@ using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.SelfHosted +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.SelfHosted; + +public class SelfHostedRevokeSponsorshipCommand : CancelSponsorshipCommand, IRevokeSponsorshipCommand { - public class SelfHostedRevokeSponsorshipCommand : CancelSponsorshipCommand, IRevokeSponsorshipCommand + public SelfHostedRevokeSponsorshipCommand( + IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IOrganizationRepository organizationRepository) : base(organizationSponsorshipRepository, organizationRepository) { - public SelfHostedRevokeSponsorshipCommand( - IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IOrganizationRepository organizationRepository) : base(organizationSponsorshipRepository, organizationRepository) + } + + public async Task RevokeSponsorshipAsync(OrganizationSponsorship sponsorship) + { + if (sponsorship == null) { + throw new BadRequestException("You are not currently sponsoring an organization."); } - public async Task RevokeSponsorshipAsync(OrganizationSponsorship sponsorship) + if (sponsorship.LastSyncDate == null) { - if (sponsorship == null) - { - throw new BadRequestException("You are not currently sponsoring an organization."); - } - - if (sponsorship.LastSyncDate == null) - { - await base.DeleteSponsorshipAsync(sponsorship); - } - else - { - await MarkToDeleteSponsorshipAsync(sponsorship); - } + await base.DeleteSponsorshipAsync(sponsorship); + } + else + { + await MarkToDeleteSponsorshipAsync(sponsorship); } } } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommand.cs index 4f12c1cf3..df293c3a7 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommand.cs @@ -11,121 +11,120 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.Extensions.Logging; -namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.SelfHosted -{ - public class SelfHostedSyncSponsorshipsCommand : BaseIdentityClientService, ISelfHostedSyncSponsorshipsCommand - { - private readonly IGlobalSettings _globalSettings; - private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IOrganizationConnectionRepository _organizationConnectionRepository; +namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.SelfHosted; - public SelfHostedSyncSponsorshipsCommand( - IHttpClientFactory httpFactory, - IOrganizationSponsorshipRepository organizationSponsorshipRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationConnectionRepository organizationConnectionRepository, - IGlobalSettings globalSettings, - ILogger logger) - : base( - httpFactory, - globalSettings.Installation.ApiUri, - globalSettings.Installation.IdentityUri, - "api.installation", - $"installation.{globalSettings.Installation.Id}", - globalSettings.Installation.Key, - logger) +public class SelfHostedSyncSponsorshipsCommand : BaseIdentityClientService, ISelfHostedSyncSponsorshipsCommand +{ + private readonly IGlobalSettings _globalSettings; + private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IOrganizationConnectionRepository _organizationConnectionRepository; + + public SelfHostedSyncSponsorshipsCommand( + IHttpClientFactory httpFactory, + IOrganizationSponsorshipRepository organizationSponsorshipRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationConnectionRepository organizationConnectionRepository, + IGlobalSettings globalSettings, + ILogger logger) + : base( + httpFactory, + globalSettings.Installation.ApiUri, + globalSettings.Installation.IdentityUri, + "api.installation", + $"installation.{globalSettings.Installation.Id}", + globalSettings.Installation.Key, + logger) + { + _globalSettings = globalSettings; + _organizationUserRepository = organizationUserRepository; + _organizationSponsorshipRepository = organizationSponsorshipRepository; + _organizationConnectionRepository = organizationConnectionRepository; + } + + public async Task SyncOrganization(Guid organizationId, Guid cloudOrganizationId, OrganizationConnection billingSyncConnection) + { + if (!_globalSettings.EnableCloudCommunication) { - _globalSettings = globalSettings; - _organizationUserRepository = organizationUserRepository; - _organizationSponsorshipRepository = organizationSponsorshipRepository; - _organizationConnectionRepository = organizationConnectionRepository; + throw new BadRequestException("Failed to sync instance with cloud - Cloud communication is disabled in global settings"); + } + if (!billingSyncConnection.Enabled) + { + throw new BadRequestException($"Billing Sync Key disabled for organization {organizationId}"); + } + if (string.IsNullOrWhiteSpace(billingSyncConnection.Config)) + { + throw new BadRequestException($"No Billing Sync Key known for organization {organizationId}"); + } + var billingSyncConfig = billingSyncConnection.GetConfig(); + if (billingSyncConfig == null || string.IsNullOrWhiteSpace(billingSyncConfig.BillingSyncKey)) + { + throw new BadRequestException($"Failed to get Billing Sync Key for organization {organizationId}"); } - public async Task SyncOrganization(Guid organizationId, Guid cloudOrganizationId, OrganizationConnection billingSyncConnection) + var organizationSponsorshipsDict = (await _organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(organizationId)) + .ToDictionary(i => i.SponsoringOrganizationUserId); + if (!organizationSponsorshipsDict.Any()) { - if (!_globalSettings.EnableCloudCommunication) - { - throw new BadRequestException("Failed to sync instance with cloud - Cloud communication is disabled in global settings"); - } - if (!billingSyncConnection.Enabled) - { - throw new BadRequestException($"Billing Sync Key disabled for organization {organizationId}"); - } - if (string.IsNullOrWhiteSpace(billingSyncConnection.Config)) - { - throw new BadRequestException($"No Billing Sync Key known for organization {organizationId}"); - } - var billingSyncConfig = billingSyncConnection.GetConfig(); - if (billingSyncConfig == null || string.IsNullOrWhiteSpace(billingSyncConfig.BillingSyncKey)) - { - throw new BadRequestException($"Failed to get Billing Sync Key for organization {organizationId}"); - } + _logger.LogInformation($"No existing sponsorships to sync for organization {organizationId}"); + return; + } + var syncedSponsorships = new List(); - var organizationSponsorshipsDict = (await _organizationSponsorshipRepository.GetManyBySponsoringOrganizationAsync(organizationId)) - .ToDictionary(i => i.SponsoringOrganizationUserId); - if (!organizationSponsorshipsDict.Any()) + foreach (var orgSponsorshipsBatch in CoreHelpers.Batch(organizationSponsorshipsDict.Values, 1000)) + { + var response = await SendAsync(HttpMethod.Post, "organization/sponsorship/sync", new OrganizationSponsorshipSyncRequestModel { - _logger.LogInformation($"No existing sponsorships to sync for organization {organizationId}"); - return; - } - var syncedSponsorships = new List(); - - foreach (var orgSponsorshipsBatch in CoreHelpers.Batch(organizationSponsorshipsDict.Values, 1000)) - { - var response = await SendAsync(HttpMethod.Post, "organization/sponsorship/sync", new OrganizationSponsorshipSyncRequestModel - { - BillingSyncKey = billingSyncConfig.BillingSyncKey, - SponsoringOrganizationCloudId = cloudOrganizationId, - SponsorshipsBatch = orgSponsorshipsBatch.Select(s => new OrganizationSponsorshipRequestModel(s)) - }); - - if (response == null) - { - _logger.LogDebug("Organization sync failed for '{OrgId}'", organizationId); - throw new BadRequestException("Organization sync failed"); - } - - syncedSponsorships.AddRange(response.ToOrganizationSponsorshipSync().SponsorshipsBatch); - } - - var sponsorshipsToDelete = syncedSponsorships.Where(s => s.CloudSponsorshipRemoved).Select(i => organizationSponsorshipsDict[i.SponsoringOrganizationUserId].Id); - var sponsorshipsToUpsert = syncedSponsorships.Where(s => !s.CloudSponsorshipRemoved).Select(i => - { - var existingSponsorship = organizationSponsorshipsDict[i.SponsoringOrganizationUserId]; - if (existingSponsorship != null) - { - existingSponsorship.LastSyncDate = i.LastSyncDate; - existingSponsorship.ValidUntil = i.ValidUntil; - existingSponsorship.ToDelete = i.ToDelete; - } - else - { - // shouldn't occur, added in case self hosted loses a sponsorship - existingSponsorship = new OrganizationSponsorship - { - SponsoringOrganizationId = organizationId, - SponsoringOrganizationUserId = i.SponsoringOrganizationUserId, - FriendlyName = i.FriendlyName, - OfferedToEmail = i.OfferedToEmail, - PlanSponsorshipType = i.PlanSponsorshipType, - LastSyncDate = i.LastSyncDate, - ValidUntil = i.ValidUntil, - ToDelete = i.ToDelete - }; - } - return existingSponsorship; + BillingSyncKey = billingSyncConfig.BillingSyncKey, + SponsoringOrganizationCloudId = cloudOrganizationId, + SponsorshipsBatch = orgSponsorshipsBatch.Select(s => new OrganizationSponsorshipRequestModel(s)) }); - if (sponsorshipsToDelete.Any()) + if (response == null) { - await _organizationSponsorshipRepository.DeleteManyAsync(sponsorshipsToDelete); - } - if (sponsorshipsToUpsert.Any()) - { - await _organizationSponsorshipRepository.UpsertManyAsync(sponsorshipsToUpsert); + _logger.LogDebug("Organization sync failed for '{OrgId}'", organizationId); + throw new BadRequestException("Organization sync failed"); } + + syncedSponsorships.AddRange(response.ToOrganizationSponsorshipSync().SponsorshipsBatch); } + var sponsorshipsToDelete = syncedSponsorships.Where(s => s.CloudSponsorshipRemoved).Select(i => organizationSponsorshipsDict[i.SponsoringOrganizationUserId].Id); + var sponsorshipsToUpsert = syncedSponsorships.Where(s => !s.CloudSponsorshipRemoved).Select(i => + { + var existingSponsorship = organizationSponsorshipsDict[i.SponsoringOrganizationUserId]; + if (existingSponsorship != null) + { + existingSponsorship.LastSyncDate = i.LastSyncDate; + existingSponsorship.ValidUntil = i.ValidUntil; + existingSponsorship.ToDelete = i.ToDelete; + } + else + { + // shouldn't occur, added in case self hosted loses a sponsorship + existingSponsorship = new OrganizationSponsorship + { + SponsoringOrganizationId = organizationId, + SponsoringOrganizationUserId = i.SponsoringOrganizationUserId, + FriendlyName = i.FriendlyName, + OfferedToEmail = i.OfferedToEmail, + PlanSponsorshipType = i.PlanSponsorshipType, + LastSyncDate = i.LastSyncDate, + ValidUntil = i.ValidUntil, + ToDelete = i.ToDelete + }; + } + return existingSponsorship; + }); + + if (sponsorshipsToDelete.Any()) + { + await _organizationSponsorshipRepository.DeleteManyAsync(sponsorshipsToDelete); + } + if (sponsorshipsToUpsert.Any()) + { + await _organizationSponsorshipRepository.UpsertManyAsync(sponsorshipsToUpsert); + } } + } diff --git a/src/Core/Repositories/ICipherRepository.cs b/src/Core/Repositories/ICipherRepository.cs index 5e071a55f..56f761935 100644 --- a/src/Core/Repositories/ICipherRepository.cs +++ b/src/Core/Repositories/ICipherRepository.cs @@ -2,38 +2,37 @@ using Bit.Core.Models.Data; using Core.Models.Data; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface ICipherRepository : IRepository { - public interface ICipherRepository : IRepository - { - Task GetByIdAsync(Guid id, Guid userId); - Task GetOrganizationDetailsByIdAsync(Guid id); - Task> GetManyOrganizationDetailsByOrganizationIdAsync(Guid organizationId); - Task GetCanEditByIdAsync(Guid userId, Guid cipherId); - Task> GetManyByUserIdAsync(Guid userId, bool withOrganizations = true); - Task> GetManyByOrganizationIdAsync(Guid organizationId); - Task CreateAsync(Cipher cipher, IEnumerable collectionIds); - Task CreateAsync(CipherDetails cipher); - Task CreateAsync(CipherDetails cipher, IEnumerable collectionIds); - Task ReplaceAsync(CipherDetails cipher); - Task UpsertAsync(CipherDetails cipher); - Task ReplaceAsync(Cipher obj, IEnumerable collectionIds); - Task UpdatePartialAsync(Guid id, Guid userId, Guid? folderId, bool favorite); - Task UpdateAttachmentAsync(CipherAttachment attachment); - Task DeleteAttachmentAsync(Guid cipherId, string attachmentId); - Task DeleteAsync(IEnumerable ids, Guid userId); - Task DeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId); - Task MoveAsync(IEnumerable ids, Guid? folderId, Guid userId); - Task DeleteByUserIdAsync(Guid userId); - Task DeleteByOrganizationIdAsync(Guid organizationId); - Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable ciphers, IEnumerable folders, IEnumerable sends); - Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers); - Task CreateAsync(IEnumerable ciphers, IEnumerable folders); - Task CreateAsync(IEnumerable ciphers, IEnumerable collections, - IEnumerable collectionCiphers); - Task SoftDeleteAsync(IEnumerable ids, Guid userId); - Task SoftDeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId); - Task RestoreAsync(IEnumerable ids, Guid userId); - Task DeleteDeletedAsync(DateTime deletedDateBefore); - } + Task GetByIdAsync(Guid id, Guid userId); + Task GetOrganizationDetailsByIdAsync(Guid id); + Task> GetManyOrganizationDetailsByOrganizationIdAsync(Guid organizationId); + Task GetCanEditByIdAsync(Guid userId, Guid cipherId); + Task> GetManyByUserIdAsync(Guid userId, bool withOrganizations = true); + Task> GetManyByOrganizationIdAsync(Guid organizationId); + Task CreateAsync(Cipher cipher, IEnumerable collectionIds); + Task CreateAsync(CipherDetails cipher); + Task CreateAsync(CipherDetails cipher, IEnumerable collectionIds); + Task ReplaceAsync(CipherDetails cipher); + Task UpsertAsync(CipherDetails cipher); + Task ReplaceAsync(Cipher obj, IEnumerable collectionIds); + Task UpdatePartialAsync(Guid id, Guid userId, Guid? folderId, bool favorite); + Task UpdateAttachmentAsync(CipherAttachment attachment); + Task DeleteAttachmentAsync(Guid cipherId, string attachmentId); + Task DeleteAsync(IEnumerable ids, Guid userId); + Task DeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId); + Task MoveAsync(IEnumerable ids, Guid? folderId, Guid userId); + Task DeleteByUserIdAsync(Guid userId); + Task DeleteByOrganizationIdAsync(Guid organizationId); + Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable ciphers, IEnumerable folders, IEnumerable sends); + Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers); + Task CreateAsync(IEnumerable ciphers, IEnumerable folders); + Task CreateAsync(IEnumerable ciphers, IEnumerable collections, + IEnumerable collectionCiphers); + Task SoftDeleteAsync(IEnumerable ids, Guid userId); + Task SoftDeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId); + Task RestoreAsync(IEnumerable ids, Guid userId); + Task DeleteDeletedAsync(DateTime deletedDateBefore); } diff --git a/src/Core/Repositories/ICollectionCipherRepository.cs b/src/Core/Repositories/ICollectionCipherRepository.cs index b79c65737..272128810 100644 --- a/src/Core/Repositories/ICollectionCipherRepository.cs +++ b/src/Core/Repositories/ICollectionCipherRepository.cs @@ -1,15 +1,14 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface ICollectionCipherRepository { - public interface ICollectionCipherRepository - { - Task> GetManyByUserIdAsync(Guid userId); - Task> GetManyByOrganizationIdAsync(Guid organizationId); - Task> GetManyByUserIdCipherIdAsync(Guid userId, Guid cipherId); - Task UpdateCollectionsAsync(Guid cipherId, Guid userId, IEnumerable collectionIds); - Task UpdateCollectionsForAdminAsync(Guid cipherId, Guid organizationId, IEnumerable collectionIds); - Task UpdateCollectionsForCiphersAsync(IEnumerable cipherIds, Guid userId, Guid organizationId, - IEnumerable collectionIds); - } + Task> GetManyByUserIdAsync(Guid userId); + Task> GetManyByOrganizationIdAsync(Guid organizationId); + Task> GetManyByUserIdCipherIdAsync(Guid userId, Guid cipherId); + Task UpdateCollectionsAsync(Guid cipherId, Guid userId, IEnumerable collectionIds); + Task UpdateCollectionsForAdminAsync(Guid cipherId, Guid organizationId, IEnumerable collectionIds); + Task UpdateCollectionsForCiphersAsync(IEnumerable cipherIds, Guid userId, Guid organizationId, + IEnumerable collectionIds); } diff --git a/src/Core/Repositories/ICollectionRepository.cs b/src/Core/Repositories/ICollectionRepository.cs index e53379997..dda042aa8 100644 --- a/src/Core/Repositories/ICollectionRepository.cs +++ b/src/Core/Repositories/ICollectionRepository.cs @@ -1,20 +1,19 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface ICollectionRepository : IRepository { - public interface ICollectionRepository : IRepository - { - Task GetCountByOrganizationIdAsync(Guid organizationId); - Task>> GetByIdWithGroupsAsync(Guid id); - Task>> GetByIdWithGroupsAsync(Guid id, Guid userId); - Task> GetManyByOrganizationIdAsync(Guid organizationId); - Task GetByIdAsync(Guid id, Guid userId); - Task> GetManyByUserIdAsync(Guid userId); - Task CreateAsync(Collection obj, IEnumerable groups); - Task ReplaceAsync(Collection obj, IEnumerable groups); - Task DeleteUserAsync(Guid collectionId, Guid organizationUserId); - Task UpdateUsersAsync(Guid id, IEnumerable users); - Task> GetManyUsersByIdAsync(Guid id); - } + Task GetCountByOrganizationIdAsync(Guid organizationId); + Task>> GetByIdWithGroupsAsync(Guid id); + Task>> GetByIdWithGroupsAsync(Guid id, Guid userId); + Task> GetManyByOrganizationIdAsync(Guid organizationId); + Task GetByIdAsync(Guid id, Guid userId); + Task> GetManyByUserIdAsync(Guid userId); + Task CreateAsync(Collection obj, IEnumerable groups); + Task ReplaceAsync(Collection obj, IEnumerable groups); + Task DeleteUserAsync(Guid collectionId, Guid organizationUserId); + Task UpdateUsersAsync(Guid id, IEnumerable users); + Task> GetManyUsersByIdAsync(Guid id); } diff --git a/src/Core/Repositories/IDeviceRepository.cs b/src/Core/Repositories/IDeviceRepository.cs index 85221d446..5424d5fe3 100644 --- a/src/Core/Repositories/IDeviceRepository.cs +++ b/src/Core/Repositories/IDeviceRepository.cs @@ -1,13 +1,12 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IDeviceRepository : IRepository { - public interface IDeviceRepository : IRepository - { - Task GetByIdAsync(Guid id, Guid userId); - Task GetByIdentifierAsync(string identifier); - Task GetByIdentifierAsync(string identifier, Guid userId); - Task> GetManyByUserIdAsync(Guid userId); - Task ClearPushTokenAsync(Guid id); - } + Task GetByIdAsync(Guid id, Guid userId); + Task GetByIdentifierAsync(string identifier); + Task GetByIdentifierAsync(string identifier, Guid userId); + Task> GetManyByUserIdAsync(Guid userId); + Task ClearPushTokenAsync(Guid id); } diff --git a/src/Core/Repositories/IEmergencyAccessRepository.cs b/src/Core/Repositories/IEmergencyAccessRepository.cs index 449bfe631..790f7191c 100644 --- a/src/Core/Repositories/IEmergencyAccessRepository.cs +++ b/src/Core/Repositories/IEmergencyAccessRepository.cs @@ -1,15 +1,14 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IEmergencyAccessRepository : IRepository { - public interface IEmergencyAccessRepository : IRepository - { - Task GetCountByGrantorIdEmailAsync(Guid grantorId, string email, bool onlyRegisteredUsers); - Task> GetManyDetailsByGrantorIdAsync(Guid grantorId); - Task> GetManyDetailsByGranteeIdAsync(Guid granteeId); - Task GetDetailsByIdGrantorIdAsync(Guid id, Guid grantorId); - Task> GetManyToNotifyAsync(); - Task> GetExpiredRecoveriesAsync(); - } + Task GetCountByGrantorIdEmailAsync(Guid grantorId, string email, bool onlyRegisteredUsers); + Task> GetManyDetailsByGrantorIdAsync(Guid grantorId); + Task> GetManyDetailsByGranteeIdAsync(Guid granteeId); + Task GetDetailsByIdGrantorIdAsync(Guid id, Guid grantorId); + Task> GetManyToNotifyAsync(); + Task> GetExpiredRecoveriesAsync(); } diff --git a/src/Core/Repositories/IEventRepository.cs b/src/Core/Repositories/IEventRepository.cs index c2af5c0e0..bac3cb534 100644 --- a/src/Core/Repositories/IEventRepository.cs +++ b/src/Core/Repositories/IEventRepository.cs @@ -1,23 +1,22 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IEventRepository { - public interface IEventRepository - { - Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate, - PageOptions pageOptions); - Task> GetManyByOrganizationAsync(Guid organizationId, DateTime startDate, DateTime endDate, - PageOptions pageOptions); - Task> GetManyByOrganizationActingUserAsync(Guid organizationId, Guid actingUserId, - DateTime startDate, DateTime endDate, PageOptions pageOptions); - Task> GetManyByProviderAsync(Guid providerId, DateTime startDate, DateTime endDate, - PageOptions pageOptions); - Task> GetManyByProviderActingUserAsync(Guid providerId, Guid actingUserId, - DateTime startDate, DateTime endDate, PageOptions pageOptions); - Task> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate, - PageOptions pageOptions); - Task CreateAsync(IEvent e); - Task CreateManyAsync(IEnumerable e); - } + Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate, + PageOptions pageOptions); + Task> GetManyByOrganizationAsync(Guid organizationId, DateTime startDate, DateTime endDate, + PageOptions pageOptions); + Task> GetManyByOrganizationActingUserAsync(Guid organizationId, Guid actingUserId, + DateTime startDate, DateTime endDate, PageOptions pageOptions); + Task> GetManyByProviderAsync(Guid providerId, DateTime startDate, DateTime endDate, + PageOptions pageOptions); + Task> GetManyByProviderActingUserAsync(Guid providerId, Guid actingUserId, + DateTime startDate, DateTime endDate, PageOptions pageOptions); + Task> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate, + PageOptions pageOptions); + Task CreateAsync(IEvent e); + Task CreateManyAsync(IEnumerable e); } diff --git a/src/Core/Repositories/IFolderRepository.cs b/src/Core/Repositories/IFolderRepository.cs index c174f4fb1..b93ca097b 100644 --- a/src/Core/Repositories/IFolderRepository.cs +++ b/src/Core/Repositories/IFolderRepository.cs @@ -1,10 +1,9 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IFolderRepository : IRepository { - public interface IFolderRepository : IRepository - { - Task GetByIdAsync(Guid id, Guid userId); - Task> GetManyByUserIdAsync(Guid userId); - } + Task GetByIdAsync(Guid id, Guid userId); + Task> GetManyByUserIdAsync(Guid userId); } diff --git a/src/Core/Repositories/IGrantRepository.cs b/src/Core/Repositories/IGrantRepository.cs index edab4c815..14f4fcb03 100644 --- a/src/Core/Repositories/IGrantRepository.cs +++ b/src/Core/Repositories/IGrantRepository.cs @@ -1,13 +1,12 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IGrantRepository { - public interface IGrantRepository - { - Task GetByKeyAsync(string key); - Task> GetManyAsync(string subjectId, string sessionId, string clientId, string type); - Task SaveAsync(Grant obj); - Task DeleteByKeyAsync(string key); - Task DeleteManyAsync(string subjectId, string sessionId, string clientId, string type); - } + Task GetByKeyAsync(string key); + Task> GetManyAsync(string subjectId, string sessionId, string clientId, string type); + Task SaveAsync(Grant obj); + Task DeleteByKeyAsync(string key); + Task DeleteManyAsync(string subjectId, string sessionId, string clientId, string type); } diff --git a/src/Core/Repositories/IGroupRepository.cs b/src/Core/Repositories/IGroupRepository.cs index e8cdc43bc..d7b9b664d 100644 --- a/src/Core/Repositories/IGroupRepository.cs +++ b/src/Core/Repositories/IGroupRepository.cs @@ -1,18 +1,17 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IGroupRepository : IRepository { - public interface IGroupRepository : IRepository - { - Task>> GetByIdWithCollectionsAsync(Guid id); - Task> GetManyByOrganizationIdAsync(Guid organizationId); - Task> GetManyIdsByUserIdAsync(Guid organizationUserId); - Task> GetManyUserIdsByIdAsync(Guid id); - Task> GetManyGroupUsersByOrganizationIdAsync(Guid organizationId); - Task CreateAsync(Group obj, IEnumerable collections); - Task ReplaceAsync(Group obj, IEnumerable collections); - Task DeleteUserAsync(Guid groupId, Guid organizationUserId); - Task UpdateUsersAsync(Guid groupId, IEnumerable organizationUserIds); - } + Task>> GetByIdWithCollectionsAsync(Guid id); + Task> GetManyByOrganizationIdAsync(Guid organizationId); + Task> GetManyIdsByUserIdAsync(Guid organizationUserId); + Task> GetManyUserIdsByIdAsync(Guid id); + Task> GetManyGroupUsersByOrganizationIdAsync(Guid organizationId); + Task CreateAsync(Group obj, IEnumerable collections); + Task ReplaceAsync(Group obj, IEnumerable collections); + Task DeleteUserAsync(Guid groupId, Guid organizationUserId); + Task UpdateUsersAsync(Guid groupId, IEnumerable organizationUserIds); } diff --git a/src/Core/Repositories/IInstallationDeviceRepository.cs b/src/Core/Repositories/IInstallationDeviceRepository.cs index 394b80837..bdbeaf297 100644 --- a/src/Core/Repositories/IInstallationDeviceRepository.cs +++ b/src/Core/Repositories/IInstallationDeviceRepository.cs @@ -1,11 +1,10 @@ using Bit.Core.Models.Data; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IInstallationDeviceRepository { - public interface IInstallationDeviceRepository - { - Task UpsertAsync(InstallationDeviceEntity entity); - Task UpsertManyAsync(IList entities); - Task DeleteAsync(InstallationDeviceEntity entity); - } + Task UpsertAsync(InstallationDeviceEntity entity); + Task UpsertManyAsync(IList entities); + Task DeleteAsync(InstallationDeviceEntity entity); } diff --git a/src/Core/Repositories/IInstallationRepository.cs b/src/Core/Repositories/IInstallationRepository.cs index f88e81e5f..65ee34aaf 100644 --- a/src/Core/Repositories/IInstallationRepository.cs +++ b/src/Core/Repositories/IInstallationRepository.cs @@ -1,8 +1,7 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IInstallationRepository : IRepository { - public interface IInstallationRepository : IRepository - { - } } diff --git a/src/Core/Repositories/IMaintenanceRepository.cs b/src/Core/Repositories/IMaintenanceRepository.cs index c1dc098c6..a89c38bd0 100644 --- a/src/Core/Repositories/IMaintenanceRepository.cs +++ b/src/Core/Repositories/IMaintenanceRepository.cs @@ -1,11 +1,10 @@ -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IMaintenanceRepository { - public interface IMaintenanceRepository - { - Task UpdateStatisticsAsync(); - Task DisableCipherAutoStatsAsync(); - Task RebuildIndexesAsync(); - Task DeleteExpiredGrantsAsync(); - Task DeleteExpiredSponsorshipsAsync(DateTime validUntilBeforeDate); - } + Task UpdateStatisticsAsync(); + Task DisableCipherAutoStatsAsync(); + Task RebuildIndexesAsync(); + Task DeleteExpiredGrantsAsync(); + Task DeleteExpiredSponsorshipsAsync(DateTime validUntilBeforeDate); } diff --git a/src/Core/Repositories/IMetaDataRepository.cs b/src/Core/Repositories/IMetaDataRepository.cs index 69895b9c8..e087234da 100644 --- a/src/Core/Repositories/IMetaDataRepository.cs +++ b/src/Core/Repositories/IMetaDataRepository.cs @@ -1,11 +1,10 @@ -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IMetaDataRepository { - public interface IMetaDataRepository - { - Task DeleteAsync(string objectName, string id); - Task> GetAsync(string objectName, string id); - Task GetAsync(string objectName, string id, string prop); - Task UpsertAsync(string objectName, string id, IDictionary dict); - Task UpsertAsync(string objectName, string id, KeyValuePair keyValuePair); - } + Task DeleteAsync(string objectName, string id); + Task> GetAsync(string objectName, string id); + Task GetAsync(string objectName, string id, string prop); + Task UpsertAsync(string objectName, string id, IDictionary dict); + Task UpsertAsync(string objectName, string id, KeyValuePair keyValuePair); } diff --git a/src/Core/Repositories/IOrganizationApiKeyRepository.cs b/src/Core/Repositories/IOrganizationApiKeyRepository.cs index 8b1b24978..778db9d73 100644 --- a/src/Core/Repositories/IOrganizationApiKeyRepository.cs +++ b/src/Core/Repositories/IOrganizationApiKeyRepository.cs @@ -1,10 +1,9 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IOrganizationApiKeyRepository : IRepository { - public interface IOrganizationApiKeyRepository : IRepository - { - Task> GetManyByOrganizationIdTypeAsync(Guid organizationId, OrganizationApiKeyType? type = null); - } + Task> GetManyByOrganizationIdTypeAsync(Guid organizationId, OrganizationApiKeyType? type = null); } diff --git a/src/Core/Repositories/IOrganizationConnectionRepository.cs b/src/Core/Repositories/IOrganizationConnectionRepository.cs index b87a82d14..a3bdbb037 100644 --- a/src/Core/Repositories/IOrganizationConnectionRepository.cs +++ b/src/Core/Repositories/IOrganizationConnectionRepository.cs @@ -1,11 +1,10 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IOrganizationConnectionRepository : IRepository { - public interface IOrganizationConnectionRepository : IRepository - { - Task> GetByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type); - Task> GetEnabledByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type); - } + Task> GetByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type); + Task> GetEnabledByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type); } diff --git a/src/Core/Repositories/IOrganizationRepository.cs b/src/Core/Repositories/IOrganizationRepository.cs index 392a925c3..690bff913 100644 --- a/src/Core/Repositories/IOrganizationRepository.cs +++ b/src/Core/Repositories/IOrganizationRepository.cs @@ -1,15 +1,14 @@ using Bit.Core.Entities; using Bit.Core.Models.Data.Organizations; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IOrganizationRepository : IRepository { - public interface IOrganizationRepository : IRepository - { - Task GetByIdentifierAsync(string identifier); - Task> GetManyByEnabledAsync(); - Task> GetManyByUserIdAsync(Guid userId); - Task> SearchAsync(string name, string userEmail, bool? paid, int skip, int take); - Task UpdateStorageAsync(Guid id); - Task> GetManyAbilitiesAsync(); - } + Task GetByIdentifierAsync(string identifier); + Task> GetManyByEnabledAsync(); + Task> GetManyByUserIdAsync(Guid userId); + Task> SearchAsync(string name, string userEmail, bool? paid, int skip, int take); + Task UpdateStorageAsync(Guid id); + Task> GetManyAbilitiesAsync(); } diff --git a/src/Core/Repositories/IOrganizationSponsorshipRepository.cs b/src/Core/Repositories/IOrganizationSponsorshipRepository.cs index 2ef24580c..232fd1b9d 100644 --- a/src/Core/Repositories/IOrganizationSponsorshipRepository.cs +++ b/src/Core/Repositories/IOrganizationSponsorshipRepository.cs @@ -1,16 +1,15 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IOrganizationSponsorshipRepository : IRepository { - public interface IOrganizationSponsorshipRepository : IRepository - { - Task> CreateManyAsync(IEnumerable organizationSponsorships); - Task ReplaceManyAsync(IEnumerable organizationSponsorships); - Task UpsertManyAsync(IEnumerable organizationSponsorships); - Task DeleteManyAsync(IEnumerable organizationSponsorshipIds); - Task> GetManyBySponsoringOrganizationAsync(Guid sponsoringOrganizationId); - Task GetBySponsoringOrganizationUserIdAsync(Guid sponsoringOrganizationUserId); - Task GetBySponsoredOrganizationIdAsync(Guid sponsoredOrganizationId); - Task GetLatestSyncDateBySponsoringOrganizationIdAsync(Guid sponsoringOrganizationId); - } + Task> CreateManyAsync(IEnumerable organizationSponsorships); + Task ReplaceManyAsync(IEnumerable organizationSponsorships); + Task UpsertManyAsync(IEnumerable organizationSponsorships); + Task DeleteManyAsync(IEnumerable organizationSponsorshipIds); + Task> GetManyBySponsoringOrganizationAsync(Guid sponsoringOrganizationId); + Task GetBySponsoringOrganizationUserIdAsync(Guid sponsoringOrganizationUserId); + Task GetBySponsoredOrganizationIdAsync(Guid sponsoredOrganizationId); + Task GetLatestSyncDateBySponsoringOrganizationIdAsync(Guid sponsoringOrganizationId); } diff --git a/src/Core/Repositories/IOrganizationUserRepository.cs b/src/Core/Repositories/IOrganizationUserRepository.cs index 8597adb52..f8909f784 100644 --- a/src/Core/Repositories/IOrganizationUserRepository.cs +++ b/src/Core/Repositories/IOrganizationUserRepository.cs @@ -3,40 +3,39 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IOrganizationUserRepository : IRepository { - public interface IOrganizationUserRepository : IRepository - { - Task GetCountByOrganizationIdAsync(Guid organizationId); - Task GetCountByFreeOrganizationAdminUserAsync(Guid userId); - Task GetCountByOnlyOwnerAsync(Guid userId); - Task> GetManyByUserAsync(Guid userId); - Task> GetManyByOrganizationAsync(Guid organizationId, OrganizationUserType? type); - Task GetCountByOrganizationAsync(Guid organizationId, string email, bool onlyRegisteredUsers); - Task> SelectKnownEmailsAsync(Guid organizationId, IEnumerable emails, bool onlyRegisteredUsers); - Task GetByOrganizationAsync(Guid organizationId, Guid userId); - Task>> GetByIdWithCollectionsAsync(Guid id); - Task GetDetailsByIdAsync(Guid id); - Task>> - GetDetailsByIdWithCollectionsAsync(Guid id); - Task> GetManyDetailsByOrganizationAsync(Guid organizationId); - Task> GetManyDetailsByUserAsync(Guid userId, - OrganizationUserStatusType? status = null); - Task GetDetailsByUserAsync(Guid userId, Guid organizationId, - OrganizationUserStatusType? status = null); - Task UpdateGroupsAsync(Guid orgUserId, IEnumerable groupIds); - Task UpsertManyAsync(IEnumerable organizationUsers); - Task CreateAsync(OrganizationUser obj, IEnumerable collections); - Task> CreateManyAsync(IEnumerable organizationIdUsers); - Task ReplaceAsync(OrganizationUser obj, IEnumerable collections); - Task ReplaceManyAsync(IEnumerable organizationUsers); - Task> GetManyByManyUsersAsync(IEnumerable userIds); - Task> GetManyAsync(IEnumerable Ids); - Task DeleteManyAsync(IEnumerable userIds); - Task GetByOrganizationEmailAsync(Guid organizationId, string email); - Task> GetManyPublicKeysByOrganizationUserAsync(Guid organizationId, IEnumerable Ids); - Task> GetManyByMinimumRoleAsync(Guid organizationId, OrganizationUserType minRole); - Task RevokeAsync(Guid id); - Task RestoreAsync(Guid id, OrganizationUserStatusType status); - } + Task GetCountByOrganizationIdAsync(Guid organizationId); + Task GetCountByFreeOrganizationAdminUserAsync(Guid userId); + Task GetCountByOnlyOwnerAsync(Guid userId); + Task> GetManyByUserAsync(Guid userId); + Task> GetManyByOrganizationAsync(Guid organizationId, OrganizationUserType? type); + Task GetCountByOrganizationAsync(Guid organizationId, string email, bool onlyRegisteredUsers); + Task> SelectKnownEmailsAsync(Guid organizationId, IEnumerable emails, bool onlyRegisteredUsers); + Task GetByOrganizationAsync(Guid organizationId, Guid userId); + Task>> GetByIdWithCollectionsAsync(Guid id); + Task GetDetailsByIdAsync(Guid id); + Task>> + GetDetailsByIdWithCollectionsAsync(Guid id); + Task> GetManyDetailsByOrganizationAsync(Guid organizationId); + Task> GetManyDetailsByUserAsync(Guid userId, + OrganizationUserStatusType? status = null); + Task GetDetailsByUserAsync(Guid userId, Guid organizationId, + OrganizationUserStatusType? status = null); + Task UpdateGroupsAsync(Guid orgUserId, IEnumerable groupIds); + Task UpsertManyAsync(IEnumerable organizationUsers); + Task CreateAsync(OrganizationUser obj, IEnumerable collections); + Task> CreateManyAsync(IEnumerable organizationIdUsers); + Task ReplaceAsync(OrganizationUser obj, IEnumerable collections); + Task ReplaceManyAsync(IEnumerable organizationUsers); + Task> GetManyByManyUsersAsync(IEnumerable userIds); + Task> GetManyAsync(IEnumerable Ids); + Task DeleteManyAsync(IEnumerable userIds); + Task GetByOrganizationEmailAsync(Guid organizationId, string email); + Task> GetManyPublicKeysByOrganizationUserAsync(Guid organizationId, IEnumerable Ids); + Task> GetManyByMinimumRoleAsync(Guid organizationId, OrganizationUserType minRole); + Task RevokeAsync(Guid id); + Task RestoreAsync(Guid id, OrganizationUserStatusType status); } diff --git a/src/Core/Repositories/IPolicyRepository.cs b/src/Core/Repositories/IPolicyRepository.cs index 34206770e..ce965e174 100644 --- a/src/Core/Repositories/IPolicyRepository.cs +++ b/src/Core/Repositories/IPolicyRepository.cs @@ -1,16 +1,15 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IPolicyRepository : IRepository { - public interface IPolicyRepository : IRepository - { - Task GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type); - Task> GetManyByOrganizationIdAsync(Guid organizationId); - Task> GetManyByUserIdAsync(Guid userId); - Task> GetManyByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, - OrganizationUserStatusType minStatus = OrganizationUserStatusType.Accepted); - Task GetCountByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, - OrganizationUserStatusType minStatus = OrganizationUserStatusType.Accepted); - } + Task GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type); + Task> GetManyByOrganizationIdAsync(Guid organizationId); + Task> GetManyByUserIdAsync(Guid userId); + Task> GetManyByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, + OrganizationUserStatusType minStatus = OrganizationUserStatusType.Accepted); + Task GetCountByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, + OrganizationUserStatusType minStatus = OrganizationUserStatusType.Accepted); } diff --git a/src/Core/Repositories/IProviderOrganizationRepository.cs b/src/Core/Repositories/IProviderOrganizationRepository.cs index 7c2cfb3b1..b546d8d2e 100644 --- a/src/Core/Repositories/IProviderOrganizationRepository.cs +++ b/src/Core/Repositories/IProviderOrganizationRepository.cs @@ -1,11 +1,10 @@ using Bit.Core.Entities.Provider; using Bit.Core.Models.Data; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IProviderOrganizationRepository : IRepository { - public interface IProviderOrganizationRepository : IRepository - { - Task> GetManyDetailsByProviderAsync(Guid providerId); - Task GetByOrganizationId(Guid organizationId); - } + Task> GetManyDetailsByProviderAsync(Guid providerId); + Task GetByOrganizationId(Guid organizationId); } diff --git a/src/Core/Repositories/IProviderRepository.cs b/src/Core/Repositories/IProviderRepository.cs index 5b3700d81..8d92fb6d2 100644 --- a/src/Core/Repositories/IProviderRepository.cs +++ b/src/Core/Repositories/IProviderRepository.cs @@ -1,11 +1,10 @@ using Bit.Core.Entities.Provider; using Bit.Core.Models.Data; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IProviderRepository : IRepository { - public interface IProviderRepository : IRepository - { - Task> SearchAsync(string name, string userEmail, int skip, int take); - Task> GetManyAbilitiesAsync(); - } + Task> SearchAsync(string name, string userEmail, int skip, int take); + Task> GetManyAbilitiesAsync(); } diff --git a/src/Core/Repositories/IProviderUserRepository.cs b/src/Core/Repositories/IProviderUserRepository.cs index 14882a0a5..4a5db368e 100644 --- a/src/Core/Repositories/IProviderUserRepository.cs +++ b/src/Core/Repositories/IProviderUserRepository.cs @@ -2,21 +2,20 @@ using Bit.Core.Enums.Provider; using Bit.Core.Models.Data; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IProviderUserRepository : IRepository { - public interface IProviderUserRepository : IRepository - { - Task GetCountByProviderAsync(Guid providerId, string email, bool onlyRegisteredUsers); - Task> GetManyAsync(IEnumerable ids); - Task> GetManyByUserAsync(Guid userId); - Task GetByProviderUserAsync(Guid providerId, Guid userId); - Task> GetManyByProviderAsync(Guid providerId, ProviderUserType? type = null); - Task> GetManyDetailsByProviderAsync(Guid providerId); - Task> GetManyDetailsByUserAsync(Guid userId, - ProviderUserStatusType? status = null); - Task> GetManyOrganizationDetailsByUserAsync(Guid userId, ProviderUserStatusType? status = null); - Task DeleteManyAsync(IEnumerable userIds); - Task> GetManyPublicKeysByProviderUserAsync(Guid providerId, IEnumerable Ids); - Task GetCountByOnlyOwnerAsync(Guid userId); - } + Task GetCountByProviderAsync(Guid providerId, string email, bool onlyRegisteredUsers); + Task> GetManyAsync(IEnumerable ids); + Task> GetManyByUserAsync(Guid userId); + Task GetByProviderUserAsync(Guid providerId, Guid userId); + Task> GetManyByProviderAsync(Guid providerId, ProviderUserType? type = null); + Task> GetManyDetailsByProviderAsync(Guid providerId); + Task> GetManyDetailsByUserAsync(Guid userId, + ProviderUserStatusType? status = null); + Task> GetManyOrganizationDetailsByUserAsync(Guid userId, ProviderUserStatusType? status = null); + Task DeleteManyAsync(IEnumerable userIds); + Task> GetManyPublicKeysByProviderUserAsync(Guid providerId, IEnumerable Ids); + Task GetCountByOnlyOwnerAsync(Guid userId); } diff --git a/src/Core/Repositories/IRepository.cs b/src/Core/Repositories/IRepository.cs index 3316bef51..18bb81ff8 100644 --- a/src/Core/Repositories/IRepository.cs +++ b/src/Core/Repositories/IRepository.cs @@ -1,13 +1,12 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IRepository where TId : IEquatable where T : class, ITableObject { - public interface IRepository where TId : IEquatable where T : class, ITableObject - { - Task GetByIdAsync(TId id); - Task CreateAsync(T obj); - Task ReplaceAsync(T obj); - Task UpsertAsync(T obj); - Task DeleteAsync(T obj); - } + Task GetByIdAsync(TId id); + Task CreateAsync(T obj); + Task ReplaceAsync(T obj); + Task UpsertAsync(T obj); + Task DeleteAsync(T obj); } diff --git a/src/Core/Repositories/ISendRepository.cs b/src/Core/Repositories/ISendRepository.cs index 4a4fe5ebf..b35a059d3 100644 --- a/src/Core/Repositories/ISendRepository.cs +++ b/src/Core/Repositories/ISendRepository.cs @@ -1,10 +1,9 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface ISendRepository : IRepository { - public interface ISendRepository : IRepository - { - Task> GetManyByUserIdAsync(Guid userId); - Task> GetManyByDeletionDateAsync(DateTime deletionDateBefore); - } + Task> GetManyByUserIdAsync(Guid userId); + Task> GetManyByDeletionDateAsync(DateTime deletionDateBefore); } diff --git a/src/Core/Repositories/ISsoConfigRepository.cs b/src/Core/Repositories/ISsoConfigRepository.cs index 2350e0a4a..8f6561819 100644 --- a/src/Core/Repositories/ISsoConfigRepository.cs +++ b/src/Core/Repositories/ISsoConfigRepository.cs @@ -1,11 +1,10 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface ISsoConfigRepository : IRepository { - public interface ISsoConfigRepository : IRepository - { - Task GetByOrganizationIdAsync(Guid organizationId); - Task GetByIdentifierAsync(string identifier); - Task> GetManyByRevisionNotBeforeDate(DateTime? notBefore); - } + Task GetByOrganizationIdAsync(Guid organizationId); + Task GetByIdentifierAsync(string identifier); + Task> GetManyByRevisionNotBeforeDate(DateTime? notBefore); } diff --git a/src/Core/Repositories/ISsoUserRepository.cs b/src/Core/Repositories/ISsoUserRepository.cs index 6dcada929..653734450 100644 --- a/src/Core/Repositories/ISsoUserRepository.cs +++ b/src/Core/Repositories/ISsoUserRepository.cs @@ -1,10 +1,9 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface ISsoUserRepository : IRepository { - public interface ISsoUserRepository : IRepository - { - Task DeleteAsync(Guid userId, Guid? organizationId); - Task GetByUserIdOrganizationIdAsync(Guid organizationId, Guid userId); - } + Task DeleteAsync(Guid userId, Guid? organizationId); + Task GetByUserIdOrganizationIdAsync(Guid organizationId, Guid userId); } diff --git a/src/Core/Repositories/ITaxRateRepository.cs b/src/Core/Repositories/ITaxRateRepository.cs index 779c2c714..a8557a789 100644 --- a/src/Core/Repositories/ITaxRateRepository.cs +++ b/src/Core/Repositories/ITaxRateRepository.cs @@ -1,12 +1,11 @@ using Bit.Core.Entities; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface ITaxRateRepository : IRepository { - public interface ITaxRateRepository : IRepository - { - Task> SearchAsync(int skip, int count); - Task> GetAllActiveAsync(); - Task ArchiveAsync(TaxRate model); - Task> GetByLocationAsync(TaxRate taxRate); - } + Task> SearchAsync(int skip, int count); + Task> GetAllActiveAsync(); + Task ArchiveAsync(TaxRate model); + Task> GetByLocationAsync(TaxRate taxRate); } diff --git a/src/Core/Repositories/ITransactionRepository.cs b/src/Core/Repositories/ITransactionRepository.cs index 6fb9b27b7..82b6f961b 100644 --- a/src/Core/Repositories/ITransactionRepository.cs +++ b/src/Core/Repositories/ITransactionRepository.cs @@ -1,12 +1,11 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface ITransactionRepository : IRepository { - public interface ITransactionRepository : IRepository - { - Task> GetManyByUserIdAsync(Guid userId); - Task> GetManyByOrganizationIdAsync(Guid organizationId); - Task GetByGatewayIdAsync(GatewayType gatewayType, string gatewayId); - } + Task> GetManyByUserIdAsync(Guid userId); + Task> GetManyByOrganizationIdAsync(Guid organizationId); + Task GetByGatewayIdAsync(GatewayType gatewayType, string gatewayId); } diff --git a/src/Core/Repositories/IUserRepository.cs b/src/Core/Repositories/IUserRepository.cs index 8ed89f7e0..0c6ee8571 100644 --- a/src/Core/Repositories/IUserRepository.cs +++ b/src/Core/Repositories/IUserRepository.cs @@ -1,19 +1,18 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Core.Repositories +namespace Bit.Core.Repositories; + +public interface IUserRepository : IRepository { - public interface IUserRepository : IRepository - { - Task GetByEmailAsync(string email); - Task GetBySsoUserAsync(string externalId, Guid? organizationId); - Task GetKdfInformationByEmailAsync(string email); - Task> SearchAsync(string email, int skip, int take); - Task> GetManyByPremiumAsync(bool premium); - Task GetPublicKeyAsync(Guid id); - Task GetAccountRevisionDateAsync(Guid id); - Task UpdateStorageAsync(Guid id); - Task UpdateRenewalReminderDateAsync(Guid id, DateTime renewalReminderDate); - Task> GetManyAsync(IEnumerable ids); - } + Task GetByEmailAsync(string email); + Task GetBySsoUserAsync(string externalId, Guid? organizationId); + Task GetKdfInformationByEmailAsync(string email); + Task> SearchAsync(string email, int skip, int take); + Task> GetManyByPremiumAsync(bool premium); + Task GetPublicKeyAsync(Guid id); + Task GetAccountRevisionDateAsync(Guid id); + Task UpdateStorageAsync(Guid id); + Task UpdateRenewalReminderDateAsync(Guid id, DateTime renewalReminderDate); + Task> GetManyAsync(IEnumerable ids); } diff --git a/src/Core/Repositories/Noop/InstallationDeviceRepository.cs b/src/Core/Repositories/Noop/InstallationDeviceRepository.cs index eb446547a..b70445901 100644 --- a/src/Core/Repositories/Noop/InstallationDeviceRepository.cs +++ b/src/Core/Repositories/Noop/InstallationDeviceRepository.cs @@ -1,22 +1,21 @@ using Bit.Core.Models.Data; -namespace Bit.Core.Repositories.Noop +namespace Bit.Core.Repositories.Noop; + +public class InstallationDeviceRepository : IInstallationDeviceRepository { - public class InstallationDeviceRepository : IInstallationDeviceRepository + public Task UpsertAsync(InstallationDeviceEntity entity) { - public Task UpsertAsync(InstallationDeviceEntity entity) - { - return Task.FromResult(0); - } + return Task.FromResult(0); + } - public Task UpsertManyAsync(IList entities) - { - return Task.FromResult(0); - } + public Task UpsertManyAsync(IList entities) + { + return Task.FromResult(0); + } - public Task DeleteAsync(InstallationDeviceEntity entity) - { - return Task.FromResult(0); - } + public Task DeleteAsync(InstallationDeviceEntity entity) + { + return Task.FromResult(0); } } diff --git a/src/Core/Repositories/Noop/MetaDataRepository.cs b/src/Core/Repositories/Noop/MetaDataRepository.cs index 1f4658455..bc235c683 100644 --- a/src/Core/Repositories/Noop/MetaDataRepository.cs +++ b/src/Core/Repositories/Noop/MetaDataRepository.cs @@ -1,30 +1,29 @@ -namespace Bit.Core.Repositories.Noop +namespace Bit.Core.Repositories.Noop; + +public class MetaDataRepository : IMetaDataRepository { - public class MetaDataRepository : IMetaDataRepository + public Task DeleteAsync(string objectName, string id) { - public Task DeleteAsync(string objectName, string id) - { - return Task.FromResult(0); - } + return Task.FromResult(0); + } - public Task> GetAsync(string objectName, string id) - { - return Task.FromResult(null as IDictionary); - } + public Task> GetAsync(string objectName, string id) + { + return Task.FromResult(null as IDictionary); + } - public Task GetAsync(string objectName, string id, string prop) - { - return Task.FromResult(null as string); - } + public Task GetAsync(string objectName, string id, string prop) + { + return Task.FromResult(null as string); + } - public Task UpsertAsync(string objectName, string id, IDictionary dict) - { - return Task.FromResult(0); - } + public Task UpsertAsync(string objectName, string id, IDictionary dict) + { + return Task.FromResult(0); + } - public Task UpsertAsync(string objectName, string id, KeyValuePair keyValuePair) - { - return Task.FromResult(0); - } + public Task UpsertAsync(string objectName, string id, KeyValuePair keyValuePair) + { + return Task.FromResult(0); } } diff --git a/src/Core/Repositories/TableStorage/EventRepository.cs b/src/Core/Repositories/TableStorage/EventRepository.cs index 9ee541b8a..514b61099 100644 --- a/src/Core/Repositories/TableStorage/EventRepository.cs +++ b/src/Core/Repositories/TableStorage/EventRepository.cs @@ -4,185 +4,184 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.Azure.Cosmos.Table; -namespace Bit.Core.Repositories.TableStorage +namespace Bit.Core.Repositories.TableStorage; + +public class EventRepository : IEventRepository { - public class EventRepository : IEventRepository + private readonly CloudTable _table; + + public EventRepository(GlobalSettings globalSettings) + : this(globalSettings.Events.ConnectionString) + { } + + public EventRepository(string storageConnectionString) { - private readonly CloudTable _table; + var storageAccount = CloudStorageAccount.Parse(storageConnectionString); + var tableClient = storageAccount.CreateCloudTableClient(); + _table = tableClient.GetTableReference("event"); + } - public EventRepository(GlobalSettings globalSettings) - : this(globalSettings.Events.ConnectionString) - { } + public async Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate, + PageOptions pageOptions) + { + return await GetManyAsync($"UserId={userId}", "Date={{0}}", startDate, endDate, pageOptions); + } - public EventRepository(string storageConnectionString) + public async Task> GetManyByOrganizationAsync(Guid organizationId, + DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + return await GetManyAsync($"OrganizationId={organizationId}", "Date={0}", startDate, endDate, pageOptions); + } + + public async Task> GetManyByOrganizationActingUserAsync(Guid organizationId, Guid actingUserId, + DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + return await GetManyAsync($"OrganizationId={organizationId}", + $"ActingUserId={actingUserId}__Date={{0}}", startDate, endDate, pageOptions); + } + + public async Task> GetManyByProviderAsync(Guid providerId, + DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + return await GetManyAsync($"ProviderId={providerId}", "Date={0}", startDate, endDate, pageOptions); + } + + public async Task> GetManyByProviderActingUserAsync(Guid providerId, Guid actingUserId, + DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + return await GetManyAsync($"ProviderId={providerId}", + $"ActingUserId={actingUserId}__Date={{0}}", startDate, endDate, pageOptions); + } + + public async Task> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate, + PageOptions pageOptions) + { + var partitionKey = cipher.OrganizationId.HasValue ? + $"OrganizationId={cipher.OrganizationId}" : $"UserId={cipher.UserId}"; + return await GetManyAsync(partitionKey, $"CipherId={cipher.Id}__Date={{0}}", startDate, endDate, pageOptions); + } + + public async Task CreateAsync(IEvent e) + { + if (!(e is EventTableEntity entity)) { - var storageAccount = CloudStorageAccount.Parse(storageConnectionString); - var tableClient = storageAccount.CreateCloudTableClient(); - _table = tableClient.GetTableReference("event"); + throw new ArgumentException(nameof(e)); } - public async Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate, - PageOptions pageOptions) + await CreateEntityAsync(entity); + } + + public async Task CreateManyAsync(IEnumerable e) + { + if (!e?.Any() ?? true) { - return await GetManyAsync($"UserId={userId}", "Date={{0}}", startDate, endDate, pageOptions); + return; } - public async Task> GetManyByOrganizationAsync(Guid organizationId, - DateTime startDate, DateTime endDate, PageOptions pageOptions) + if (!e.Skip(1).Any()) { - return await GetManyAsync($"OrganizationId={organizationId}", "Date={0}", startDate, endDate, pageOptions); + await CreateAsync(e.First()); + return; } - public async Task> GetManyByOrganizationActingUserAsync(Guid organizationId, Guid actingUserId, - DateTime startDate, DateTime endDate, PageOptions pageOptions) + var entities = e.Where(ev => ev is EventTableEntity).Select(ev => ev as EventTableEntity); + var entityGroups = entities.GroupBy(ent => ent.PartitionKey); + foreach (var group in entityGroups) { - return await GetManyAsync($"OrganizationId={organizationId}", - $"ActingUserId={actingUserId}__Date={{0}}", startDate, endDate, pageOptions); - } - - public async Task> GetManyByProviderAsync(Guid providerId, - DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - return await GetManyAsync($"ProviderId={providerId}", "Date={0}", startDate, endDate, pageOptions); - } - - public async Task> GetManyByProviderActingUserAsync(Guid providerId, Guid actingUserId, - DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - return await GetManyAsync($"ProviderId={providerId}", - $"ActingUserId={actingUserId}__Date={{0}}", startDate, endDate, pageOptions); - } - - public async Task> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate, - PageOptions pageOptions) - { - var partitionKey = cipher.OrganizationId.HasValue ? - $"OrganizationId={cipher.OrganizationId}" : $"UserId={cipher.UserId}"; - return await GetManyAsync(partitionKey, $"CipherId={cipher.Id}__Date={{0}}", startDate, endDate, pageOptions); - } - - public async Task CreateAsync(IEvent e) - { - if (!(e is EventTableEntity entity)) + var groupEntities = group.ToList(); + if (groupEntities.Count == 1) { - throw new ArgumentException(nameof(e)); + await CreateEntityAsync(groupEntities.First()); + continue; } - await CreateEntityAsync(entity); - } - - public async Task CreateManyAsync(IEnumerable e) - { - if (!e?.Any() ?? true) + // A batch insert can only contain 100 entities at a time + var iterations = groupEntities.Count / 100; + for (var i = 0; i <= iterations; i++) { - return; - } - - if (!e.Skip(1).Any()) - { - await CreateAsync(e.First()); - return; - } - - var entities = e.Where(ev => ev is EventTableEntity).Select(ev => ev as EventTableEntity); - var entityGroups = entities.GroupBy(ent => ent.PartitionKey); - foreach (var group in entityGroups) - { - var groupEntities = group.ToList(); - if (groupEntities.Count == 1) + var batch = new TableBatchOperation(); + var batchEntities = groupEntities.Skip(i * 100).Take(100); + if (!batchEntities.Any()) { - await CreateEntityAsync(groupEntities.First()); - continue; + break; } - // A batch insert can only contain 100 entities at a time - var iterations = groupEntities.Count / 100; - for (var i = 0; i <= iterations; i++) + foreach (var entity in batchEntities) { - var batch = new TableBatchOperation(); - var batchEntities = groupEntities.Skip(i * 100).Take(100); - if (!batchEntities.Any()) - { - break; - } - - foreach (var entity in batchEntities) - { - batch.InsertOrReplace(entity); - } - - await _table.ExecuteBatchAsync(batch); + batch.InsertOrReplace(entity); } + + await _table.ExecuteBatchAsync(batch); } } - - public async Task CreateEntityAsync(ITableEntity entity) - { - await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); - } - - public async Task> GetManyAsync(string partitionKey, string rowKey, - DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - var start = CoreHelpers.DateTimeToTableStorageKey(startDate); - var end = CoreHelpers.DateTimeToTableStorageKey(endDate); - var filter = MakeFilter(partitionKey, string.Format(rowKey, start), string.Format(rowKey, end)); - - var query = new TableQuery().Where(filter).Take(pageOptions.PageSize); - var result = new PagedResult(); - var continuationToken = DeserializeContinuationToken(pageOptions?.ContinuationToken); - - var queryResults = await _table.ExecuteQuerySegmentedAsync(query, continuationToken); - result.ContinuationToken = SerializeContinuationToken(queryResults.ContinuationToken); - result.Data.AddRange(queryResults.Results); - - return result; - } - - private string MakeFilter(string partitionKey, string rowStart, string rowEnd) - { - var rowFilter = TableQuery.CombineFilters( - TableQuery.GenerateFilterCondition("RowKey", QueryComparisons.LessThanOrEqual, $"{rowStart}`"), - TableOperators.And, - TableQuery.GenerateFilterCondition("RowKey", QueryComparisons.GreaterThanOrEqual, $"{rowEnd}_")); - - return TableQuery.CombineFilters( - TableQuery.GenerateFilterCondition("PartitionKey", QueryComparisons.Equal, partitionKey), - TableOperators.And, - rowFilter); - } - - private string SerializeContinuationToken(TableContinuationToken token) - { - if (token == null) - { - return null; - } - - return string.Format("{0}__{1}__{2}__{3}", (int)token.TargetLocation, token.NextTableName, - token.NextPartitionKey, token.NextRowKey); - } - - private TableContinuationToken DeserializeContinuationToken(string token) - { - if (string.IsNullOrWhiteSpace(token)) - { - return null; - } - - var tokenParts = token.Split(new string[] { "__" }, StringSplitOptions.None); - if (tokenParts.Length < 4 || !Enum.TryParse(tokenParts[0], out StorageLocation tLoc)) - { - return null; - } - - return new TableContinuationToken - { - TargetLocation = tLoc, - NextTableName = string.IsNullOrWhiteSpace(tokenParts[1]) ? null : tokenParts[1], - NextPartitionKey = string.IsNullOrWhiteSpace(tokenParts[2]) ? null : tokenParts[2], - NextRowKey = string.IsNullOrWhiteSpace(tokenParts[3]) ? null : tokenParts[3] - }; - } + } + + public async Task CreateEntityAsync(ITableEntity entity) + { + await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); + } + + public async Task> GetManyAsync(string partitionKey, string rowKey, + DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + var start = CoreHelpers.DateTimeToTableStorageKey(startDate); + var end = CoreHelpers.DateTimeToTableStorageKey(endDate); + var filter = MakeFilter(partitionKey, string.Format(rowKey, start), string.Format(rowKey, end)); + + var query = new TableQuery().Where(filter).Take(pageOptions.PageSize); + var result = new PagedResult(); + var continuationToken = DeserializeContinuationToken(pageOptions?.ContinuationToken); + + var queryResults = await _table.ExecuteQuerySegmentedAsync(query, continuationToken); + result.ContinuationToken = SerializeContinuationToken(queryResults.ContinuationToken); + result.Data.AddRange(queryResults.Results); + + return result; + } + + private string MakeFilter(string partitionKey, string rowStart, string rowEnd) + { + var rowFilter = TableQuery.CombineFilters( + TableQuery.GenerateFilterCondition("RowKey", QueryComparisons.LessThanOrEqual, $"{rowStart}`"), + TableOperators.And, + TableQuery.GenerateFilterCondition("RowKey", QueryComparisons.GreaterThanOrEqual, $"{rowEnd}_")); + + return TableQuery.CombineFilters( + TableQuery.GenerateFilterCondition("PartitionKey", QueryComparisons.Equal, partitionKey), + TableOperators.And, + rowFilter); + } + + private string SerializeContinuationToken(TableContinuationToken token) + { + if (token == null) + { + return null; + } + + return string.Format("{0}__{1}__{2}__{3}", (int)token.TargetLocation, token.NextTableName, + token.NextPartitionKey, token.NextRowKey); + } + + private TableContinuationToken DeserializeContinuationToken(string token) + { + if (string.IsNullOrWhiteSpace(token)) + { + return null; + } + + var tokenParts = token.Split(new string[] { "__" }, StringSplitOptions.None); + if (tokenParts.Length < 4 || !Enum.TryParse(tokenParts[0], out StorageLocation tLoc)) + { + return null; + } + + return new TableContinuationToken + { + TargetLocation = tLoc, + NextTableName = string.IsNullOrWhiteSpace(tokenParts[1]) ? null : tokenParts[1], + NextPartitionKey = string.IsNullOrWhiteSpace(tokenParts[2]) ? null : tokenParts[2], + NextRowKey = string.IsNullOrWhiteSpace(tokenParts[3]) ? null : tokenParts[3] + }; } } diff --git a/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs b/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs index 125360e6b..32b466d1b 100644 --- a/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs +++ b/src/Core/Repositories/TableStorage/InstallationDeviceRepository.cs @@ -3,83 +3,82 @@ using Bit.Core.Models.Data; using Bit.Core.Settings; using Microsoft.Azure.Cosmos.Table; -namespace Bit.Core.Repositories.TableStorage +namespace Bit.Core.Repositories.TableStorage; + +public class InstallationDeviceRepository : IInstallationDeviceRepository { - public class InstallationDeviceRepository : IInstallationDeviceRepository + private readonly CloudTable _table; + + public InstallationDeviceRepository(GlobalSettings globalSettings) + : this(globalSettings.Events.ConnectionString) + { } + + public InstallationDeviceRepository(string storageConnectionString) { - private readonly CloudTable _table; + var storageAccount = CloudStorageAccount.Parse(storageConnectionString); + var tableClient = storageAccount.CreateCloudTableClient(); + _table = tableClient.GetTableReference("installationdevice"); + } - public InstallationDeviceRepository(GlobalSettings globalSettings) - : this(globalSettings.Events.ConnectionString) - { } + public async Task UpsertAsync(InstallationDeviceEntity entity) + { + await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); + } - public InstallationDeviceRepository(string storageConnectionString) + public async Task UpsertManyAsync(IList entities) + { + if (!entities?.Any() ?? true) { - var storageAccount = CloudStorageAccount.Parse(storageConnectionString); - var tableClient = storageAccount.CreateCloudTableClient(); - _table = tableClient.GetTableReference("installationdevice"); + return; } - public async Task UpsertAsync(InstallationDeviceEntity entity) + if (entities.Count == 1) { - await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); + await UpsertAsync(entities.First()); + return; } - public async Task UpsertManyAsync(IList entities) + var entityGroups = entities.GroupBy(ent => ent.PartitionKey); + foreach (var group in entityGroups) { - if (!entities?.Any() ?? true) + var groupEntities = group.ToList(); + if (groupEntities.Count == 1) { - return; + await UpsertAsync(groupEntities.First()); + continue; } - if (entities.Count == 1) + // A batch insert can only contain 100 entities at a time + var iterations = groupEntities.Count / 100; + for (var i = 0; i <= iterations; i++) { - await UpsertAsync(entities.First()); - return; - } - - var entityGroups = entities.GroupBy(ent => ent.PartitionKey); - foreach (var group in entityGroups) - { - var groupEntities = group.ToList(); - if (groupEntities.Count == 1) + var batch = new TableBatchOperation(); + var batchEntities = groupEntities.Skip(i * 100).Take(100); + if (!batchEntities.Any()) { - await UpsertAsync(groupEntities.First()); - continue; + break; } - // A batch insert can only contain 100 entities at a time - var iterations = groupEntities.Count / 100; - for (var i = 0; i <= iterations; i++) + foreach (var entity in batchEntities) { - var batch = new TableBatchOperation(); - var batchEntities = groupEntities.Skip(i * 100).Take(100); - if (!batchEntities.Any()) - { - break; - } - - foreach (var entity in batchEntities) - { - batch.InsertOrReplace(entity); - } - - await _table.ExecuteBatchAsync(batch); + batch.InsertOrReplace(entity); } - } - } - public async Task DeleteAsync(InstallationDeviceEntity entity) - { - try - { - entity.ETag = "*"; - await _table.ExecuteAsync(TableOperation.Delete(entity)); - } - catch (StorageException e) when (e.RequestInformation.HttpStatusCode != (int)HttpStatusCode.NotFound) - { - throw; + await _table.ExecuteBatchAsync(batch); } } } + + public async Task DeleteAsync(InstallationDeviceEntity entity) + { + try + { + entity.ETag = "*"; + await _table.ExecuteAsync(TableOperation.Delete(entity)); + } + catch (StorageException e) when (e.RequestInformation.HttpStatusCode != (int)HttpStatusCode.NotFound) + { + throw; + } + } } diff --git a/src/Core/Repositories/TableStorage/MetaDataRepository.cs b/src/Core/Repositories/TableStorage/MetaDataRepository.cs index 83ae04e4b..c70426e2a 100644 --- a/src/Core/Repositories/TableStorage/MetaDataRepository.cs +++ b/src/Core/Repositories/TableStorage/MetaDataRepository.cs @@ -3,92 +3,91 @@ using Bit.Core.Models.Data; using Bit.Core.Settings; using Microsoft.Azure.Cosmos.Table; -namespace Bit.Core.Repositories.TableStorage +namespace Bit.Core.Repositories.TableStorage; + +public class MetaDataRepository : IMetaDataRepository { - public class MetaDataRepository : IMetaDataRepository + private readonly CloudTable _table; + + public MetaDataRepository(GlobalSettings globalSettings) + : this(globalSettings.Events.ConnectionString) + { } + + public MetaDataRepository(string storageConnectionString) { - private readonly CloudTable _table; + var storageAccount = CloudStorageAccount.Parse(storageConnectionString); + var tableClient = storageAccount.CreateCloudTableClient(); + _table = tableClient.GetTableReference("metadata"); + } - public MetaDataRepository(GlobalSettings globalSettings) - : this(globalSettings.Events.ConnectionString) - { } + public async Task> GetAsync(string objectName, string id) + { + var query = new TableQuery().Where( + TableQuery.GenerateFilterCondition("PartitionKey", QueryComparisons.Equal, $"{objectName}_{id}")); + var queryResults = await _table.ExecuteQuerySegmentedAsync(query, null); + return queryResults.Results.FirstOrDefault()?.ToDictionary(d => d.Key, d => d.Value.StringValue); + } - public MetaDataRepository(string storageConnectionString) + public async Task GetAsync(string objectName, string id, string prop) + { + var dict = await GetAsync(objectName, id); + if (dict != null && dict.ContainsKey(prop)) { - var storageAccount = CloudStorageAccount.Parse(storageConnectionString); - var tableClient = storageAccount.CreateCloudTableClient(); - _table = tableClient.GetTableReference("metadata"); + return dict[prop]; } + return null; + } - public async Task> GetAsync(string objectName, string id) + public async Task UpsertAsync(string objectName, string id, KeyValuePair keyValuePair) + { + var query = new TableQuery().Where( + TableQuery.GenerateFilterCondition("PartitionKey", QueryComparisons.Equal, $"{objectName}_{id}")); + var queryResults = await _table.ExecuteQuerySegmentedAsync(query, null); + var entity = queryResults.Results.FirstOrDefault(); + if (entity == null) { - var query = new TableQuery().Where( - TableQuery.GenerateFilterCondition("PartitionKey", QueryComparisons.Equal, $"{objectName}_{id}")); - var queryResults = await _table.ExecuteQuerySegmentedAsync(query, null); - return queryResults.Results.FirstOrDefault()?.ToDictionary(d => d.Key, d => d.Value.StringValue); - } - - public async Task GetAsync(string objectName, string id, string prop) - { - var dict = await GetAsync(objectName, id); - if (dict != null && dict.ContainsKey(prop)) - { - return dict[prop]; - } - return null; - } - - public async Task UpsertAsync(string objectName, string id, KeyValuePair keyValuePair) - { - var query = new TableQuery().Where( - TableQuery.GenerateFilterCondition("PartitionKey", QueryComparisons.Equal, $"{objectName}_{id}")); - var queryResults = await _table.ExecuteQuerySegmentedAsync(query, null); - var entity = queryResults.Results.FirstOrDefault(); - if (entity == null) - { - entity = new DictionaryEntity - { - PartitionKey = $"{objectName}_{id}", - RowKey = string.Empty - }; - } - if (entity.ContainsKey(keyValuePair.Key)) - { - entity.Remove(keyValuePair.Key); - } - entity.Add(keyValuePair.Key, keyValuePair.Value); - await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); - } - - public async Task UpsertAsync(string objectName, string id, IDictionary dict) - { - var entity = new DictionaryEntity + entity = new DictionaryEntity { PartitionKey = $"{objectName}_{id}", RowKey = string.Empty }; - foreach (var item in dict) - { - entity.Add(item.Key, item.Value); - } - await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); } - - public async Task DeleteAsync(string objectName, string id) + if (entity.ContainsKey(keyValuePair.Key)) { - try + entity.Remove(keyValuePair.Key); + } + entity.Add(keyValuePair.Key, keyValuePair.Value); + await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); + } + + public async Task UpsertAsync(string objectName, string id, IDictionary dict) + { + var entity = new DictionaryEntity + { + PartitionKey = $"{objectName}_{id}", + RowKey = string.Empty + }; + foreach (var item in dict) + { + entity.Add(item.Key, item.Value); + } + await _table.ExecuteAsync(TableOperation.InsertOrReplace(entity)); + } + + public async Task DeleteAsync(string objectName, string id) + { + try + { + await _table.ExecuteAsync(TableOperation.Delete(new DictionaryEntity { - await _table.ExecuteAsync(TableOperation.Delete(new DictionaryEntity - { - PartitionKey = $"{objectName}_{id}", - RowKey = string.Empty, - ETag = "*" - })); - } - catch (StorageException e) when (e.RequestInformation.HttpStatusCode != (int)HttpStatusCode.NotFound) - { - throw; - } + PartitionKey = $"{objectName}_{id}", + RowKey = string.Empty, + ETag = "*" + })); + } + catch (StorageException e) when (e.RequestInformation.HttpStatusCode != (int)HttpStatusCode.NotFound) + { + throw; } } } diff --git a/src/Core/Resources/SharedResources.cs b/src/Core/Resources/SharedResources.cs index 543ec227a..39eea7c6c 100644 --- a/src/Core/Resources/SharedResources.cs +++ b/src/Core/Resources/SharedResources.cs @@ -1,6 +1,5 @@ -namespace Bit.Core.Resources +namespace Bit.Core.Resources; + +public class SharedResources { - public class SharedResources - { - } } diff --git a/src/Core/Services/IAppleIapService.cs b/src/Core/Services/IAppleIapService.cs index aef7e2c88..b258b9e3b 100644 --- a/src/Core/Services/IAppleIapService.cs +++ b/src/Core/Services/IAppleIapService.cs @@ -1,11 +1,10 @@ using Bit.Billing.Models; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IAppleIapService { - public interface IAppleIapService - { - Task GetVerifiedReceiptStatusAsync(string receiptData); - Task SaveReceiptAsync(AppleReceiptStatus receiptStatus, Guid userId); - Task> GetReceiptAsync(string originalTransactionId); - } + Task GetVerifiedReceiptStatusAsync(string receiptData); + Task SaveReceiptAsync(AppleReceiptStatus receiptStatus, Guid userId); + Task> GetReceiptAsync(string originalTransactionId); } diff --git a/src/Core/Services/IApplicationCacheService.cs b/src/Core/Services/IApplicationCacheService.cs index 08efe7b7c..7c21fac76 100644 --- a/src/Core/Services/IApplicationCacheService.cs +++ b/src/Core/Services/IApplicationCacheService.cs @@ -3,14 +3,13 @@ using Bit.Core.Entities.Provider; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IApplicationCacheService { - public interface IApplicationCacheService - { - Task> GetOrganizationAbilitiesAsync(); - Task> GetProviderAbilitiesAsync(); - Task UpsertOrganizationAbilityAsync(Organization organization); - Task UpsertProviderAbilityAsync(Provider provider); - Task DeleteOrganizationAbilityAsync(Guid organizationId); - } + Task> GetOrganizationAbilitiesAsync(); + Task> GetProviderAbilitiesAsync(); + Task UpsertOrganizationAbilityAsync(Organization organization); + Task UpsertProviderAbilityAsync(Provider provider); + Task DeleteOrganizationAbilityAsync(Guid organizationId); } diff --git a/src/Core/Services/IAttachmentStorageService.cs b/src/Core/Services/IAttachmentStorageService.cs index c0b11a021..964b711f0 100644 --- a/src/Core/Services/IAttachmentStorageService.cs +++ b/src/Core/Services/IAttachmentStorageService.cs @@ -2,22 +2,21 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IAttachmentStorageService { - public interface IAttachmentStorageService - { - FileUploadType FileUploadType { get; } - Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData); - Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData); - Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData); - Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer); - Task CleanupAsync(Guid cipherId); - Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData); - Task DeleteAttachmentsForCipherAsync(Guid cipherId); - Task DeleteAttachmentsForOrganizationAsync(Guid organizationId); - Task DeleteAttachmentsForUserAsync(Guid userId); - Task GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData); - Task GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData); - Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway); - } + FileUploadType FileUploadType { get; } + Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData); + Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData); + Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData); + Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer); + Task CleanupAsync(Guid cipherId); + Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData); + Task DeleteAttachmentsForCipherAsync(Guid cipherId); + Task DeleteAttachmentsForOrganizationAsync(Guid organizationId); + Task DeleteAttachmentsForUserAsync(Guid userId); + Task GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData); + Task GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData); + Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway); } diff --git a/src/Core/Services/IBlockIpService.cs b/src/Core/Services/IBlockIpService.cs index 547a7cede..87af1a2ce 100644 --- a/src/Core/Services/IBlockIpService.cs +++ b/src/Core/Services/IBlockIpService.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IBlockIpService { - public interface IBlockIpService - { - Task BlockIpAsync(string ipAddress, bool permanentBlock); - } + Task BlockIpAsync(string ipAddress, bool permanentBlock); } diff --git a/src/Core/Services/ICaptchaValidationService.cs b/src/Core/Services/ICaptchaValidationService.cs index d908be7c2..50faad31f 100644 --- a/src/Core/Services/ICaptchaValidationService.cs +++ b/src/Core/Services/ICaptchaValidationService.cs @@ -2,15 +2,14 @@ using Bit.Core.Entities; using Bit.Core.Models.Business; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface ICaptchaValidationService { - public interface ICaptchaValidationService - { - string SiteKey { get; } - string SiteKeyResponseKeyName { get; } - bool RequireCaptchaValidation(ICurrentContext currentContext, User user = null); - Task ValidateCaptchaResponseAsync(string captchResponse, string clientIpAddress, - User user = null); - string GenerateCaptchaBypassToken(User user); - } + string SiteKey { get; } + string SiteKeyResponseKeyName { get; } + bool RequireCaptchaValidation(ICurrentContext currentContext, User user = null); + Task ValidateCaptchaResponseAsync(string captchResponse, string clientIpAddress, + User user = null); + string GenerateCaptchaBypassToken(User user); } diff --git a/src/Core/Services/ICipherService.cs b/src/Core/Services/ICipherService.cs index 9afeb5926..ad93990c2 100644 --- a/src/Core/Services/ICipherService.cs +++ b/src/Core/Services/ICipherService.cs @@ -2,43 +2,42 @@ using Bit.Core.Models.Data; using Core.Models.Data; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface ICipherService { - public interface ICipherService - { - Task SaveAsync(Cipher cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, IEnumerable collectionIds = null, - bool skipPermissionCheck = false, bool limitCollectionScope = true); - Task SaveDetailsAsync(CipherDetails cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, - IEnumerable collectionIds = null, bool skipPermissionCheck = false); - Task<(string attachmentId, string uploadUrl)> CreateAttachmentForDelayedUploadAsync(Cipher cipher, - string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId); - Task CreateAttachmentAsync(Cipher cipher, Stream stream, string fileName, string key, - long requestLength, Guid savingUserId, bool orgAdmin = false); - Task CreateAttachmentShareAsync(Cipher cipher, Stream stream, long requestLength, string attachmentId, - Guid organizationShareId); - Task DeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false); - Task DeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false); - Task DeleteAttachmentAsync(Cipher cipher, string attachmentId, Guid deletingUserId, bool orgAdmin = false); - Task PurgeAsync(Guid organizationId); - Task MoveManyAsync(IEnumerable cipherIds, Guid? destinationFolderId, Guid movingUserId); - Task SaveFolderAsync(Folder folder); - Task DeleteFolderAsync(Folder folder); - Task ShareAsync(Cipher originalCipher, Cipher cipher, Guid organizationId, IEnumerable collectionIds, - Guid userId, DateTime? lastKnownRevisionDate); - Task ShareManyAsync(IEnumerable<(Cipher cipher, DateTime? lastKnownRevisionDate)> ciphers, Guid organizationId, - IEnumerable collectionIds, Guid sharingUserId); - Task SaveCollectionsAsync(Cipher cipher, IEnumerable collectionIds, Guid savingUserId, bool orgAdmin); - Task ImportCiphersAsync(List folders, List ciphers, - IEnumerable> folderRelationships); - Task ImportCiphersAsync(List collections, List ciphers, - IEnumerable> collectionRelationships, Guid importingUserId); - Task SoftDeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false); - Task SoftDeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false); - Task RestoreAsync(Cipher cipher, Guid restoringUserId, bool orgAdmin = false); - Task RestoreManyAsync(IEnumerable ciphers, Guid restoringUserId); - Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentId); - Task GetAttachmentDownloadDataAsync(Cipher cipher, string attachmentId); - Task ValidateCipherAttachmentFile(Cipher cipher, CipherAttachment.MetaData attachmentData); - Task<(IEnumerable, Dictionary>)> GetOrganizationCiphers(Guid userId, Guid organizationId); - } + Task SaveAsync(Cipher cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, IEnumerable collectionIds = null, + bool skipPermissionCheck = false, bool limitCollectionScope = true); + Task SaveDetailsAsync(CipherDetails cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, + IEnumerable collectionIds = null, bool skipPermissionCheck = false); + Task<(string attachmentId, string uploadUrl)> CreateAttachmentForDelayedUploadAsync(Cipher cipher, + string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId); + Task CreateAttachmentAsync(Cipher cipher, Stream stream, string fileName, string key, + long requestLength, Guid savingUserId, bool orgAdmin = false); + Task CreateAttachmentShareAsync(Cipher cipher, Stream stream, long requestLength, string attachmentId, + Guid organizationShareId); + Task DeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false); + Task DeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false); + Task DeleteAttachmentAsync(Cipher cipher, string attachmentId, Guid deletingUserId, bool orgAdmin = false); + Task PurgeAsync(Guid organizationId); + Task MoveManyAsync(IEnumerable cipherIds, Guid? destinationFolderId, Guid movingUserId); + Task SaveFolderAsync(Folder folder); + Task DeleteFolderAsync(Folder folder); + Task ShareAsync(Cipher originalCipher, Cipher cipher, Guid organizationId, IEnumerable collectionIds, + Guid userId, DateTime? lastKnownRevisionDate); + Task ShareManyAsync(IEnumerable<(Cipher cipher, DateTime? lastKnownRevisionDate)> ciphers, Guid organizationId, + IEnumerable collectionIds, Guid sharingUserId); + Task SaveCollectionsAsync(Cipher cipher, IEnumerable collectionIds, Guid savingUserId, bool orgAdmin); + Task ImportCiphersAsync(List folders, List ciphers, + IEnumerable> folderRelationships); + Task ImportCiphersAsync(List collections, List ciphers, + IEnumerable> collectionRelationships, Guid importingUserId); + Task SoftDeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false); + Task SoftDeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false); + Task RestoreAsync(Cipher cipher, Guid restoringUserId, bool orgAdmin = false); + Task RestoreManyAsync(IEnumerable ciphers, Guid restoringUserId); + Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentId); + Task GetAttachmentDownloadDataAsync(Cipher cipher, string attachmentId); + Task ValidateCipherAttachmentFile(Cipher cipher, CipherAttachment.MetaData attachmentData); + Task<(IEnumerable, Dictionary>)> GetOrganizationCiphers(Guid userId, Guid organizationId); } diff --git a/src/Core/Services/ICollectionService.cs b/src/Core/Services/ICollectionService.cs index 015474b6f..7ae3562ea 100644 --- a/src/Core/Services/ICollectionService.cs +++ b/src/Core/Services/ICollectionService.cs @@ -1,13 +1,12 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface ICollectionService { - public interface ICollectionService - { - Task SaveAsync(Collection collection, IEnumerable groups = null, Guid? assignUserId = null); - Task DeleteAsync(Collection collection); - Task DeleteUserAsync(Collection collection, Guid organizationUserId); - Task> GetOrganizationCollections(Guid organizationId); - } + Task SaveAsync(Collection collection, IEnumerable groups = null, Guid? assignUserId = null); + Task DeleteAsync(Collection collection); + Task DeleteUserAsync(Collection collection, Guid organizationUserId); + Task> GetOrganizationCollections(Guid organizationId); } diff --git a/src/Core/Services/IDeviceService.cs b/src/Core/Services/IDeviceService.cs index 6455e6a32..3109cc107 100644 --- a/src/Core/Services/IDeviceService.cs +++ b/src/Core/Services/IDeviceService.cs @@ -1,11 +1,10 @@ using Bit.Core.Entities; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IDeviceService { - public interface IDeviceService - { - Task SaveAsync(Device device); - Task ClearTokenAsync(Device device); - Task DeleteAsync(Device device); - } + Task SaveAsync(Device device); + Task ClearTokenAsync(Device device); + Task DeleteAsync(Device device); } diff --git a/src/Core/Services/IEmergencyAccessService.cs b/src/Core/Services/IEmergencyAccessService.cs index f975bfe76..96edb752c 100644 --- a/src/Core/Services/IEmergencyAccessService.cs +++ b/src/Core/Services/IEmergencyAccessService.cs @@ -2,26 +2,25 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IEmergencyAccessService { - public interface IEmergencyAccessService - { - Task InviteAsync(User invitingUser, string email, EmergencyAccessType type, int waitTime); - Task ResendInviteAsync(User invitingUser, Guid emergencyAccessId); - Task AcceptUserAsync(Guid emergencyAccessId, User user, string token, IUserService userService); - Task DeleteAsync(Guid emergencyAccessId, Guid grantorId); - Task ConfirmUserAsync(Guid emergencyAccessId, string key, Guid grantorId); - Task GetAsync(Guid emergencyAccessId, Guid userId); - Task SaveAsync(EmergencyAccess emergencyAccess, User savingUser); - Task InitiateAsync(Guid id, User initiatingUser); - Task ApproveAsync(Guid id, User approvingUser); - Task RejectAsync(Guid id, User rejectingUser); - Task> GetPoliciesAsync(Guid id, User requestingUser); - Task<(EmergencyAccess, User)> TakeoverAsync(Guid id, User initiatingUser); - Task PasswordAsync(Guid id, User user, string newMasterPasswordHash, string key); - Task SendNotificationsAsync(); - Task HandleTimedOutRequestsAsync(); - Task ViewAsync(Guid id, User user); - Task GetAttachmentDownloadAsync(Guid id, Guid cipherId, string attachmentId, User user); - } + Task InviteAsync(User invitingUser, string email, EmergencyAccessType type, int waitTime); + Task ResendInviteAsync(User invitingUser, Guid emergencyAccessId); + Task AcceptUserAsync(Guid emergencyAccessId, User user, string token, IUserService userService); + Task DeleteAsync(Guid emergencyAccessId, Guid grantorId); + Task ConfirmUserAsync(Guid emergencyAccessId, string key, Guid grantorId); + Task GetAsync(Guid emergencyAccessId, Guid userId); + Task SaveAsync(EmergencyAccess emergencyAccess, User savingUser); + Task InitiateAsync(Guid id, User initiatingUser); + Task ApproveAsync(Guid id, User approvingUser); + Task RejectAsync(Guid id, User rejectingUser); + Task> GetPoliciesAsync(Guid id, User requestingUser); + Task<(EmergencyAccess, User)> TakeoverAsync(Guid id, User initiatingUser); + Task PasswordAsync(Guid id, User user, string newMasterPasswordHash, string key); + Task SendNotificationsAsync(); + Task HandleTimedOutRequestsAsync(); + Task ViewAsync(Guid id, User user); + Task GetAttachmentDownloadAsync(Guid id, Guid cipherId, string attachmentId, User user); } diff --git a/src/Core/Services/IEventService.cs b/src/Core/Services/IEventService.cs index fa9848584..fd0ca4491 100644 --- a/src/Core/Services/IEventService.cs +++ b/src/Core/Services/IEventService.cs @@ -2,21 +2,20 @@ using Bit.Core.Entities.Provider; using Bit.Core.Enums; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IEventService { - public interface IEventService - { - Task LogUserEventAsync(Guid userId, EventType type, DateTime? date = null); - Task LogCipherEventAsync(Cipher cipher, EventType type, DateTime? date = null); - Task LogCipherEventsAsync(IEnumerable> events); - Task LogCollectionEventAsync(Collection collection, EventType type, DateTime? date = null); - Task LogGroupEventAsync(Group group, EventType type, DateTime? date = null); - Task LogPolicyEventAsync(Policy policy, EventType type, DateTime? date = null); - Task LogOrganizationUserEventAsync(OrganizationUser organizationUser, EventType type, DateTime? date = null); - Task LogOrganizationUserEventsAsync(IEnumerable<(OrganizationUser, EventType, DateTime?)> events); - Task LogOrganizationEventAsync(Organization organization, EventType type, DateTime? date = null); - Task LogProviderUserEventAsync(ProviderUser providerUser, EventType type, DateTime? date = null); - Task LogProviderUsersEventAsync(IEnumerable<(ProviderUser, EventType, DateTime?)> events); - Task LogProviderOrganizationEventAsync(ProviderOrganization providerOrganization, EventType type, DateTime? date = null); - } + Task LogUserEventAsync(Guid userId, EventType type, DateTime? date = null); + Task LogCipherEventAsync(Cipher cipher, EventType type, DateTime? date = null); + Task LogCipherEventsAsync(IEnumerable> events); + Task LogCollectionEventAsync(Collection collection, EventType type, DateTime? date = null); + Task LogGroupEventAsync(Group group, EventType type, DateTime? date = null); + Task LogPolicyEventAsync(Policy policy, EventType type, DateTime? date = null); + Task LogOrganizationUserEventAsync(OrganizationUser organizationUser, EventType type, DateTime? date = null); + Task LogOrganizationUserEventsAsync(IEnumerable<(OrganizationUser, EventType, DateTime?)> events); + Task LogOrganizationEventAsync(Organization organization, EventType type, DateTime? date = null); + Task LogProviderUserEventAsync(ProviderUser providerUser, EventType type, DateTime? date = null); + Task LogProviderUsersEventAsync(IEnumerable<(ProviderUser, EventType, DateTime?)> events); + Task LogProviderOrganizationEventAsync(ProviderOrganization providerOrganization, EventType type, DateTime? date = null); } diff --git a/src/Core/Services/IEventWriteService.cs b/src/Core/Services/IEventWriteService.cs index dc3318937..cbe8790d3 100644 --- a/src/Core/Services/IEventWriteService.cs +++ b/src/Core/Services/IEventWriteService.cs @@ -1,10 +1,9 @@ using Bit.Core.Models.Data; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IEventWriteService { - public interface IEventWriteService - { - Task CreateAsync(IEvent e); - Task CreateManyAsync(IEnumerable e); - } + Task CreateAsync(IEvent e); + Task CreateManyAsync(IEnumerable e); } diff --git a/src/Core/Services/IGroupService.cs b/src/Core/Services/IGroupService.cs index 82fd9792a..494d3e6c0 100644 --- a/src/Core/Services/IGroupService.cs +++ b/src/Core/Services/IGroupService.cs @@ -1,12 +1,11 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IGroupService { - public interface IGroupService - { - Task SaveAsync(Group group, IEnumerable collections = null); - Task DeleteAsync(Group group); - Task DeleteUserAsync(Group group, Guid organizationUserId); - } + Task SaveAsync(Group group, IEnumerable collections = null); + Task DeleteAsync(Group group); + Task DeleteUserAsync(Group group, Guid organizationUserId); } diff --git a/src/Core/Services/II18nService.cs b/src/Core/Services/II18nService.cs index a66e14883..ee92664d8 100644 --- a/src/Core/Services/II18nService.cs +++ b/src/Core/Services/II18nService.cs @@ -1,12 +1,11 @@ using Microsoft.Extensions.Localization; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface II18nService { - public interface II18nService - { - LocalizedString GetLocalizedHtmlString(string key); - LocalizedString GetLocalizedHtmlString(string key, params object[] args); - string Translate(string key, params object[] args); - string T(string key, params object[] args); - } + LocalizedString GetLocalizedHtmlString(string key); + LocalizedString GetLocalizedHtmlString(string key, params object[] args); + string Translate(string key, params object[] args); + string T(string key, params object[] args); } diff --git a/src/Core/Services/ILicensingService.cs b/src/Core/Services/ILicensingService.cs index fd3ad9afe..bf3b5ee42 100644 --- a/src/Core/Services/ILicensingService.cs +++ b/src/Core/Services/ILicensingService.cs @@ -1,17 +1,16 @@ using Bit.Core.Entities; using Bit.Core.Models.Business; -namespace Bit.Core.Services -{ - public interface ILicensingService - { - Task ValidateOrganizationsAsync(); - Task ValidateUsersAsync(); - Task ValidateUserPremiumAsync(User user); - bool VerifyLicense(ILicense license); - byte[] SignLicense(ILicense license); - Task ReadOrganizationLicenseAsync(Organization organization); - Task ReadOrganizationLicenseAsync(Guid organizationId); +namespace Bit.Core.Services; + +public interface ILicensingService +{ + Task ValidateOrganizationsAsync(); + Task ValidateUsersAsync(); + Task ValidateUserPremiumAsync(User user); + bool VerifyLicense(ILicense license); + byte[] SignLicense(ILicense license); + Task ReadOrganizationLicenseAsync(Organization organization); + Task ReadOrganizationLicenseAsync(Guid organizationId); - } } diff --git a/src/Core/Services/IMailDeliveryService.cs b/src/Core/Services/IMailDeliveryService.cs index 1c42e39e2..924736722 100644 --- a/src/Core/Services/IMailDeliveryService.cs +++ b/src/Core/Services/IMailDeliveryService.cs @@ -1,9 +1,8 @@ using Bit.Core.Models.Mail; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IMailDeliveryService { - public interface IMailDeliveryService - { - Task SendEmailAsync(MailMessage message); - } + Task SendEmailAsync(MailMessage message); } diff --git a/src/Core/Services/IMailEnqueuingService.cs b/src/Core/Services/IMailEnqueuingService.cs index 1e681b6ca..19dc33f19 100644 --- a/src/Core/Services/IMailEnqueuingService.cs +++ b/src/Core/Services/IMailEnqueuingService.cs @@ -1,10 +1,9 @@ using Bit.Core.Models.Mail; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IMailEnqueuingService { - public interface IMailEnqueuingService - { - Task EnqueueAsync(IMailQueueMessage message, Func fallback); - Task EnqueueManyAsync(IEnumerable messages, Func fallback); - } + Task EnqueueAsync(IMailQueueMessage message, Func fallback); + Task EnqueueManyAsync(IEnumerable messages, Func fallback); } diff --git a/src/Core/Services/IMailService.cs b/src/Core/Services/IMailService.cs index 7be31e82a..3af89108c 100644 --- a/src/Core/Services/IMailService.cs +++ b/src/Core/Services/IMailService.cs @@ -3,56 +3,55 @@ using Bit.Core.Entities.Provider; using Bit.Core.Models.Business; using Bit.Core.Models.Mail; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IMailService { - public interface IMailService - { - Task SendWelcomeEmailAsync(User user); - Task SendVerifyEmailEmailAsync(string email, Guid userId, string token); - Task SendVerifyDeleteEmailAsync(string email, Guid userId, string token); - Task SendChangeEmailAlreadyExistsEmailAsync(string fromEmail, string toEmail); - Task SendChangeEmailEmailAsync(string newEmailAddress, string token); - Task SendTwoFactorEmailAsync(string email, string token); - Task SendNewDeviceLoginTwoFactorEmailAsync(string email, string token); - Task SendNoMasterPasswordHintEmailAsync(string email); - Task SendMasterPasswordHintEmailAsync(string email, string hint); - Task SendOrganizationInviteEmailAsync(string organizationName, OrganizationUser orgUser, ExpiringToken token); - Task BulkSendOrganizationInviteEmailAsync(string organizationName, IEnumerable<(OrganizationUser orgUser, ExpiringToken token)> invites); - Task SendOrganizationMaxSeatLimitReachedEmailAsync(Organization organization, int maxSeatCount, IEnumerable ownerEmails); - Task SendOrganizationAutoscaledEmailAsync(Organization organization, int initialSeatCount, IEnumerable ownerEmails); - Task SendOrganizationAcceptedEmailAsync(Organization organization, string userIdentifier, IEnumerable adminEmails); - Task SendOrganizationConfirmedEmailAsync(string organizationName, string email); - Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string organizationName, string email); - Task SendPasswordlessSignInAsync(string returnUrl, string token, string email); - Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, List items, - bool mentionInvoices); - Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices); - Task SendAddedCreditAsync(string email, decimal amount); - Task SendLicenseExpiredAsync(IEnumerable emails, string organizationName = null); - Task SendNewDeviceLoggedInEmail(string email, string deviceType, DateTime timestamp, string ip); - Task SendRecoverTwoFactorEmail(string email, DateTime timestamp, string ip); - Task SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(string organizationName, string email); - Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token); - Task SendEmergencyAccessAcceptedEmailAsync(string granteeEmail, string email); - Task SendEmergencyAccessConfirmedEmailAsync(string grantorName, string email); - Task SendEmergencyAccessRecoveryInitiated(EmergencyAccess emergencyAccess, string initiatingName, string email); - Task SendEmergencyAccessRecoveryApproved(EmergencyAccess emergencyAccess, string approvingName, string email); - Task SendEmergencyAccessRecoveryRejected(EmergencyAccess emergencyAccess, string rejectingName, string email); - Task SendEmergencyAccessRecoveryReminder(EmergencyAccess emergencyAccess, string initiatingName, string email); - Task SendEmergencyAccessRecoveryTimedOut(EmergencyAccess ea, string initiatingName, string email); - Task SendEnqueuedMailMessageAsync(IMailQueueMessage queueMessage); - Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName); - Task SendProviderSetupInviteEmailAsync(Provider provider, string token, string email); - Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email); - Task SendProviderConfirmedEmailAsync(string providerName, string email); - Task SendProviderUserRemoved(string providerName, string email); - Task SendUpdatedTempPasswordEmailAsync(string email, string userName); - Task SendFamiliesForEnterpriseOfferEmailAsync(string sponsorOrgName, string email, bool existingAccount, string token); - Task BulkSendFamiliesForEnterpriseOfferEmailAsync(string SponsorOrgName, IEnumerable<(string Email, bool ExistingAccount, string Token)> invites); - Task SendFamiliesForEnterpriseRedeemedEmailsAsync(string familyUserEmail, string sponsorEmail); - Task SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(string email, DateTime expirationDate); - Task SendOTPEmailAsync(string email, string token); - Task SendFailedLoginAttemptsEmailAsync(string email, DateTime utcNow, string ip); - Task SendFailedTwoFactorAttemptsEmailAsync(string email, DateTime utcNow, string ip); - } + Task SendWelcomeEmailAsync(User user); + Task SendVerifyEmailEmailAsync(string email, Guid userId, string token); + Task SendVerifyDeleteEmailAsync(string email, Guid userId, string token); + Task SendChangeEmailAlreadyExistsEmailAsync(string fromEmail, string toEmail); + Task SendChangeEmailEmailAsync(string newEmailAddress, string token); + Task SendTwoFactorEmailAsync(string email, string token); + Task SendNewDeviceLoginTwoFactorEmailAsync(string email, string token); + Task SendNoMasterPasswordHintEmailAsync(string email); + Task SendMasterPasswordHintEmailAsync(string email, string hint); + Task SendOrganizationInviteEmailAsync(string organizationName, OrganizationUser orgUser, ExpiringToken token); + Task BulkSendOrganizationInviteEmailAsync(string organizationName, IEnumerable<(OrganizationUser orgUser, ExpiringToken token)> invites); + Task SendOrganizationMaxSeatLimitReachedEmailAsync(Organization organization, int maxSeatCount, IEnumerable ownerEmails); + Task SendOrganizationAutoscaledEmailAsync(Organization organization, int initialSeatCount, IEnumerable ownerEmails); + Task SendOrganizationAcceptedEmailAsync(Organization organization, string userIdentifier, IEnumerable adminEmails); + Task SendOrganizationConfirmedEmailAsync(string organizationName, string email); + Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string organizationName, string email); + Task SendPasswordlessSignInAsync(string returnUrl, string token, string email); + Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, List items, + bool mentionInvoices); + Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices); + Task SendAddedCreditAsync(string email, decimal amount); + Task SendLicenseExpiredAsync(IEnumerable emails, string organizationName = null); + Task SendNewDeviceLoggedInEmail(string email, string deviceType, DateTime timestamp, string ip); + Task SendRecoverTwoFactorEmail(string email, DateTime timestamp, string ip); + Task SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(string organizationName, string email); + Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token); + Task SendEmergencyAccessAcceptedEmailAsync(string granteeEmail, string email); + Task SendEmergencyAccessConfirmedEmailAsync(string grantorName, string email); + Task SendEmergencyAccessRecoveryInitiated(EmergencyAccess emergencyAccess, string initiatingName, string email); + Task SendEmergencyAccessRecoveryApproved(EmergencyAccess emergencyAccess, string approvingName, string email); + Task SendEmergencyAccessRecoveryRejected(EmergencyAccess emergencyAccess, string rejectingName, string email); + Task SendEmergencyAccessRecoveryReminder(EmergencyAccess emergencyAccess, string initiatingName, string email); + Task SendEmergencyAccessRecoveryTimedOut(EmergencyAccess ea, string initiatingName, string email); + Task SendEnqueuedMailMessageAsync(IMailQueueMessage queueMessage); + Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName); + Task SendProviderSetupInviteEmailAsync(Provider provider, string token, string email); + Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email); + Task SendProviderConfirmedEmailAsync(string providerName, string email); + Task SendProviderUserRemoved(string providerName, string email); + Task SendUpdatedTempPasswordEmailAsync(string email, string userName); + Task SendFamiliesForEnterpriseOfferEmailAsync(string sponsorOrgName, string email, bool existingAccount, string token); + Task BulkSendFamiliesForEnterpriseOfferEmailAsync(string SponsorOrgName, IEnumerable<(string Email, bool ExistingAccount, string Token)> invites); + Task SendFamiliesForEnterpriseRedeemedEmailsAsync(string familyUserEmail, string sponsorEmail); + Task SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(string email, DateTime expirationDate); + Task SendOTPEmailAsync(string email, string token); + Task SendFailedLoginAttemptsEmailAsync(string email, DateTime utcNow, string ip); + Task SendFailedTwoFactorAttemptsEmailAsync(string email, DateTime utcNow, string ip); } diff --git a/src/Core/Services/IOrganizationService.cs b/src/Core/Services/IOrganizationService.cs index 076cd3eb8..3bd3e1f6e 100644 --- a/src/Core/Services/IOrganizationService.cs +++ b/src/Core/Services/IOrganizationService.cs @@ -3,66 +3,65 @@ using Bit.Core.Enums; using Bit.Core.Models.Business; using Bit.Core.Models.Data; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IOrganizationService { - public interface IOrganizationService - { - Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, PaymentMethodType paymentMethodType, - TaxInfo taxInfo); - Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null); - Task ReinstateSubscriptionAsync(Guid organizationId); - Task> UpgradePlanAsync(Guid organizationId, OrganizationUpgrade upgrade); - Task AdjustStorageAsync(Guid organizationId, short storageAdjustmentGb); - Task UpdateSubscription(Guid organizationId, int seatAdjustment, int? maxAutoscaleSeats); - Task AutoAddSeatsAsync(Organization organization, int seatsToAdd, DateTime? prorationDate = null); - Task AdjustSeatsAsync(Guid organizationId, int seatAdjustment, DateTime? prorationDate = null); - Task VerifyBankAsync(Guid organizationId, int amount1, int amount2); - Task> SignUpAsync(OrganizationSignup organizationSignup, bool provider = false); - Task> SignUpAsync(OrganizationLicense license, User owner, - string ownerKey, string collectionName, string publicKey, string privateKey); - Task UpdateLicenseAsync(Guid organizationId, OrganizationLicense license); - Task DeleteAsync(Organization organization); - Task EnableAsync(Guid organizationId, DateTime? expirationDate); - Task DisableAsync(Guid organizationId, DateTime? expirationDate); - Task UpdateExpirationDateAsync(Guid organizationId, DateTime? expirationDate); - Task EnableAsync(Guid organizationId); - Task UpdateAsync(Organization organization, bool updateBilling = false); - Task UpdateTwoFactorProviderAsync(Organization organization, TwoFactorProviderType type); - Task DisableTwoFactorProviderAsync(Organization organization, TwoFactorProviderType type); - Task> InviteUsersAsync(Guid organizationId, Guid? invitingUserId, - IEnumerable<(OrganizationUserInvite invite, string externalId)> invites); - Task InviteUserAsync(Guid organizationId, Guid? invitingUserId, string email, - OrganizationUserType type, bool accessAll, string externalId, IEnumerable collections); - Task>> ResendInvitesAsync(Guid organizationId, Guid? invitingUserId, IEnumerable organizationUsersId); - Task ResendInviteAsync(Guid organizationId, Guid? invitingUserId, Guid organizationUserId); - Task AcceptUserAsync(Guid organizationUserId, User user, string token, - IUserService userService); - Task AcceptUserAsync(string orgIdentifier, User user, IUserService userService); - Task ConfirmUserAsync(Guid organizationId, Guid organizationUserId, string key, - Guid confirmingUserId, IUserService userService); - Task>> ConfirmUsersAsync(Guid organizationId, Dictionary keys, - Guid confirmingUserId, IUserService userService); - Task SaveUserAsync(OrganizationUser user, Guid? savingUserId, IEnumerable collections); - Task DeleteUserAsync(Guid organizationId, Guid organizationUserId, Guid? deletingUserId); - Task DeleteUserAsync(Guid organizationId, Guid userId); - Task>> DeleteUsersAsync(Guid organizationId, - IEnumerable organizationUserIds, Guid? deletingUserId); - Task UpdateUserGroupsAsync(OrganizationUser organizationUser, IEnumerable groupIds, Guid? loggedInUserId); - Task UpdateUserResetPasswordEnrollmentAsync(Guid organizationId, Guid userId, string resetPasswordKey, Guid? callingUserId); - Task GenerateLicenseAsync(Guid organizationId, Guid installationId); - Task GenerateLicenseAsync(Organization organization, Guid installationId, - int? version = null); - Task ImportAsync(Guid organizationId, Guid? importingUserId, IEnumerable groups, - IEnumerable newUsers, IEnumerable removeUserExternalIds, - bool overwriteExisting); - Task DeleteSsoUserAsync(Guid userId, Guid? organizationId); - Task UpdateOrganizationKeysAsync(Guid orgId, string publicKey, string privateKey); - Task HasConfirmedOwnersExceptAsync(Guid organizationId, IEnumerable organizationUsersId, bool includeProvider = true); - Task RevokeUserAsync(OrganizationUser organizationUser, Guid? revokingUserId); - Task>> RevokeUsersAsync(Guid organizationId, - IEnumerable organizationUserIds, Guid? revokingUserId); - Task RestoreUserAsync(OrganizationUser organizationUser, Guid? restoringUserId, IUserService userService); - Task>> RestoreUsersAsync(Guid organizationId, - IEnumerable organizationUserIds, Guid? restoringUserId, IUserService userService); - } + Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, PaymentMethodType paymentMethodType, + TaxInfo taxInfo); + Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null); + Task ReinstateSubscriptionAsync(Guid organizationId); + Task> UpgradePlanAsync(Guid organizationId, OrganizationUpgrade upgrade); + Task AdjustStorageAsync(Guid organizationId, short storageAdjustmentGb); + Task UpdateSubscription(Guid organizationId, int seatAdjustment, int? maxAutoscaleSeats); + Task AutoAddSeatsAsync(Organization organization, int seatsToAdd, DateTime? prorationDate = null); + Task AdjustSeatsAsync(Guid organizationId, int seatAdjustment, DateTime? prorationDate = null); + Task VerifyBankAsync(Guid organizationId, int amount1, int amount2); + Task> SignUpAsync(OrganizationSignup organizationSignup, bool provider = false); + Task> SignUpAsync(OrganizationLicense license, User owner, + string ownerKey, string collectionName, string publicKey, string privateKey); + Task UpdateLicenseAsync(Guid organizationId, OrganizationLicense license); + Task DeleteAsync(Organization organization); + Task EnableAsync(Guid organizationId, DateTime? expirationDate); + Task DisableAsync(Guid organizationId, DateTime? expirationDate); + Task UpdateExpirationDateAsync(Guid organizationId, DateTime? expirationDate); + Task EnableAsync(Guid organizationId); + Task UpdateAsync(Organization organization, bool updateBilling = false); + Task UpdateTwoFactorProviderAsync(Organization organization, TwoFactorProviderType type); + Task DisableTwoFactorProviderAsync(Organization organization, TwoFactorProviderType type); + Task> InviteUsersAsync(Guid organizationId, Guid? invitingUserId, + IEnumerable<(OrganizationUserInvite invite, string externalId)> invites); + Task InviteUserAsync(Guid organizationId, Guid? invitingUserId, string email, + OrganizationUserType type, bool accessAll, string externalId, IEnumerable collections); + Task>> ResendInvitesAsync(Guid organizationId, Guid? invitingUserId, IEnumerable organizationUsersId); + Task ResendInviteAsync(Guid organizationId, Guid? invitingUserId, Guid organizationUserId); + Task AcceptUserAsync(Guid organizationUserId, User user, string token, + IUserService userService); + Task AcceptUserAsync(string orgIdentifier, User user, IUserService userService); + Task ConfirmUserAsync(Guid organizationId, Guid organizationUserId, string key, + Guid confirmingUserId, IUserService userService); + Task>> ConfirmUsersAsync(Guid organizationId, Dictionary keys, + Guid confirmingUserId, IUserService userService); + Task SaveUserAsync(OrganizationUser user, Guid? savingUserId, IEnumerable collections); + Task DeleteUserAsync(Guid organizationId, Guid organizationUserId, Guid? deletingUserId); + Task DeleteUserAsync(Guid organizationId, Guid userId); + Task>> DeleteUsersAsync(Guid organizationId, + IEnumerable organizationUserIds, Guid? deletingUserId); + Task UpdateUserGroupsAsync(OrganizationUser organizationUser, IEnumerable groupIds, Guid? loggedInUserId); + Task UpdateUserResetPasswordEnrollmentAsync(Guid organizationId, Guid userId, string resetPasswordKey, Guid? callingUserId); + Task GenerateLicenseAsync(Guid organizationId, Guid installationId); + Task GenerateLicenseAsync(Organization organization, Guid installationId, + int? version = null); + Task ImportAsync(Guid organizationId, Guid? importingUserId, IEnumerable groups, + IEnumerable newUsers, IEnumerable removeUserExternalIds, + bool overwriteExisting); + Task DeleteSsoUserAsync(Guid userId, Guid? organizationId); + Task UpdateOrganizationKeysAsync(Guid orgId, string publicKey, string privateKey); + Task HasConfirmedOwnersExceptAsync(Guid organizationId, IEnumerable organizationUsersId, bool includeProvider = true); + Task RevokeUserAsync(OrganizationUser organizationUser, Guid? revokingUserId); + Task>> RevokeUsersAsync(Guid organizationId, + IEnumerable organizationUserIds, Guid? revokingUserId); + Task RestoreUserAsync(OrganizationUser organizationUser, Guid? restoringUserId, IUserService userService); + Task>> RestoreUsersAsync(Guid organizationId, + IEnumerable organizationUserIds, Guid? restoringUserId, IUserService userService); } diff --git a/src/Core/Services/IPaymentService.cs b/src/Core/Services/IPaymentService.cs index 562c70e3a..a9091808d 100644 --- a/src/Core/Services/IPaymentService.cs +++ b/src/Core/Services/IPaymentService.cs @@ -3,36 +3,35 @@ using Bit.Core.Enums; using Bit.Core.Models.Business; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IPaymentService { - public interface IPaymentService - { - Task CancelAndRecoverChargesAsync(ISubscriber subscriber); - Task PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType, - string paymentToken, Plan plan, short additionalStorageGb, int additionalSeats, - bool premiumAccessAddon, TaxInfo taxInfo); - Task SponsorOrganizationAsync(Organization org, OrganizationSponsorship sponsorship); - Task RemoveOrganizationSponsorshipAsync(Organization org, OrganizationSponsorship sponsorship); - Task UpgradeFreeOrganizationAsync(Organization org, Plan plan, - short additionalStorageGb, int additionalSeats, bool premiumAccessAddon, TaxInfo taxInfo); - Task PurchasePremiumAsync(User user, PaymentMethodType paymentMethodType, string paymentToken, - short additionalStorageGb, TaxInfo taxInfo); - Task AdjustSeatsAsync(Organization organization, Plan plan, int additionalSeats, DateTime? prorationDate = null); - Task AdjustStorageAsync(IStorableSubscriber storableSubscriber, int additionalStorage, string storagePlanId, DateTime? prorationDate = null); - Task CancelSubscriptionAsync(ISubscriber subscriber, bool endOfPeriod = false, - bool skipInAppPurchaseCheck = false); - Task ReinstateSubscriptionAsync(ISubscriber subscriber); - Task UpdatePaymentMethodAsync(ISubscriber subscriber, PaymentMethodType paymentMethodType, - string paymentToken, bool allowInAppPurchases = false, TaxInfo taxInfo = null); - Task CreditAccountAsync(ISubscriber subscriber, decimal creditAmount); - Task GetBillingAsync(ISubscriber subscriber); - Task GetBillingHistoryAsync(ISubscriber subscriber); - Task GetBillingBalanceAndSourceAsync(ISubscriber subscriber); - Task GetSubscriptionAsync(ISubscriber subscriber); - Task GetTaxInfoAsync(ISubscriber subscriber); - Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo); - Task CreateTaxRateAsync(TaxRate taxRate); - Task UpdateTaxRateAsync(TaxRate taxRate); - Task ArchiveTaxRateAsync(TaxRate taxRate); - } + Task CancelAndRecoverChargesAsync(ISubscriber subscriber); + Task PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType, + string paymentToken, Plan plan, short additionalStorageGb, int additionalSeats, + bool premiumAccessAddon, TaxInfo taxInfo); + Task SponsorOrganizationAsync(Organization org, OrganizationSponsorship sponsorship); + Task RemoveOrganizationSponsorshipAsync(Organization org, OrganizationSponsorship sponsorship); + Task UpgradeFreeOrganizationAsync(Organization org, Plan plan, + short additionalStorageGb, int additionalSeats, bool premiumAccessAddon, TaxInfo taxInfo); + Task PurchasePremiumAsync(User user, PaymentMethodType paymentMethodType, string paymentToken, + short additionalStorageGb, TaxInfo taxInfo); + Task AdjustSeatsAsync(Organization organization, Plan plan, int additionalSeats, DateTime? prorationDate = null); + Task AdjustStorageAsync(IStorableSubscriber storableSubscriber, int additionalStorage, string storagePlanId, DateTime? prorationDate = null); + Task CancelSubscriptionAsync(ISubscriber subscriber, bool endOfPeriod = false, + bool skipInAppPurchaseCheck = false); + Task ReinstateSubscriptionAsync(ISubscriber subscriber); + Task UpdatePaymentMethodAsync(ISubscriber subscriber, PaymentMethodType paymentMethodType, + string paymentToken, bool allowInAppPurchases = false, TaxInfo taxInfo = null); + Task CreditAccountAsync(ISubscriber subscriber, decimal creditAmount); + Task GetBillingAsync(ISubscriber subscriber); + Task GetBillingHistoryAsync(ISubscriber subscriber); + Task GetBillingBalanceAndSourceAsync(ISubscriber subscriber); + Task GetSubscriptionAsync(ISubscriber subscriber); + Task GetTaxInfoAsync(ISubscriber subscriber); + Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo); + Task CreateTaxRateAsync(TaxRate taxRate); + Task UpdateTaxRateAsync(TaxRate taxRate); + Task ArchiveTaxRateAsync(TaxRate taxRate); } diff --git a/src/Core/Services/IPolicyService.cs b/src/Core/Services/IPolicyService.cs index d7487cbd4..5f1b4d366 100644 --- a/src/Core/Services/IPolicyService.cs +++ b/src/Core/Services/IPolicyService.cs @@ -1,10 +1,9 @@ using Bit.Core.Entities; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IPolicyService { - public interface IPolicyService - { - Task SaveAsync(Policy policy, IUserService userService, IOrganizationService organizationService, - Guid? savingUserId); - } + Task SaveAsync(Policy policy, IUserService userService, IOrganizationService organizationService, + Guid? savingUserId); } diff --git a/src/Core/Services/IProviderService.cs b/src/Core/Services/IProviderService.cs index eb38afad2..c5cf039b2 100644 --- a/src/Core/Services/IProviderService.cs +++ b/src/Core/Services/IProviderService.cs @@ -3,29 +3,28 @@ using Bit.Core.Entities.Provider; using Bit.Core.Models.Business; using Bit.Core.Models.Business.Provider; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IProviderService { - public interface IProviderService - { - Task CreateAsync(string ownerEmail); - Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key); - Task UpdateAsync(Provider provider, bool updateBilling = false); + Task CreateAsync(string ownerEmail); + Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key); + Task UpdateAsync(Provider provider, bool updateBilling = false); - Task> InviteUserAsync(ProviderUserInvite invite); - Task>> ResendInvitesAsync(ProviderUserInvite invite); - Task AcceptUserAsync(Guid providerUserId, User user, string token); - Task>> ConfirmUsersAsync(Guid providerId, Dictionary keys, Guid confirmingUserId); + Task> InviteUserAsync(ProviderUserInvite invite); + Task>> ResendInvitesAsync(ProviderUserInvite invite); + Task AcceptUserAsync(Guid providerUserId, User user, string token); + Task>> ConfirmUsersAsync(Guid providerId, Dictionary keys, Guid confirmingUserId); - Task SaveUserAsync(ProviderUser user, Guid savingUserId); - Task>> DeleteUsersAsync(Guid providerId, IEnumerable providerUserIds, - Guid deletingUserId); + Task SaveUserAsync(ProviderUser user, Guid savingUserId); + Task>> DeleteUsersAsync(Guid providerId, IEnumerable providerUserIds, + Guid deletingUserId); - Task AddOrganization(Guid providerId, Guid organizationId, Guid addingUserId, string key); - Task CreateOrganizationAsync(Guid providerId, OrganizationSignup organizationSignup, - string clientOwnerEmail, User user); - Task RemoveOrganizationAsync(Guid providerId, Guid providerOrganizationId, Guid removingUserId); - Task LogProviderAccessToOrganizationAsync(Guid organizationId); - Task ResendProviderSetupInviteEmailAsync(Guid providerId, Guid ownerId); - } + Task AddOrganization(Guid providerId, Guid organizationId, Guid addingUserId, string key); + Task CreateOrganizationAsync(Guid providerId, OrganizationSignup organizationSignup, + string clientOwnerEmail, User user); + Task RemoveOrganizationAsync(Guid providerId, Guid providerOrganizationId, Guid removingUserId); + Task LogProviderAccessToOrganizationAsync(Guid organizationId); + Task ResendProviderSetupInviteEmailAsync(Guid providerId, Guid ownerId); } diff --git a/src/Core/Services/IPushNotificationService.cs b/src/Core/Services/IPushNotificationService.cs index 9707b93dc..34e98515f 100644 --- a/src/Core/Services/IPushNotificationService.cs +++ b/src/Core/Services/IPushNotificationService.cs @@ -1,26 +1,25 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IPushNotificationService { - public interface IPushNotificationService - { - Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds); - Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds); - Task PushSyncCipherDeleteAsync(Cipher cipher); - Task PushSyncFolderCreateAsync(Folder folder); - Task PushSyncFolderUpdateAsync(Folder folder); - Task PushSyncFolderDeleteAsync(Folder folder); - Task PushSyncCiphersAsync(Guid userId); - Task PushSyncVaultAsync(Guid userId); - Task PushSyncOrgKeysAsync(Guid userId); - Task PushSyncSettingsAsync(Guid userId); - Task PushLogOutAsync(Guid userId); - Task PushSyncSendCreateAsync(Send send); - Task PushSyncSendUpdateAsync(Send send); - Task PushSyncSendDeleteAsync(Send send); - Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, string deviceId = null); - Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null); - } + Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds); + Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds); + Task PushSyncCipherDeleteAsync(Cipher cipher); + Task PushSyncFolderCreateAsync(Folder folder); + Task PushSyncFolderUpdateAsync(Folder folder); + Task PushSyncFolderDeleteAsync(Folder folder); + Task PushSyncCiphersAsync(Guid userId); + Task PushSyncVaultAsync(Guid userId); + Task PushSyncOrgKeysAsync(Guid userId); + Task PushSyncSettingsAsync(Guid userId); + Task PushLogOutAsync(Guid userId); + Task PushSyncSendCreateAsync(Send send); + Task PushSyncSendUpdateAsync(Send send); + Task PushSyncSendDeleteAsync(Send send); + Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, string deviceId = null); + Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, + string deviceId = null); } diff --git a/src/Core/Services/IPushRegistrationService.cs b/src/Core/Services/IPushRegistrationService.cs index 14d2c82ef..985246de0 100644 --- a/src/Core/Services/IPushRegistrationService.cs +++ b/src/Core/Services/IPushRegistrationService.cs @@ -1,13 +1,12 @@ using Bit.Core.Enums; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IPushRegistrationService { - public interface IPushRegistrationService - { - Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, - string identifier, DeviceType type); - Task DeleteRegistrationAsync(string deviceId); - Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId); - Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId); - } + Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, + string identifier, DeviceType type); + Task DeleteRegistrationAsync(string deviceId); + Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId); + Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId); } diff --git a/src/Core/Services/IReferenceEventService.cs b/src/Core/Services/IReferenceEventService.cs index fa85a2a3d..03339f08c 100644 --- a/src/Core/Services/IReferenceEventService.cs +++ b/src/Core/Services/IReferenceEventService.cs @@ -1,9 +1,8 @@ using Bit.Core.Models.Business; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IReferenceEventService { - public interface IReferenceEventService - { - Task RaiseEventAsync(ReferenceEvent referenceEvent); - } + Task RaiseEventAsync(ReferenceEvent referenceEvent); } diff --git a/src/Core/Services/ISendService.cs b/src/Core/Services/ISendService.cs index 8ee97a629..a2b6b8c35 100644 --- a/src/Core/Services/ISendService.cs +++ b/src/Core/Services/ISendService.cs @@ -1,17 +1,16 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface ISendService { - public interface ISendService - { - Task DeleteSendAsync(Send send); - Task SaveSendAsync(Send send); - Task SaveFileSendAsync(Send send, SendFileData data, long fileLength); - Task UploadFileToExistingSendAsync(Stream stream, Send send); - Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password); - string HashPassword(string password); - Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password); - Task ValidateSendFile(Send send); - } + Task DeleteSendAsync(Send send); + Task SaveSendAsync(Send send); + Task SaveFileSendAsync(Send send, SendFileData data, long fileLength); + Task UploadFileToExistingSendAsync(Stream stream, Send send); + Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password); + string HashPassword(string password); + Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password); + Task ValidateSendFile(Send send); } diff --git a/src/Core/Services/ISendStorageService.cs b/src/Core/Services/ISendStorageService.cs index f671d0077..63c0d44ca 100644 --- a/src/Core/Services/ISendStorageService.cs +++ b/src/Core/Services/ISendStorageService.cs @@ -1,17 +1,16 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface ISendFileStorageService { - public interface ISendFileStorageService - { - FileUploadType FileUploadType { get; } - Task UploadNewFileAsync(Stream stream, Send send, string fileId); - Task DeleteFileAsync(Send send, string fileId); - Task DeleteFilesForOrganizationAsync(Guid organizationId); - Task DeleteFilesForUserAsync(Guid userId); - Task GetSendFileDownloadUrlAsync(Send send, string fileId); - Task GetSendFileUploadUrlAsync(Send send, string fileId); - Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway); - } + FileUploadType FileUploadType { get; } + Task UploadNewFileAsync(Stream stream, Send send, string fileId); + Task DeleteFileAsync(Send send, string fileId); + Task DeleteFilesForOrganizationAsync(Guid organizationId); + Task DeleteFilesForUserAsync(Guid userId); + Task GetSendFileDownloadUrlAsync(Send send, string fileId); + Task GetSendFileUploadUrlAsync(Send send, string fileId); + Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway); } diff --git a/src/Core/Services/ISsoConfigService.cs b/src/Core/Services/ISsoConfigService.cs index d4d2cfcef..c25127d95 100644 --- a/src/Core/Services/ISsoConfigService.cs +++ b/src/Core/Services/ISsoConfigService.cs @@ -1,9 +1,8 @@ using Bit.Core.Entities; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface ISsoConfigService { - public interface ISsoConfigService - { - Task SaveAsync(SsoConfig config, Organization organization); - } + Task SaveAsync(SsoConfig config, Organization organization); } diff --git a/src/Core/Services/IStripeAdapter.cs b/src/Core/Services/IStripeAdapter.cs index ffb0e2a1c..ff922161c 100644 --- a/src/Core/Services/IStripeAdapter.cs +++ b/src/Core/Services/IStripeAdapter.cs @@ -1,40 +1,39 @@ using Bit.Core.Models.BitStripe; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IStripeAdapter { - public interface IStripeAdapter - { - Task CustomerCreateAsync(Stripe.CustomerCreateOptions customerCreateOptions); - Task CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null); - Task CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null); - Task CustomerDeleteAsync(string id); - Task SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions subscriptionCreateOptions); - Task SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null); - Task> SubscriptionListAsync(StripeSubscriptionListOptions subscriptionSearchOptions); - Task SubscriptionUpdateAsync(string id, Stripe.SubscriptionUpdateOptions options = null); - Task SubscriptionCancelAsync(string Id, Stripe.SubscriptionCancelOptions options = null); - Task InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options); - Task InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options); - Task> InvoiceListAsync(Stripe.InvoiceListOptions options); - Task InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options); - Task InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options); - Task InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options); - Task InvoicePayAsync(string id, Stripe.InvoicePayOptions options = null); - Task InvoiceDeleteAsync(string id, Stripe.InvoiceDeleteOptions options = null); - Task InvoiceVoidInvoiceAsync(string id, Stripe.InvoiceVoidOptions options = null); - IEnumerable PaymentMethodListAutoPaging(Stripe.PaymentMethodListOptions options); - Task PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null); - Task PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null); - Task TaxRateCreateAsync(Stripe.TaxRateCreateOptions options); - Task TaxRateUpdateAsync(string id, Stripe.TaxRateUpdateOptions options); - Task TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options); - Task TaxIdDeleteAsync(string customerId, string taxIdId, Stripe.TaxIdDeleteOptions options = null); - Task> ChargeListAsync(Stripe.ChargeListOptions options); - Task RefundCreateAsync(Stripe.RefundCreateOptions options); - Task CardDeleteAsync(string customerId, string cardId, Stripe.CardDeleteOptions options = null); - Task BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null); - Task BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null); - Task> PriceListAsync(Stripe.PriceListOptions options = null); - Task> TestClockListAsync(); - } + Task CustomerCreateAsync(Stripe.CustomerCreateOptions customerCreateOptions); + Task CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null); + Task CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null); + Task CustomerDeleteAsync(string id); + Task SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions subscriptionCreateOptions); + Task SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null); + Task> SubscriptionListAsync(StripeSubscriptionListOptions subscriptionSearchOptions); + Task SubscriptionUpdateAsync(string id, Stripe.SubscriptionUpdateOptions options = null); + Task SubscriptionCancelAsync(string Id, Stripe.SubscriptionCancelOptions options = null); + Task InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options); + Task InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options); + Task> InvoiceListAsync(Stripe.InvoiceListOptions options); + Task InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options); + Task InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options); + Task InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options); + Task InvoicePayAsync(string id, Stripe.InvoicePayOptions options = null); + Task InvoiceDeleteAsync(string id, Stripe.InvoiceDeleteOptions options = null); + Task InvoiceVoidInvoiceAsync(string id, Stripe.InvoiceVoidOptions options = null); + IEnumerable PaymentMethodListAutoPaging(Stripe.PaymentMethodListOptions options); + Task PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null); + Task PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null); + Task TaxRateCreateAsync(Stripe.TaxRateCreateOptions options); + Task TaxRateUpdateAsync(string id, Stripe.TaxRateUpdateOptions options); + Task TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options); + Task TaxIdDeleteAsync(string customerId, string taxIdId, Stripe.TaxIdDeleteOptions options = null); + Task> ChargeListAsync(Stripe.ChargeListOptions options); + Task RefundCreateAsync(Stripe.RefundCreateOptions options); + Task CardDeleteAsync(string customerId, string cardId, Stripe.CardDeleteOptions options = null); + Task BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null); + Task BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null); + Task> PriceListAsync(Stripe.PriceListOptions options = null); + Task> TestClockListAsync(); } diff --git a/src/Core/Services/IStripeSyncService.cs b/src/Core/Services/IStripeSyncService.cs index 0219bc5d2..655998805 100644 --- a/src/Core/Services/IStripeSyncService.cs +++ b/src/Core/Services/IStripeSyncService.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IStripeSyncService { - public interface IStripeSyncService - { - Task UpdateCustomerEmailAddress(string gatewayCustomerId, string emailAddress); - } + Task UpdateCustomerEmailAddress(string gatewayCustomerId, string emailAddress); } diff --git a/src/Core/Services/IUserService.cs b/src/Core/Services/IUserService.cs index 989bea85d..077f66756 100644 --- a/src/Core/Services/IUserService.cs +++ b/src/Core/Services/IUserService.cs @@ -6,77 +6,76 @@ using Bit.Core.Models.Business; using Fido2NetLib; using Microsoft.AspNetCore.Identity; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public interface IUserService { - public interface IUserService - { - Guid? GetProperUserId(ClaimsPrincipal principal); - Task GetUserByIdAsync(string userId); - Task GetUserByIdAsync(Guid userId); - Task GetUserByPrincipalAsync(ClaimsPrincipal principal); - Task GetAccountRevisionDateByIdAsync(Guid userId); - Task SaveUserAsync(User user, bool push = false); - Task RegisterUserAsync(User user, string masterPassword, string token, Guid? orgUserId); - Task RegisterUserAsync(User user); - Task SendMasterPasswordHintAsync(string email); - Task SendTwoFactorEmailAsync(User user, bool isBecauseNewDeviceLogin = false); - Task VerifyTwoFactorEmailAsync(User user, string token); - Task StartWebAuthnRegistrationAsync(User user); - Task DeleteWebAuthnKeyAsync(User user, int id); - Task CompleteWebAuthRegistrationAsync(User user, int value, string name, AuthenticatorAttestationRawResponse attestationResponse); - Task SendEmailVerificationAsync(User user); - Task ConfirmEmailAsync(User user, string token); - Task InitiateEmailChangeAsync(User user, string newEmail); - Task ChangeEmailAsync(User user, string masterPassword, string newEmail, string newMasterPassword, - string token, string key); - Task ChangePasswordAsync(User user, string masterPassword, string newMasterPassword, string passwordHint, string key); - Task SetPasswordAsync(User user, string newMasterPassword, string key, string orgIdentifier = null); - Task SetKeyConnectorKeyAsync(User user, string key, string orgIdentifier); - Task ConvertToKeyConnectorAsync(User user); - Task AdminResetPasswordAsync(OrganizationUserType type, Guid orgId, Guid id, string newMasterPassword, string key); - Task UpdateTempPasswordAsync(User user, string newMasterPassword, string key, string hint); - Task ChangeKdfAsync(User user, string masterPassword, string newMasterPassword, string key, - KdfType kdf, int kdfIterations); - Task UpdateKeyAsync(User user, string masterPassword, string key, string privateKey, - IEnumerable ciphers, IEnumerable folders, IEnumerable sends); - Task RefreshSecurityStampAsync(User user, string masterPasswordHash); - Task UpdateTwoFactorProviderAsync(User user, TwoFactorProviderType type, bool setEnabled = true, bool logEvent = true); - Task DisableTwoFactorProviderAsync(User user, TwoFactorProviderType type, - IOrganizationService organizationService); - Task RecoverTwoFactorAsync(string email, string masterPassword, string recoveryCode, - IOrganizationService organizationService); - Task GenerateUserTokenAsync(User user, string tokenProvider, string purpose); - Task DeleteAsync(User user); - Task DeleteAsync(User user, string token); - Task SendDeleteConfirmationAsync(string email); - Task> SignUpPremiumAsync(User user, string paymentToken, - PaymentMethodType paymentMethodType, short additionalStorageGb, UserLicense license, - TaxInfo taxInfo); - Task IapCheckAsync(User user, PaymentMethodType paymentMethodType); - Task UpdateLicenseAsync(User user, UserLicense license); - Task AdjustStorageAsync(User user, short storageAdjustmentGb); - Task ReplacePaymentMethodAsync(User user, string paymentToken, PaymentMethodType paymentMethodType, TaxInfo taxInfo); - Task CancelPremiumAsync(User user, bool? endOfPeriod = null, bool accountDelete = false); - Task ReinstatePremiumAsync(User user); - Task EnablePremiumAsync(Guid userId, DateTime? expirationDate); - Task EnablePremiumAsync(User user, DateTime? expirationDate); - Task DisablePremiumAsync(Guid userId, DateTime? expirationDate); - Task DisablePremiumAsync(User user, DateTime? expirationDate); - Task UpdatePremiumExpirationAsync(Guid userId, DateTime? expirationDate); - Task GenerateLicenseAsync(User user, SubscriptionInfo subscriptionInfo = null, - int? version = null); - Task CheckPasswordAsync(User user, string password); - Task CanAccessPremium(ITwoFactorProvidersUser user); - Task HasPremiumFromOrganization(ITwoFactorProvidersUser user); - Task TwoFactorIsEnabledAsync(ITwoFactorProvidersUser user); - Task TwoFactorProviderIsEnabledAsync(TwoFactorProviderType provider, ITwoFactorProvidersUser user); - Task GenerateSignInTokenAsync(User user, string purpose); - Task RotateApiKeyAsync(User user); - string GetUserName(ClaimsPrincipal principal); - Task SendOTPAsync(User user); - Task VerifyOTPAsync(User user, string token); - Task VerifySecretAsync(User user, string secret); - Task Needs2FABecauseNewDeviceAsync(User user, string deviceIdentifier, string grantType); - bool CanEditDeviceVerificationSettings(User user); - } + Guid? GetProperUserId(ClaimsPrincipal principal); + Task GetUserByIdAsync(string userId); + Task GetUserByIdAsync(Guid userId); + Task GetUserByPrincipalAsync(ClaimsPrincipal principal); + Task GetAccountRevisionDateByIdAsync(Guid userId); + Task SaveUserAsync(User user, bool push = false); + Task RegisterUserAsync(User user, string masterPassword, string token, Guid? orgUserId); + Task RegisterUserAsync(User user); + Task SendMasterPasswordHintAsync(string email); + Task SendTwoFactorEmailAsync(User user, bool isBecauseNewDeviceLogin = false); + Task VerifyTwoFactorEmailAsync(User user, string token); + Task StartWebAuthnRegistrationAsync(User user); + Task DeleteWebAuthnKeyAsync(User user, int id); + Task CompleteWebAuthRegistrationAsync(User user, int value, string name, AuthenticatorAttestationRawResponse attestationResponse); + Task SendEmailVerificationAsync(User user); + Task ConfirmEmailAsync(User user, string token); + Task InitiateEmailChangeAsync(User user, string newEmail); + Task ChangeEmailAsync(User user, string masterPassword, string newEmail, string newMasterPassword, + string token, string key); + Task ChangePasswordAsync(User user, string masterPassword, string newMasterPassword, string passwordHint, string key); + Task SetPasswordAsync(User user, string newMasterPassword, string key, string orgIdentifier = null); + Task SetKeyConnectorKeyAsync(User user, string key, string orgIdentifier); + Task ConvertToKeyConnectorAsync(User user); + Task AdminResetPasswordAsync(OrganizationUserType type, Guid orgId, Guid id, string newMasterPassword, string key); + Task UpdateTempPasswordAsync(User user, string newMasterPassword, string key, string hint); + Task ChangeKdfAsync(User user, string masterPassword, string newMasterPassword, string key, + KdfType kdf, int kdfIterations); + Task UpdateKeyAsync(User user, string masterPassword, string key, string privateKey, + IEnumerable ciphers, IEnumerable folders, IEnumerable sends); + Task RefreshSecurityStampAsync(User user, string masterPasswordHash); + Task UpdateTwoFactorProviderAsync(User user, TwoFactorProviderType type, bool setEnabled = true, bool logEvent = true); + Task DisableTwoFactorProviderAsync(User user, TwoFactorProviderType type, + IOrganizationService organizationService); + Task RecoverTwoFactorAsync(string email, string masterPassword, string recoveryCode, + IOrganizationService organizationService); + Task GenerateUserTokenAsync(User user, string tokenProvider, string purpose); + Task DeleteAsync(User user); + Task DeleteAsync(User user, string token); + Task SendDeleteConfirmationAsync(string email); + Task> SignUpPremiumAsync(User user, string paymentToken, + PaymentMethodType paymentMethodType, short additionalStorageGb, UserLicense license, + TaxInfo taxInfo); + Task IapCheckAsync(User user, PaymentMethodType paymentMethodType); + Task UpdateLicenseAsync(User user, UserLicense license); + Task AdjustStorageAsync(User user, short storageAdjustmentGb); + Task ReplacePaymentMethodAsync(User user, string paymentToken, PaymentMethodType paymentMethodType, TaxInfo taxInfo); + Task CancelPremiumAsync(User user, bool? endOfPeriod = null, bool accountDelete = false); + Task ReinstatePremiumAsync(User user); + Task EnablePremiumAsync(Guid userId, DateTime? expirationDate); + Task EnablePremiumAsync(User user, DateTime? expirationDate); + Task DisablePremiumAsync(Guid userId, DateTime? expirationDate); + Task DisablePremiumAsync(User user, DateTime? expirationDate); + Task UpdatePremiumExpirationAsync(Guid userId, DateTime? expirationDate); + Task GenerateLicenseAsync(User user, SubscriptionInfo subscriptionInfo = null, + int? version = null); + Task CheckPasswordAsync(User user, string password); + Task CanAccessPremium(ITwoFactorProvidersUser user); + Task HasPremiumFromOrganization(ITwoFactorProvidersUser user); + Task TwoFactorIsEnabledAsync(ITwoFactorProvidersUser user); + Task TwoFactorProviderIsEnabledAsync(TwoFactorProviderType provider, ITwoFactorProvidersUser user); + Task GenerateSignInTokenAsync(User user, string purpose); + Task RotateApiKeyAsync(User user); + string GetUserName(ClaimsPrincipal principal); + Task SendOTPAsync(User user); + Task VerifyOTPAsync(User user, string token); + Task VerifySecretAsync(User user, string secret); + Task Needs2FABecauseNewDeviceAsync(User user, string deviceIdentifier, string grantType); + bool CanEditDeviceVerificationSettings(User user); } diff --git a/src/Core/Services/Implementations/AmazonSesMailDeliveryService.cs b/src/Core/Services/Implementations/AmazonSesMailDeliveryService.cs index 98275b4ab..adf406cf0 100644 --- a/src/Core/Services/Implementations/AmazonSesMailDeliveryService.cs +++ b/src/Core/Services/Implementations/AmazonSesMailDeliveryService.cs @@ -7,137 +7,136 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class AmazonSesMailDeliveryService : IMailDeliveryService, IDisposable { - public class AmazonSesMailDeliveryService : IMailDeliveryService, IDisposable + private readonly GlobalSettings _globalSettings; + private readonly IWebHostEnvironment _hostingEnvironment; + private readonly ILogger _logger; + private readonly IAmazonSimpleEmailService _client; + private readonly string _source; + private readonly string _senderTag; + private readonly string _configSetName; + + public AmazonSesMailDeliveryService( + GlobalSettings globalSettings, + IWebHostEnvironment hostingEnvironment, + ILogger logger) + : this(globalSettings, hostingEnvironment, logger, + new AmazonSimpleEmailServiceClient( + globalSettings.Amazon.AccessKeyId, + globalSettings.Amazon.AccessKeySecret, + RegionEndpoint.GetBySystemName(globalSettings.Amazon.Region)) + ) { - private readonly GlobalSettings _globalSettings; - private readonly IWebHostEnvironment _hostingEnvironment; - private readonly ILogger _logger; - private readonly IAmazonSimpleEmailService _client; - private readonly string _source; - private readonly string _senderTag; - private readonly string _configSetName; + } - public AmazonSesMailDeliveryService( - GlobalSettings globalSettings, - IWebHostEnvironment hostingEnvironment, - ILogger logger) - : this(globalSettings, hostingEnvironment, logger, - new AmazonSimpleEmailServiceClient( - globalSettings.Amazon.AccessKeyId, - globalSettings.Amazon.AccessKeySecret, - RegionEndpoint.GetBySystemName(globalSettings.Amazon.Region)) - ) + public AmazonSesMailDeliveryService( + GlobalSettings globalSettings, + IWebHostEnvironment hostingEnvironment, + ILogger logger, + IAmazonSimpleEmailService amazonSimpleEmailService) + { + if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeyId)) { + throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeyId)); + } + if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeySecret)) + { + throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeySecret)); + } + if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.Region)) + { + throw new ArgumentNullException(nameof(globalSettings.Amazon.Region)); } - public AmazonSesMailDeliveryService( - GlobalSettings globalSettings, - IWebHostEnvironment hostingEnvironment, - ILogger logger, - IAmazonSimpleEmailService amazonSimpleEmailService) + var replyToEmail = CoreHelpers.PunyEncode(globalSettings.Mail.ReplyToEmail); + + _globalSettings = globalSettings; + _hostingEnvironment = hostingEnvironment; + _logger = logger; + _client = amazonSimpleEmailService; + _source = $"\"{globalSettings.SiteName}\" <{replyToEmail}>"; + _senderTag = $"Server_{globalSettings.ProjectName?.Replace(' ', '_')}"; + if (!string.IsNullOrWhiteSpace(_globalSettings.Mail.AmazonConfigSetName)) { - if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeyId)) - { - throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeyId)); - } - if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeySecret)) - { - throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeySecret)); - } - if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.Region)) - { - throw new ArgumentNullException(nameof(globalSettings.Amazon.Region)); - } - - var replyToEmail = CoreHelpers.PunyEncode(globalSettings.Mail.ReplyToEmail); - - _globalSettings = globalSettings; - _hostingEnvironment = hostingEnvironment; - _logger = logger; - _client = amazonSimpleEmailService; - _source = $"\"{globalSettings.SiteName}\" <{replyToEmail}>"; - _senderTag = $"Server_{globalSettings.ProjectName?.Replace(' ', '_')}"; - if (!string.IsNullOrWhiteSpace(_globalSettings.Mail.AmazonConfigSetName)) - { - _configSetName = _globalSettings.Mail.AmazonConfigSetName; - } - } - - public void Dispose() - { - _client?.Dispose(); - } - - public async Task SendEmailAsync(MailMessage message) - { - var request = new SendEmailRequest - { - ConfigurationSetName = _configSetName, - Source = _source, - Destination = new Destination - { - ToAddresses = message.ToEmails - .Select(email => CoreHelpers.PunyEncode(email)) - .ToList() - }, - Message = new Message - { - Subject = new Content(message.Subject), - Body = new Body - { - Html = new Content - { - Charset = "UTF-8", - Data = message.HtmlContent - }, - Text = new Content - { - Charset = "UTF-8", - Data = message.TextContent - } - } - }, - Tags = new List - { - new MessageTag { Name = "Environment", Value = _hostingEnvironment.EnvironmentName }, - new MessageTag { Name = "Sender", Value = _senderTag } - } - }; - - if (message.BccEmails?.Any() ?? false) - { - request.Destination.BccAddresses = message.BccEmails - .Select(email => CoreHelpers.PunyEncode(email)) - .ToList(); - } - - if (!string.IsNullOrWhiteSpace(message.Category)) - { - request.Tags.Add(new MessageTag { Name = "Category", Value = message.Category }); - } - - try - { - await SendAsync(request, false); - } - catch (Exception e) - { - _logger.LogWarning(e, "Failed to send email. Retrying..."); - await SendAsync(request, true); - throw; - } - } - - private async Task SendAsync(SendEmailRequest request, bool retry) - { - if (retry) - { - // wait and try again - await Task.Delay(2000); - } - await _client.SendEmailAsync(request); + _configSetName = _globalSettings.Mail.AmazonConfigSetName; } } + + public void Dispose() + { + _client?.Dispose(); + } + + public async Task SendEmailAsync(MailMessage message) + { + var request = new SendEmailRequest + { + ConfigurationSetName = _configSetName, + Source = _source, + Destination = new Destination + { + ToAddresses = message.ToEmails + .Select(email => CoreHelpers.PunyEncode(email)) + .ToList() + }, + Message = new Message + { + Subject = new Content(message.Subject), + Body = new Body + { + Html = new Content + { + Charset = "UTF-8", + Data = message.HtmlContent + }, + Text = new Content + { + Charset = "UTF-8", + Data = message.TextContent + } + } + }, + Tags = new List + { + new MessageTag { Name = "Environment", Value = _hostingEnvironment.EnvironmentName }, + new MessageTag { Name = "Sender", Value = _senderTag } + } + }; + + if (message.BccEmails?.Any() ?? false) + { + request.Destination.BccAddresses = message.BccEmails + .Select(email => CoreHelpers.PunyEncode(email)) + .ToList(); + } + + if (!string.IsNullOrWhiteSpace(message.Category)) + { + request.Tags.Add(new MessageTag { Name = "Category", Value = message.Category }); + } + + try + { + await SendAsync(request, false); + } + catch (Exception e) + { + _logger.LogWarning(e, "Failed to send email. Retrying..."); + await SendAsync(request, true); + throw; + } + } + + private async Task SendAsync(SendEmailRequest request, bool retry) + { + if (retry) + { + // wait and try again + await Task.Delay(2000); + } + await _client.SendEmailAsync(request); + } } diff --git a/src/Core/Services/Implementations/AmazonSqsBlockIpService.cs b/src/Core/Services/Implementations/AmazonSqsBlockIpService.cs index 1e6dcf935..ac5dfb45c 100644 --- a/src/Core/Services/Implementations/AmazonSqsBlockIpService.cs +++ b/src/Core/Services/Implementations/AmazonSqsBlockIpService.cs @@ -2,81 +2,80 @@ using Amazon.SQS; using Bit.Core.Settings; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class AmazonSqsBlockIpService : IBlockIpService, IDisposable { - public class AmazonSqsBlockIpService : IBlockIpService, IDisposable + private readonly IAmazonSQS _client; + private string _blockIpQueueUrl; + private string _unblockIpQueueUrl; + private bool _didInit = false; + private Tuple _lastBlock; + + public AmazonSqsBlockIpService( + GlobalSettings globalSettings) + : this(globalSettings, new AmazonSQSClient( + globalSettings.Amazon.AccessKeyId, + globalSettings.Amazon.AccessKeySecret, + RegionEndpoint.GetBySystemName(globalSettings.Amazon.Region)) + ) { - private readonly IAmazonSQS _client; - private string _blockIpQueueUrl; - private string _unblockIpQueueUrl; - private bool _didInit = false; - private Tuple _lastBlock; + } - public AmazonSqsBlockIpService( - GlobalSettings globalSettings) - : this(globalSettings, new AmazonSQSClient( - globalSettings.Amazon.AccessKeyId, - globalSettings.Amazon.AccessKeySecret, - RegionEndpoint.GetBySystemName(globalSettings.Amazon.Region)) - ) + public AmazonSqsBlockIpService( + GlobalSettings globalSettings, + IAmazonSQS amazonSqs) + { + if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeyId)) { + throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeyId)); + } + if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeySecret)) + { + throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeySecret)); + } + if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.Region)) + { + throw new ArgumentNullException(nameof(globalSettings.Amazon.Region)); } - public AmazonSqsBlockIpService( - GlobalSettings globalSettings, - IAmazonSQS amazonSqs) - { - if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeyId)) - { - throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeyId)); - } - if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.AccessKeySecret)) - { - throw new ArgumentNullException(nameof(globalSettings.Amazon.AccessKeySecret)); - } - if (string.IsNullOrWhiteSpace(globalSettings.Amazon?.Region)) - { - throw new ArgumentNullException(nameof(globalSettings.Amazon.Region)); - } + _client = amazonSqs; + } - _client = amazonSqs; + public void Dispose() + { + _client?.Dispose(); + } + + public async Task BlockIpAsync(string ipAddress, bool permanentBlock) + { + var now = DateTime.UtcNow; + if (_lastBlock != null && _lastBlock.Item1 == ipAddress && _lastBlock.Item2 == permanentBlock && + (now - _lastBlock.Item3) < TimeSpan.FromMinutes(1)) + { + // Already blocked this IP recently. + return; } - public void Dispose() + _lastBlock = new Tuple(ipAddress, permanentBlock, now); + await _client.SendMessageAsync(_blockIpQueueUrl, ipAddress); + if (!permanentBlock) { - _client?.Dispose(); - } - - public async Task BlockIpAsync(string ipAddress, bool permanentBlock) - { - var now = DateTime.UtcNow; - if (_lastBlock != null && _lastBlock.Item1 == ipAddress && _lastBlock.Item2 == permanentBlock && - (now - _lastBlock.Item3) < TimeSpan.FromMinutes(1)) - { - // Already blocked this IP recently. - return; - } - - _lastBlock = new Tuple(ipAddress, permanentBlock, now); - await _client.SendMessageAsync(_blockIpQueueUrl, ipAddress); - if (!permanentBlock) - { - await _client.SendMessageAsync(_unblockIpQueueUrl, ipAddress); - } - } - - private async Task InitAsync() - { - if (_didInit) - { - return; - } - - var blockIpQueue = await _client.GetQueueUrlAsync("block-ip"); - _blockIpQueueUrl = blockIpQueue.QueueUrl; - var unblockIpQueue = await _client.GetQueueUrlAsync("unblock-ip"); - _unblockIpQueueUrl = unblockIpQueue.QueueUrl; - _didInit = true; + await _client.SendMessageAsync(_unblockIpQueueUrl, ipAddress); } } + + private async Task InitAsync() + { + if (_didInit) + { + return; + } + + var blockIpQueue = await _client.GetQueueUrlAsync("block-ip"); + _blockIpQueueUrl = blockIpQueue.QueueUrl; + var unblockIpQueue = await _client.GetQueueUrlAsync("unblock-ip"); + _unblockIpQueueUrl = unblockIpQueue.QueueUrl; + _didInit = true; + } } diff --git a/src/Core/Services/Implementations/AppleIapService.cs b/src/Core/Services/Implementations/AppleIapService.cs index 2fa8edfd7..35cd2ac11 100644 --- a/src/Core/Services/Implementations/AppleIapService.cs +++ b/src/Core/Services/Implementations/AppleIapService.cs @@ -7,127 +7,126 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class AppleIapService : IAppleIapService { - public class AppleIapService : IAppleIapService + private readonly HttpClient _httpClient = new HttpClient(); + + private readonly GlobalSettings _globalSettings; + private readonly IWebHostEnvironment _hostingEnvironment; + private readonly IMetaDataRepository _metaDataRespository; + private readonly ILogger _logger; + + public AppleIapService( + GlobalSettings globalSettings, + IWebHostEnvironment hostingEnvironment, + IMetaDataRepository metaDataRespository, + ILogger logger) { - private readonly HttpClient _httpClient = new HttpClient(); + _globalSettings = globalSettings; + _hostingEnvironment = hostingEnvironment; + _metaDataRespository = metaDataRespository; + _logger = logger; + } - private readonly GlobalSettings _globalSettings; - private readonly IWebHostEnvironment _hostingEnvironment; - private readonly IMetaDataRepository _metaDataRespository; - private readonly ILogger _logger; - - public AppleIapService( - GlobalSettings globalSettings, - IWebHostEnvironment hostingEnvironment, - IMetaDataRepository metaDataRespository, - ILogger logger) + public async Task GetVerifiedReceiptStatusAsync(string receiptData) + { + var receiptStatus = await GetReceiptStatusAsync(receiptData); + if (receiptStatus?.Status != 0) { - _globalSettings = globalSettings; - _hostingEnvironment = hostingEnvironment; - _metaDataRespository = metaDataRespository; - _logger = logger; + return null; } - - public async Task GetVerifiedReceiptStatusAsync(string receiptData) + var validEnvironment = _globalSettings.AppleIap.AppInReview || + (!(_hostingEnvironment.IsProduction() || _hostingEnvironment.IsEnvironment("QA")) && receiptStatus.Environment == "Sandbox") || + ((_hostingEnvironment.IsProduction() || _hostingEnvironment.IsEnvironment("QA")) && receiptStatus.Environment != "Sandbox"); + var validProductBundle = receiptStatus.Receipt.BundleId == "com.bitwarden.desktop" || + receiptStatus.Receipt.BundleId == "com.8bit.bitwarden"; + var validProduct = receiptStatus.LatestReceiptInfo.LastOrDefault()?.ProductId == "premium_annually"; + var validIds = receiptStatus.GetOriginalTransactionId() != null && + receiptStatus.GetLastTransactionId() != null; + var validTransaction = receiptStatus.GetLastExpiresDate() + .GetValueOrDefault(DateTime.MinValue) > DateTime.UtcNow; + if (validEnvironment && validProductBundle && validProduct && validIds && validTransaction) { - var receiptStatus = await GetReceiptStatusAsync(receiptData); - if (receiptStatus?.Status != 0) + return receiptStatus; + } + return null; + } + + public async Task SaveReceiptAsync(AppleReceiptStatus receiptStatus, Guid userId) + { + var originalTransactionId = receiptStatus.GetOriginalTransactionId(); + if (string.IsNullOrWhiteSpace(originalTransactionId)) + { + throw new Exception("OriginalTransactionId is null"); + } + await _metaDataRespository.UpsertAsync("AppleReceipt", originalTransactionId, + new Dictionary { - return null; + ["Data"] = receiptStatus.GetReceiptData(), + ["UserId"] = userId.ToString() + }); + } + + public async Task> GetReceiptAsync(string originalTransactionId) + { + var receipt = await _metaDataRespository.GetAsync("AppleReceipt", originalTransactionId); + if (receipt == null) + { + return null; + } + return new Tuple(receipt.ContainsKey("Data") ? receipt["Data"] : null, + receipt.ContainsKey("UserId") ? new Guid(receipt["UserId"]) : (Guid?)null); + } + + // Internal for testing + internal async Task GetReceiptStatusAsync(string receiptData, bool prod = true, + int attempt = 0, AppleReceiptStatus lastReceiptStatus = null) + { + try + { + if (attempt > 4) + { + throw new Exception( + $"Failed verifying Apple IAP after too many attempts. Last attempt status: {lastReceiptStatus?.Status.ToString() ?? "null"}"); } - var validEnvironment = _globalSettings.AppleIap.AppInReview || - (!(_hostingEnvironment.IsProduction() || _hostingEnvironment.IsEnvironment("QA")) && receiptStatus.Environment == "Sandbox") || - ((_hostingEnvironment.IsProduction() || _hostingEnvironment.IsEnvironment("QA")) && receiptStatus.Environment != "Sandbox"); - var validProductBundle = receiptStatus.Receipt.BundleId == "com.bitwarden.desktop" || - receiptStatus.Receipt.BundleId == "com.8bit.bitwarden"; - var validProduct = receiptStatus.LatestReceiptInfo.LastOrDefault()?.ProductId == "premium_annually"; - var validIds = receiptStatus.GetOriginalTransactionId() != null && - receiptStatus.GetLastTransactionId() != null; - var validTransaction = receiptStatus.GetLastExpiresDate() - .GetValueOrDefault(DateTime.MinValue) > DateTime.UtcNow; - if (validEnvironment && validProductBundle && validProduct && validIds && validTransaction) + + var url = string.Format("https://{0}.itunes.apple.com/verifyReceipt", prod ? "buy" : "sandbox"); + + var response = await _httpClient.PostAsJsonAsync(url, new AppleVerifyReceiptRequestModel { + ReceiptData = receiptData, + Password = _globalSettings.AppleIap.Password + }); + + if (response.IsSuccessStatusCode) + { + var receiptStatus = await response.Content.ReadFromJsonAsync(); + if (receiptStatus.Status == 21007) + { + return await GetReceiptStatusAsync(receiptData, false, attempt + 1, receiptStatus); + } + else if (receiptStatus.Status == 21005) + { + await Task.Delay(2000); + return await GetReceiptStatusAsync(receiptData, prod, attempt + 1, receiptStatus); + } return receiptStatus; } - return null; } - - public async Task SaveReceiptAsync(AppleReceiptStatus receiptStatus, Guid userId) + catch (Exception e) { - var originalTransactionId = receiptStatus.GetOriginalTransactionId(); - if (string.IsNullOrWhiteSpace(originalTransactionId)) - { - throw new Exception("OriginalTransactionId is null"); - } - await _metaDataRespository.UpsertAsync("AppleReceipt", originalTransactionId, - new Dictionary - { - ["Data"] = receiptStatus.GetReceiptData(), - ["UserId"] = userId.ToString() - }); + _logger.LogWarning(e, "Error verifying Apple IAP receipt."); } - - public async Task> GetReceiptAsync(string originalTransactionId) - { - var receipt = await _metaDataRespository.GetAsync("AppleReceipt", originalTransactionId); - if (receipt == null) - { - return null; - } - return new Tuple(receipt.ContainsKey("Data") ? receipt["Data"] : null, - receipt.ContainsKey("UserId") ? new Guid(receipt["UserId"]) : (Guid?)null); - } - - // Internal for testing - internal async Task GetReceiptStatusAsync(string receiptData, bool prod = true, - int attempt = 0, AppleReceiptStatus lastReceiptStatus = null) - { - try - { - if (attempt > 4) - { - throw new Exception( - $"Failed verifying Apple IAP after too many attempts. Last attempt status: {lastReceiptStatus?.Status.ToString() ?? "null"}"); - } - - var url = string.Format("https://{0}.itunes.apple.com/verifyReceipt", prod ? "buy" : "sandbox"); - - var response = await _httpClient.PostAsJsonAsync(url, new AppleVerifyReceiptRequestModel - { - ReceiptData = receiptData, - Password = _globalSettings.AppleIap.Password - }); - - if (response.IsSuccessStatusCode) - { - var receiptStatus = await response.Content.ReadFromJsonAsync(); - if (receiptStatus.Status == 21007) - { - return await GetReceiptStatusAsync(receiptData, false, attempt + 1, receiptStatus); - } - else if (receiptStatus.Status == 21005) - { - await Task.Delay(2000); - return await GetReceiptStatusAsync(receiptData, prod, attempt + 1, receiptStatus); - } - return receiptStatus; - } - } - catch (Exception e) - { - _logger.LogWarning(e, "Error verifying Apple IAP receipt."); - } - return null; - } - } - - public class AppleVerifyReceiptRequestModel - { - [JsonPropertyName("receipt-data")] - public string ReceiptData { get; set; } - [JsonPropertyName("password")] - public string Password { get; set; } + return null; } } + +public class AppleVerifyReceiptRequestModel +{ + [JsonPropertyName("receipt-data")] + public string ReceiptData { get; set; } + [JsonPropertyName("password")] + public string Password { get; set; } +} diff --git a/src/Core/Services/Implementations/AzureAttachmentStorageService.cs b/src/Core/Services/Implementations/AzureAttachmentStorageService.cs index 6a9e8f77f..edc35e03a 100644 --- a/src/Core/Services/Implementations/AzureAttachmentStorageService.cs +++ b/src/Core/Services/Implementations/AzureAttachmentStorageService.cs @@ -7,260 +7,259 @@ using Bit.Core.Models.Data; using Bit.Core.Settings; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class AzureAttachmentStorageService : IAttachmentStorageService { - public class AzureAttachmentStorageService : IAttachmentStorageService + public FileUploadType FileUploadType => FileUploadType.Azure; + public const string EventGridEnabledContainerName = "attachments-v2"; + private const string _defaultContainerName = "attachments"; + private readonly static string[] _attachmentContainerName = { "attachments", "attachments-v2" }; + private static readonly TimeSpan blobLinkLiveTime = TimeSpan.FromMinutes(1); + private readonly BlobServiceClient _blobServiceClient; + private readonly Dictionary _attachmentContainers = new Dictionary(); + private readonly ILogger _logger; + + private string BlobName(Guid cipherId, CipherAttachment.MetaData attachmentData, Guid? organizationId = null, bool temp = false) => + string.Concat( + temp ? "temp/" : "", + $"{cipherId}/", + organizationId != null ? $"{organizationId.Value}/" : "", + attachmentData.AttachmentId + ); + + public static (string cipherId, string organizationId, string attachmentId) IdentifiersFromBlobName(string blobName) { - public FileUploadType FileUploadType => FileUploadType.Azure; - public const string EventGridEnabledContainerName = "attachments-v2"; - private const string _defaultContainerName = "attachments"; - private readonly static string[] _attachmentContainerName = { "attachments", "attachments-v2" }; - private static readonly TimeSpan blobLinkLiveTime = TimeSpan.FromMinutes(1); - private readonly BlobServiceClient _blobServiceClient; - private readonly Dictionary _attachmentContainers = new Dictionary(); - private readonly ILogger _logger; - - private string BlobName(Guid cipherId, CipherAttachment.MetaData attachmentData, Guid? organizationId = null, bool temp = false) => - string.Concat( - temp ? "temp/" : "", - $"{cipherId}/", - organizationId != null ? $"{organizationId.Value}/" : "", - attachmentData.AttachmentId - ); - - public static (string cipherId, string organizationId, string attachmentId) IdentifiersFromBlobName(string blobName) + var parts = blobName.Split('/'); + switch (parts.Length) { - var parts = blobName.Split('/'); - switch (parts.Length) - { - case 4: - return (parts[1], parts[2], parts[3]); - case 3: - if (parts[0] == "temp") - { - return (parts[1], null, parts[2]); - } - else - { - return (parts[0], parts[1], parts[2]); - } - case 2: - return (parts[0], null, parts[1]); - default: - throw new Exception("Cannot determine cipher information from blob name"); - } + case 4: + return (parts[1], parts[2], parts[3]); + case 3: + if (parts[0] == "temp") + { + return (parts[1], null, parts[2]); + } + else + { + return (parts[0], parts[1], parts[2]); + } + case 2: + return (parts[0], null, parts[1]); + default: + throw new Exception("Cannot determine cipher information from blob name"); + } + } + + public AzureAttachmentStorageService( + GlobalSettings globalSettings, + ILogger logger) + { + _blobServiceClient = new BlobServiceClient(globalSettings.Attachment.ConnectionString); + _logger = logger; + } + + public async Task GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) + { + await InitAsync(attachmentData.ContainerName); + var blobClient = _attachmentContainers[attachmentData.ContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData)); + var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Read, DateTime.UtcNow.Add(blobLinkLiveTime)); + return sasUri.ToString(); + } + + public async Task GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) + { + await InitAsync(EventGridEnabledContainerName); + var blobClient = _attachmentContainers[EventGridEnabledContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData)); + attachmentData.ContainerName = EventGridEnabledContainerName; + var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Create | BlobSasPermissions.Write, DateTime.UtcNow.Add(blobLinkLiveTime)); + return sasUri.ToString(); + } + + public async Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData) + { + attachmentData.ContainerName = _defaultContainerName; + await InitAsync(_defaultContainerName); + var blobClient = _attachmentContainers[_defaultContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData)); + + var metadata = new Dictionary(); + metadata.Add("cipherId", cipher.Id.ToString()); + if (cipher.UserId.HasValue) + { + metadata.Add("userId", cipher.UserId.Value.ToString()); + } + else + { + metadata.Add("organizationId", cipher.OrganizationId.Value.ToString()); } - public AzureAttachmentStorageService( - GlobalSettings globalSettings, - ILogger logger) + var headers = new BlobHttpHeaders { - _blobServiceClient = new BlobServiceClient(globalSettings.Attachment.ConnectionString); - _logger = logger; + ContentDisposition = $"attachment; filename=\"{attachmentData.AttachmentId}\"" + }; + await blobClient.UploadAsync(stream, new BlobUploadOptions { Metadata = metadata, HttpHeaders = headers }); + } + + public async Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) + { + attachmentData.ContainerName = _defaultContainerName; + await InitAsync(_defaultContainerName); + var blobClient = _attachmentContainers[_defaultContainerName].GetBlobClient( + BlobName(cipherId, attachmentData, organizationId, temp: true)); + + var metadata = new Dictionary(); + metadata.Add("cipherId", cipherId.ToString()); + metadata.Add("organizationId", organizationId.ToString()); + + var headers = new BlobHttpHeaders + { + ContentDisposition = $"attachment; filename=\"{attachmentData.AttachmentId}\"" + }; + await blobClient.UploadAsync(stream, new BlobUploadOptions { Metadata = metadata, HttpHeaders = headers }); + } + + public async Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData data) + { + await InitAsync(data.ContainerName); + var source = _attachmentContainers[data.ContainerName].GetBlobClient( + BlobName(cipherId, data, organizationId, temp: true)); + if (!await source.ExistsAsync()) + { + return; } - public async Task GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) + await InitAsync(_defaultContainerName); + var dest = _attachmentContainers[_defaultContainerName].GetBlobClient(BlobName(cipherId, data)); + if (!await dest.ExistsAsync()) { - await InitAsync(attachmentData.ContainerName); - var blobClient = _attachmentContainers[attachmentData.ContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData)); - var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Read, DateTime.UtcNow.Add(blobLinkLiveTime)); - return sasUri.ToString(); + return; } - public async Task GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) + var original = _attachmentContainers[_defaultContainerName].GetBlobClient( + BlobName(cipherId, data, temp: true)); + await original.DeleteIfExistsAsync(); + await original.StartCopyFromUriAsync(dest.Uri); + + await dest.DeleteIfExistsAsync(); + await dest.StartCopyFromUriAsync(source.Uri); + } + + public async Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer) + { + await InitAsync(attachmentData.ContainerName); + var source = _attachmentContainers[attachmentData.ContainerName].GetBlobClient( + BlobName(cipherId, attachmentData, organizationId, temp: true)); + await source.DeleteIfExistsAsync(); + + await InitAsync(originalContainer); + var original = _attachmentContainers[originalContainer].GetBlobClient( + BlobName(cipherId, attachmentData, temp: true)); + if (!await original.ExistsAsync()) { - await InitAsync(EventGridEnabledContainerName); - var blobClient = _attachmentContainers[EventGridEnabledContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData)); - attachmentData.ContainerName = EventGridEnabledContainerName; - var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Create | BlobSasPermissions.Write, DateTime.UtcNow.Add(blobLinkLiveTime)); - return sasUri.ToString(); + return; } - public async Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData) - { - attachmentData.ContainerName = _defaultContainerName; - await InitAsync(_defaultContainerName); - var blobClient = _attachmentContainers[_defaultContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData)); + var dest = _attachmentContainers[originalContainer].GetBlobClient( + BlobName(cipherId, attachmentData)); + await dest.DeleteIfExistsAsync(); + await dest.StartCopyFromUriAsync(original.Uri); + await original.DeleteIfExistsAsync(); + } - var metadata = new Dictionary(); - metadata.Add("cipherId", cipher.Id.ToString()); + public async Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData) + { + await InitAsync(attachmentData.ContainerName); + var blobClient = _attachmentContainers[attachmentData.ContainerName].GetBlobClient( + BlobName(cipherId, attachmentData)); + await blobClient.DeleteIfExistsAsync(); + } + + public async Task CleanupAsync(Guid cipherId) => await DeleteAttachmentsForPathAsync($"temp/{cipherId}"); + + public async Task DeleteAttachmentsForCipherAsync(Guid cipherId) => + await DeleteAttachmentsForPathAsync(cipherId.ToString()); + + public async Task DeleteAttachmentsForOrganizationAsync(Guid organizationId) + { + await InitAsync(_defaultContainerName); + } + + public async Task DeleteAttachmentsForUserAsync(Guid userId) + { + await InitAsync(_defaultContainerName); + } + + public async Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway) + { + await InitAsync(attachmentData.ContainerName); + + var blobClient = _attachmentContainers[attachmentData.ContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData)); + + try + { + var blobProperties = await blobClient.GetPropertiesAsync(); + + var metadata = blobProperties.Value.Metadata; + metadata["cipherId"] = cipher.Id.ToString(); if (cipher.UserId.HasValue) { - metadata.Add("userId", cipher.UserId.Value.ToString()); + metadata["userId"] = cipher.UserId.Value.ToString(); } else { - metadata.Add("organizationId", cipher.OrganizationId.Value.ToString()); + metadata["organizationId"] = cipher.OrganizationId.Value.ToString(); } + await blobClient.SetMetadataAsync(metadata); var headers = new BlobHttpHeaders { ContentDisposition = $"attachment; filename=\"{attachmentData.AttachmentId}\"" }; - await blobClient.UploadAsync(stream, new BlobUploadOptions { Metadata = metadata, HttpHeaders = headers }); - } + await blobClient.SetHttpHeadersAsync(headers); - public async Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) - { - attachmentData.ContainerName = _defaultContainerName; - await InitAsync(_defaultContainerName); - var blobClient = _attachmentContainers[_defaultContainerName].GetBlobClient( - BlobName(cipherId, attachmentData, organizationId, temp: true)); - - var metadata = new Dictionary(); - metadata.Add("cipherId", cipherId.ToString()); - metadata.Add("organizationId", organizationId.ToString()); - - var headers = new BlobHttpHeaders + var length = blobProperties.Value.ContentLength; + if (length < attachmentData.Size - leeway || length > attachmentData.Size + leeway) { - ContentDisposition = $"attachment; filename=\"{attachmentData.AttachmentId}\"" - }; - await blobClient.UploadAsync(stream, new BlobUploadOptions { Metadata = metadata, HttpHeaders = headers }); - } - - public async Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData data) - { - await InitAsync(data.ContainerName); - var source = _attachmentContainers[data.ContainerName].GetBlobClient( - BlobName(cipherId, data, organizationId, temp: true)); - if (!await source.ExistsAsync()) - { - return; + return (false, length); } - await InitAsync(_defaultContainerName); - var dest = _attachmentContainers[_defaultContainerName].GetBlobClient(BlobName(cipherId, data)); - if (!await dest.ExistsAsync()) + return (true, length); + } + catch (Exception ex) + { + _logger.LogError(ex, "Unhandled error in ValidateFileAsync"); + return (false, null); + } + } + + private async Task DeleteAttachmentsForPathAsync(string path) + { + foreach (var container in _attachmentContainerName) + { + await InitAsync(container); + var blobContainerClient = _attachmentContainers[container]; + + var blobItems = blobContainerClient.GetBlobsAsync(BlobTraits.None, BlobStates.None, prefix: path); + await foreach (var blobItem in blobItems) { - return; - } - - var original = _attachmentContainers[_defaultContainerName].GetBlobClient( - BlobName(cipherId, data, temp: true)); - await original.DeleteIfExistsAsync(); - await original.StartCopyFromUriAsync(dest.Uri); - - await dest.DeleteIfExistsAsync(); - await dest.StartCopyFromUriAsync(source.Uri); - } - - public async Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer) - { - await InitAsync(attachmentData.ContainerName); - var source = _attachmentContainers[attachmentData.ContainerName].GetBlobClient( - BlobName(cipherId, attachmentData, organizationId, temp: true)); - await source.DeleteIfExistsAsync(); - - await InitAsync(originalContainer); - var original = _attachmentContainers[originalContainer].GetBlobClient( - BlobName(cipherId, attachmentData, temp: true)); - if (!await original.ExistsAsync()) - { - return; - } - - var dest = _attachmentContainers[originalContainer].GetBlobClient( - BlobName(cipherId, attachmentData)); - await dest.DeleteIfExistsAsync(); - await dest.StartCopyFromUriAsync(original.Uri); - await original.DeleteIfExistsAsync(); - } - - public async Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData) - { - await InitAsync(attachmentData.ContainerName); - var blobClient = _attachmentContainers[attachmentData.ContainerName].GetBlobClient( - BlobName(cipherId, attachmentData)); - await blobClient.DeleteIfExistsAsync(); - } - - public async Task CleanupAsync(Guid cipherId) => await DeleteAttachmentsForPathAsync($"temp/{cipherId}"); - - public async Task DeleteAttachmentsForCipherAsync(Guid cipherId) => - await DeleteAttachmentsForPathAsync(cipherId.ToString()); - - public async Task DeleteAttachmentsForOrganizationAsync(Guid organizationId) - { - await InitAsync(_defaultContainerName); - } - - public async Task DeleteAttachmentsForUserAsync(Guid userId) - { - await InitAsync(_defaultContainerName); - } - - public async Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway) - { - await InitAsync(attachmentData.ContainerName); - - var blobClient = _attachmentContainers[attachmentData.ContainerName].GetBlobClient(BlobName(cipher.Id, attachmentData)); - - try - { - var blobProperties = await blobClient.GetPropertiesAsync(); - - var metadata = blobProperties.Value.Metadata; - metadata["cipherId"] = cipher.Id.ToString(); - if (cipher.UserId.HasValue) - { - metadata["userId"] = cipher.UserId.Value.ToString(); - } - else - { - metadata["organizationId"] = cipher.OrganizationId.Value.ToString(); - } - await blobClient.SetMetadataAsync(metadata); - - var headers = new BlobHttpHeaders - { - ContentDisposition = $"attachment; filename=\"{attachmentData.AttachmentId}\"" - }; - await blobClient.SetHttpHeadersAsync(headers); - - var length = blobProperties.Value.ContentLength; - if (length < attachmentData.Size - leeway || length > attachmentData.Size + leeway) - { - return (false, length); - } - - return (true, length); - } - catch (Exception ex) - { - _logger.LogError(ex, "Unhandled error in ValidateFileAsync"); - return (false, null); + BlobClient blobClient = blobContainerClient.GetBlobClient(blobItem.Name); + await blobClient.DeleteIfExistsAsync(); } } + } - private async Task DeleteAttachmentsForPathAsync(string path) + private async Task InitAsync(string containerName) + { + if (!_attachmentContainers.ContainsKey(containerName) || _attachmentContainers[containerName] == null) { - foreach (var container in _attachmentContainerName) + _attachmentContainers[containerName] = _blobServiceClient.GetBlobContainerClient(containerName); + if (containerName == "attachments") { - await InitAsync(container); - var blobContainerClient = _attachmentContainers[container]; - - var blobItems = blobContainerClient.GetBlobsAsync(BlobTraits.None, BlobStates.None, prefix: path); - await foreach (var blobItem in blobItems) - { - BlobClient blobClient = blobContainerClient.GetBlobClient(blobItem.Name); - await blobClient.DeleteIfExistsAsync(); - } + await _attachmentContainers[containerName].CreateIfNotExistsAsync(PublicAccessType.Blob, null, null); } - } - - private async Task InitAsync(string containerName) - { - if (!_attachmentContainers.ContainsKey(containerName) || _attachmentContainers[containerName] == null) + else { - _attachmentContainers[containerName] = _blobServiceClient.GetBlobContainerClient(containerName); - if (containerName == "attachments") - { - await _attachmentContainers[containerName].CreateIfNotExistsAsync(PublicAccessType.Blob, null, null); - } - else - { - await _attachmentContainers[containerName].CreateIfNotExistsAsync(PublicAccessType.None, null, null); - } + await _attachmentContainers[containerName].CreateIfNotExistsAsync(PublicAccessType.None, null, null); } } } diff --git a/src/Core/Services/Implementations/AzureQueueBlockIpService.cs b/src/Core/Services/Implementations/AzureQueueBlockIpService.cs index 8682b0c49..ab78c4654 100644 --- a/src/Core/Services/Implementations/AzureQueueBlockIpService.cs +++ b/src/Core/Services/Implementations/AzureQueueBlockIpService.cs @@ -1,37 +1,36 @@ using Azure.Storage.Queues; using Bit.Core.Settings; -namespace Bit.Core.Services -{ - public class AzureQueueBlockIpService : IBlockIpService - { - private readonly QueueClient _blockIpQueueClient; - private readonly QueueClient _unblockIpQueueClient; - private Tuple _lastBlock; +namespace Bit.Core.Services; - public AzureQueueBlockIpService( - GlobalSettings globalSettings) +public class AzureQueueBlockIpService : IBlockIpService +{ + private readonly QueueClient _blockIpQueueClient; + private readonly QueueClient _unblockIpQueueClient; + private Tuple _lastBlock; + + public AzureQueueBlockIpService( + GlobalSettings globalSettings) + { + _blockIpQueueClient = new QueueClient(globalSettings.Storage.ConnectionString, "blockip"); + _unblockIpQueueClient = new QueueClient(globalSettings.Storage.ConnectionString, "unblockip"); + } + + public async Task BlockIpAsync(string ipAddress, bool permanentBlock) + { + var now = DateTime.UtcNow; + if (_lastBlock != null && _lastBlock.Item1 == ipAddress && _lastBlock.Item2 == permanentBlock && + (now - _lastBlock.Item3) < TimeSpan.FromMinutes(1)) { - _blockIpQueueClient = new QueueClient(globalSettings.Storage.ConnectionString, "blockip"); - _unblockIpQueueClient = new QueueClient(globalSettings.Storage.ConnectionString, "unblockip"); + // Already blocked this IP recently. + return; } - public async Task BlockIpAsync(string ipAddress, bool permanentBlock) + _lastBlock = new Tuple(ipAddress, permanentBlock, now); + await _blockIpQueueClient.SendMessageAsync(ipAddress); + if (!permanentBlock) { - var now = DateTime.UtcNow; - if (_lastBlock != null && _lastBlock.Item1 == ipAddress && _lastBlock.Item2 == permanentBlock && - (now - _lastBlock.Item3) < TimeSpan.FromMinutes(1)) - { - // Already blocked this IP recently. - return; - } - - _lastBlock = new Tuple(ipAddress, permanentBlock, now); - await _blockIpQueueClient.SendMessageAsync(ipAddress); - if (!permanentBlock) - { - await _unblockIpQueueClient.SendMessageAsync(ipAddress, new TimeSpan(0, 15, 0)); - } + await _unblockIpQueueClient.SendMessageAsync(ipAddress, new TimeSpan(0, 15, 0)); } } } diff --git a/src/Core/Services/Implementations/AzureQueueEventWriteService.cs b/src/Core/Services/Implementations/AzureQueueEventWriteService.cs index bf74677f1..f81175f7b 100644 --- a/src/Core/Services/Implementations/AzureQueueEventWriteService.cs +++ b/src/Core/Services/Implementations/AzureQueueEventWriteService.cs @@ -3,15 +3,14 @@ using Bit.Core.Models.Data; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Core.Services -{ - public class AzureQueueEventWriteService : AzureQueueService, IEventWriteService - { - public AzureQueueEventWriteService(GlobalSettings globalSettings) : base( - new QueueClient(globalSettings.Events.ConnectionString, "event"), - JsonHelpers.IgnoreWritingNull) - { } +namespace Bit.Core.Services; - public Task CreateAsync(IEvent e) => CreateManyAsync(new[] { e }); - } +public class AzureQueueEventWriteService : AzureQueueService, IEventWriteService +{ + public AzureQueueEventWriteService(GlobalSettings globalSettings) : base( + new QueueClient(globalSettings.Events.ConnectionString, "event"), + JsonHelpers.IgnoreWritingNull) + { } + + public Task CreateAsync(IEvent e) => CreateManyAsync(new[] { e }); } diff --git a/src/Core/Services/Implementations/AzureQueueMailService.cs b/src/Core/Services/Implementations/AzureQueueMailService.cs index e05c106ea..92d6fd17b 100644 --- a/src/Core/Services/Implementations/AzureQueueMailService.cs +++ b/src/Core/Services/Implementations/AzureQueueMailService.cs @@ -3,19 +3,18 @@ using Bit.Core.Models.Mail; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class AzureQueueMailService : AzureQueueService, IMailEnqueuingService { - public class AzureQueueMailService : AzureQueueService, IMailEnqueuingService - { - public AzureQueueMailService(GlobalSettings globalSettings) : base( - new QueueClient(globalSettings.Mail.ConnectionString, "mail"), - JsonHelpers.IgnoreWritingNull) - { } + public AzureQueueMailService(GlobalSettings globalSettings) : base( + new QueueClient(globalSettings.Mail.ConnectionString, "mail"), + JsonHelpers.IgnoreWritingNull) + { } - public Task EnqueueAsync(IMailQueueMessage message, Func fallback) => - CreateManyAsync(new[] { message }); + public Task EnqueueAsync(IMailQueueMessage message, Func fallback) => + CreateManyAsync(new[] { message }); - public Task EnqueueManyAsync(IEnumerable messages, Func fallback) => - CreateManyAsync(messages); - } + public Task EnqueueManyAsync(IEnumerable messages, Func fallback) => + CreateManyAsync(messages); } diff --git a/src/Core/Services/Implementations/AzureQueuePushNotificationService.cs b/src/Core/Services/Implementations/AzureQueuePushNotificationService.cs index 7062c6c18..fb7bcafca 100644 --- a/src/Core/Services/Implementations/AzureQueuePushNotificationService.cs +++ b/src/Core/Services/Implementations/AzureQueuePushNotificationService.cs @@ -8,190 +8,189 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.AspNetCore.Http; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class AzureQueuePushNotificationService : IPushNotificationService { - public class AzureQueuePushNotificationService : IPushNotificationService + private readonly QueueClient _queueClient; + private readonly GlobalSettings _globalSettings; + private readonly IHttpContextAccessor _httpContextAccessor; + + public AzureQueuePushNotificationService( + GlobalSettings globalSettings, + IHttpContextAccessor httpContextAccessor) { - private readonly QueueClient _queueClient; - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; + _queueClient = new QueueClient(globalSettings.Notifications.ConnectionString, "notifications"); + _globalSettings = globalSettings; + _httpContextAccessor = httpContextAccessor; + } - public AzureQueuePushNotificationService( - GlobalSettings globalSettings, - IHttpContextAccessor httpContextAccessor) - { - _queueClient = new QueueClient(globalSettings.Notifications.ConnectionString, "notifications"); - _globalSettings = globalSettings; - _httpContextAccessor = httpContextAccessor; - } + public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) + { + await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds); + } - public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) - { - await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds); - } + public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) + { + await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds); + } - public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) - { - await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds); - } + public async Task PushSyncCipherDeleteAsync(Cipher cipher) + { + await PushCipherAsync(cipher, PushType.SyncLoginDelete, null); + } - public async Task PushSyncCipherDeleteAsync(Cipher cipher) + private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable collectionIds) + { + if (cipher.OrganizationId.HasValue) { - await PushCipherAsync(cipher, PushType.SyncLoginDelete, null); - } - - private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable collectionIds) - { - if (cipher.OrganizationId.HasValue) + var message = new SyncCipherPushNotification { - var message = new SyncCipherPushNotification - { - Id = cipher.Id, - OrganizationId = cipher.OrganizationId, - RevisionDate = cipher.RevisionDate, - CollectionIds = collectionIds, - }; - - await SendMessageAsync(type, message, true); - } - else if (cipher.UserId.HasValue) - { - var message = new SyncCipherPushNotification - { - Id = cipher.Id, - UserId = cipher.UserId, - RevisionDate = cipher.RevisionDate, - }; - - await SendMessageAsync(type, message, true); - } - } - - public async Task PushSyncFolderCreateAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderCreate); - } - - public async Task PushSyncFolderUpdateAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderUpdate); - } - - public async Task PushSyncFolderDeleteAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderDelete); - } - - private async Task PushFolderAsync(Folder folder, PushType type) - { - var message = new SyncFolderPushNotification - { - Id = folder.Id, - UserId = folder.UserId, - RevisionDate = folder.RevisionDate + Id = cipher.Id, + OrganizationId = cipher.OrganizationId, + RevisionDate = cipher.RevisionDate, + CollectionIds = collectionIds, }; await SendMessageAsync(type, message, true); } - - public async Task PushSyncCiphersAsync(Guid userId) + else if (cipher.UserId.HasValue) { - await PushUserAsync(userId, PushType.SyncCiphers); - } - - public async Task PushSyncVaultAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncVault); - } - - public async Task PushSyncOrgKeysAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncOrgKeys); - } - - public async Task PushSyncSettingsAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncSettings); - } - - public async Task PushLogOutAsync(Guid userId) - { - await PushUserAsync(userId, PushType.LogOut); - } - - private async Task PushUserAsync(Guid userId, PushType type) - { - var message = new UserPushNotification + var message = new SyncCipherPushNotification { - UserId = userId, - Date = DateTime.UtcNow + Id = cipher.Id, + UserId = cipher.UserId, + RevisionDate = cipher.RevisionDate, }; - await SendMessageAsync(type, message, false); - } - - public async Task PushSyncSendCreateAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendCreate); - } - - public async Task PushSyncSendUpdateAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendUpdate); - } - - public async Task PushSyncSendDeleteAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendDelete); - } - - private async Task PushSendAsync(Send send, PushType type) - { - if (send.UserId.HasValue) - { - var message = new SyncSendPushNotification - { - Id = send.Id, - UserId = send.UserId.Value, - RevisionDate = send.RevisionDate - }; - - await SendMessageAsync(type, message, true); - } - } - - private async Task SendMessageAsync(PushType type, T payload, bool excludeCurrentContext) - { - var contextId = GetContextIdentifier(excludeCurrentContext); - var message = JsonSerializer.Serialize(new PushNotificationData(type, payload, contextId), - JsonHelpers.IgnoreWritingNull); - await _queueClient.SendMessageAsync(message); - } - - private string GetContextIdentifier(bool excludeCurrentContext) - { - if (!excludeCurrentContext) - { - return null; - } - - var currentContext = _httpContextAccessor?.HttpContext?. - RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; - return currentContext?.DeviceIdentifier; - } - - public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, - string deviceId = null) - { - // Noop - return Task.FromResult(0); - } - - public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null) - { - // Noop - return Task.FromResult(0); + await SendMessageAsync(type, message, true); } } + + public async Task PushSyncFolderCreateAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderCreate); + } + + public async Task PushSyncFolderUpdateAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderUpdate); + } + + public async Task PushSyncFolderDeleteAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderDelete); + } + + private async Task PushFolderAsync(Folder folder, PushType type) + { + var message = new SyncFolderPushNotification + { + Id = folder.Id, + UserId = folder.UserId, + RevisionDate = folder.RevisionDate + }; + + await SendMessageAsync(type, message, true); + } + + public async Task PushSyncCiphersAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncCiphers); + } + + public async Task PushSyncVaultAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncVault); + } + + public async Task PushSyncOrgKeysAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncOrgKeys); + } + + public async Task PushSyncSettingsAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncSettings); + } + + public async Task PushLogOutAsync(Guid userId) + { + await PushUserAsync(userId, PushType.LogOut); + } + + private async Task PushUserAsync(Guid userId, PushType type) + { + var message = new UserPushNotification + { + UserId = userId, + Date = DateTime.UtcNow + }; + + await SendMessageAsync(type, message, false); + } + + public async Task PushSyncSendCreateAsync(Send send) + { + await PushSendAsync(send, PushType.SyncSendCreate); + } + + public async Task PushSyncSendUpdateAsync(Send send) + { + await PushSendAsync(send, PushType.SyncSendUpdate); + } + + public async Task PushSyncSendDeleteAsync(Send send) + { + await PushSendAsync(send, PushType.SyncSendDelete); + } + + private async Task PushSendAsync(Send send, PushType type) + { + if (send.UserId.HasValue) + { + var message = new SyncSendPushNotification + { + Id = send.Id, + UserId = send.UserId.Value, + RevisionDate = send.RevisionDate + }; + + await SendMessageAsync(type, message, true); + } + } + + private async Task SendMessageAsync(PushType type, T payload, bool excludeCurrentContext) + { + var contextId = GetContextIdentifier(excludeCurrentContext); + var message = JsonSerializer.Serialize(new PushNotificationData(type, payload, contextId), + JsonHelpers.IgnoreWritingNull); + await _queueClient.SendMessageAsync(message); + } + + private string GetContextIdentifier(bool excludeCurrentContext) + { + if (!excludeCurrentContext) + { + return null; + } + + var currentContext = _httpContextAccessor?.HttpContext?. + RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; + return currentContext?.DeviceIdentifier; + } + + public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, + string deviceId = null) + { + // Noop + return Task.FromResult(0); + } + + public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, + string deviceId = null) + { + // Noop + return Task.FromResult(0); + } } diff --git a/src/Core/Services/Implementations/AzureQueueReferenceEventService.cs b/src/Core/Services/Implementations/AzureQueueReferenceEventService.cs index e3b0f0ecf..6abbe9783 100644 --- a/src/Core/Services/Implementations/AzureQueueReferenceEventService.cs +++ b/src/Core/Services/Implementations/AzureQueueReferenceEventService.cs @@ -5,45 +5,44 @@ using Bit.Core.Models.Business; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class AzureQueueReferenceEventService : IReferenceEventService { - public class AzureQueueReferenceEventService : IReferenceEventService + private const string _queueName = "reference-events"; + + private readonly QueueClient _queueClient; + private readonly GlobalSettings _globalSettings; + + public AzureQueueReferenceEventService( + GlobalSettings globalSettings) { - private const string _queueName = "reference-events"; + _queueClient = new QueueClient(globalSettings.Events.ConnectionString, _queueName); + _globalSettings = globalSettings; + } - private readonly QueueClient _queueClient; - private readonly GlobalSettings _globalSettings; + public async Task RaiseEventAsync(ReferenceEvent referenceEvent) + { + await SendMessageAsync(referenceEvent); + } - public AzureQueueReferenceEventService( - GlobalSettings globalSettings) + private async Task SendMessageAsync(ReferenceEvent referenceEvent) + { + if (_globalSettings.SelfHosted) { - _queueClient = new QueueClient(globalSettings.Events.ConnectionString, _queueName); - _globalSettings = globalSettings; + // Ignore for self-hosted + return; } - - public async Task RaiseEventAsync(ReferenceEvent referenceEvent) + try { - await SendMessageAsync(referenceEvent); + var message = JsonSerializer.Serialize(referenceEvent, JsonHelpers.IgnoreWritingNullAndCamelCase); + // Messages need to be base64 encoded + var encodedMessage = Convert.ToBase64String(Encoding.UTF8.GetBytes(message)); + await _queueClient.SendMessageAsync(encodedMessage); } - - private async Task SendMessageAsync(ReferenceEvent referenceEvent) + catch { - if (_globalSettings.SelfHosted) - { - // Ignore for self-hosted - return; - } - try - { - var message = JsonSerializer.Serialize(referenceEvent, JsonHelpers.IgnoreWritingNullAndCamelCase); - // Messages need to be base64 encoded - var encodedMessage = Convert.ToBase64String(Encoding.UTF8.GetBytes(message)); - await _queueClient.SendMessageAsync(encodedMessage); - } - catch - { - // Ignore failure - } + // Ignore failure } } } diff --git a/src/Core/Services/Implementations/AzureQueueService.cs b/src/Core/Services/Implementations/AzureQueueService.cs index 942be2680..11c1a58ae 100644 --- a/src/Core/Services/Implementations/AzureQueueService.cs +++ b/src/Core/Services/Implementations/AzureQueueService.cs @@ -3,76 +3,75 @@ using System.Text.Json; using Azure.Storage.Queues; using Bit.Core.Utilities; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public abstract class AzureQueueService { - public abstract class AzureQueueService + protected QueueClient _queueClient; + protected JsonSerializerOptions _jsonOptions; + + protected AzureQueueService(QueueClient queueClient, JsonSerializerOptions jsonOptions) { - protected QueueClient _queueClient; - protected JsonSerializerOptions _jsonOptions; + _queueClient = queueClient; + _jsonOptions = jsonOptions; + } - protected AzureQueueService(QueueClient queueClient, JsonSerializerOptions jsonOptions) + public async Task CreateManyAsync(IEnumerable messages) + { + if (messages?.Any() != true) { - _queueClient = queueClient; - _jsonOptions = jsonOptions; + return; } - public async Task CreateManyAsync(IEnumerable messages) + foreach (var json in SerializeMany(messages, _jsonOptions)) { - if (messages?.Any() != true) - { - return; - } + await _queueClient.SendMessageAsync(json); + } + } - foreach (var json in SerializeMany(messages, _jsonOptions)) + protected IEnumerable SerializeMany(IEnumerable messages, JsonSerializerOptions jsonOptions) + { + // Calculate Base-64 encoded text with padding + int getBase64Size(int byteCount) => ((4 * byteCount / 3) + 3) & ~3; + + var messagesList = new List(); + var messagesListSize = 0; + + int calculateByteSize(int totalSize, int toAdd) => + // Calculate the total length this would be w/ "[]" and commas + getBase64Size(totalSize + toAdd + messagesList.Count + 2); + + // Format the final array string, i.e. [{...},{...}] + string getArrayString() + { + if (messagesList.Count == 1) { - await _queueClient.SendMessageAsync(json); + return CoreHelpers.Base64EncodeString(messagesList[0]); } + return CoreHelpers.Base64EncodeString( + string.Concat("[", string.Join(',', messagesList), "]")); } - protected IEnumerable SerializeMany(IEnumerable messages, JsonSerializerOptions jsonOptions) + var serializedMessages = messages.Select(message => + JsonSerializer.Serialize(message, jsonOptions)); + + foreach (var message in serializedMessages) { - // Calculate Base-64 encoded text with padding - int getBase64Size(int byteCount) => ((4 * byteCount / 3) + 3) & ~3; - - var messagesList = new List(); - var messagesListSize = 0; - - int calculateByteSize(int totalSize, int toAdd) => - // Calculate the total length this would be w/ "[]" and commas - getBase64Size(totalSize + toAdd + messagesList.Count + 2); - - // Format the final array string, i.e. [{...},{...}] - string getArrayString() - { - if (messagesList.Count == 1) - { - return CoreHelpers.Base64EncodeString(messagesList[0]); - } - return CoreHelpers.Base64EncodeString( - string.Concat("[", string.Join(',', messagesList), "]")); - } - - var serializedMessages = messages.Select(message => - JsonSerializer.Serialize(message, jsonOptions)); - - foreach (var message in serializedMessages) - { - var messageSize = Encoding.UTF8.GetByteCount(message); - if (calculateByteSize(messagesListSize, messageSize) > _queueClient.MessageMaxBytes) - { - yield return getArrayString(); - messagesListSize = 0; - messagesList.Clear(); - } - - messagesList.Add(message); - messagesListSize += messageSize; - } - - if (messagesList.Any()) + var messageSize = Encoding.UTF8.GetByteCount(message); + if (calculateByteSize(messagesListSize, messageSize) > _queueClient.MessageMaxBytes) { yield return getArrayString(); + messagesListSize = 0; + messagesList.Clear(); } + + messagesList.Add(message); + messagesListSize += messageSize; + } + + if (messagesList.Any()) + { + yield return getArrayString(); } } } diff --git a/src/Core/Services/Implementations/AzureSendFileStorageService.cs b/src/Core/Services/Implementations/AzureSendFileStorageService.cs index 94a0aaaee..d1d7822f2 100644 --- a/src/Core/Services/Implementations/AzureSendFileStorageService.cs +++ b/src/Core/Services/Implementations/AzureSendFileStorageService.cs @@ -6,137 +6,136 @@ using Bit.Core.Enums; using Bit.Core.Settings; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class AzureSendFileStorageService : ISendFileStorageService { - public class AzureSendFileStorageService : ISendFileStorageService + public const string FilesContainerName = "sendfiles"; + private static readonly TimeSpan _downloadLinkLiveTime = TimeSpan.FromMinutes(1); + private readonly BlobServiceClient _blobServiceClient; + private readonly ILogger _logger; + private BlobContainerClient _sendFilesContainerClient; + + public FileUploadType FileUploadType => FileUploadType.Azure; + + public static string SendIdFromBlobName(string blobName) => blobName.Split('/')[0]; + public static string BlobName(Send send, string fileId) => $"{send.Id}/{fileId}"; + + public AzureSendFileStorageService( + GlobalSettings globalSettings, + ILogger logger) { - public const string FilesContainerName = "sendfiles"; - private static readonly TimeSpan _downloadLinkLiveTime = TimeSpan.FromMinutes(1); - private readonly BlobServiceClient _blobServiceClient; - private readonly ILogger _logger; - private BlobContainerClient _sendFilesContainerClient; + _blobServiceClient = new BlobServiceClient(globalSettings.Send.ConnectionString); + _logger = logger; + } - public FileUploadType FileUploadType => FileUploadType.Azure; + public async Task UploadNewFileAsync(Stream stream, Send send, string fileId) + { + await InitAsync(); - public static string SendIdFromBlobName(string blobName) => blobName.Split('/')[0]; - public static string BlobName(Send send, string fileId) => $"{send.Id}/{fileId}"; + var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId)); - public AzureSendFileStorageService( - GlobalSettings globalSettings, - ILogger logger) + var metadata = new Dictionary(); + if (send.UserId.HasValue) { - _blobServiceClient = new BlobServiceClient(globalSettings.Send.ConnectionString); - _logger = logger; + metadata.Add("userId", send.UserId.Value.ToString()); + } + else + { + metadata.Add("organizationId", send.OrganizationId.Value.ToString()); } - public async Task UploadNewFileAsync(Stream stream, Send send, string fileId) + var headers = new BlobHttpHeaders { - await InitAsync(); + ContentDisposition = $"attachment; filename=\"{fileId}\"" + }; - var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId)); + await blobClient.UploadAsync(stream, new BlobUploadOptions { Metadata = metadata, HttpHeaders = headers }); + } + + public async Task DeleteFileAsync(Send send, string fileId) => await DeleteBlobAsync(BlobName(send, fileId)); + + public async Task DeleteBlobAsync(string blobName) + { + await InitAsync(); + var blobClient = _sendFilesContainerClient.GetBlobClient(blobName); + await blobClient.DeleteIfExistsAsync(); + } + + public async Task DeleteFilesForOrganizationAsync(Guid organizationId) + { + await InitAsync(); + } + + public async Task DeleteFilesForUserAsync(Guid userId) + { + await InitAsync(); + } + + public async Task GetSendFileDownloadUrlAsync(Send send, string fileId) + { + await InitAsync(); + var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId)); + var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Read, DateTime.UtcNow.Add(_downloadLinkLiveTime)); + return sasUri.ToString(); + } + + public async Task GetSendFileUploadUrlAsync(Send send, string fileId) + { + await InitAsync(); + var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId)); + var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Create | BlobSasPermissions.Write, DateTime.UtcNow.Add(_downloadLinkLiveTime)); + return sasUri.ToString(); + } + + public async Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway) + { + await InitAsync(); + + var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId)); + + try + { + var blobProperties = await blobClient.GetPropertiesAsync(); + var metadata = blobProperties.Value.Metadata; - var metadata = new Dictionary(); if (send.UserId.HasValue) { - metadata.Add("userId", send.UserId.Value.ToString()); + metadata["userId"] = send.UserId.Value.ToString(); } else { - metadata.Add("organizationId", send.OrganizationId.Value.ToString()); + metadata["organizationId"] = send.OrganizationId.Value.ToString(); } + await blobClient.SetMetadataAsync(metadata); var headers = new BlobHttpHeaders { ContentDisposition = $"attachment; filename=\"{fileId}\"" }; + await blobClient.SetHttpHeadersAsync(headers); - await blobClient.UploadAsync(stream, new BlobUploadOptions { Metadata = metadata, HttpHeaders = headers }); - } - - public async Task DeleteFileAsync(Send send, string fileId) => await DeleteBlobAsync(BlobName(send, fileId)); - - public async Task DeleteBlobAsync(string blobName) - { - await InitAsync(); - var blobClient = _sendFilesContainerClient.GetBlobClient(blobName); - await blobClient.DeleteIfExistsAsync(); - } - - public async Task DeleteFilesForOrganizationAsync(Guid organizationId) - { - await InitAsync(); - } - - public async Task DeleteFilesForUserAsync(Guid userId) - { - await InitAsync(); - } - - public async Task GetSendFileDownloadUrlAsync(Send send, string fileId) - { - await InitAsync(); - var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId)); - var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Read, DateTime.UtcNow.Add(_downloadLinkLiveTime)); - return sasUri.ToString(); - } - - public async Task GetSendFileUploadUrlAsync(Send send, string fileId) - { - await InitAsync(); - var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId)); - var sasUri = blobClient.GenerateSasUri(BlobSasPermissions.Create | BlobSasPermissions.Write, DateTime.UtcNow.Add(_downloadLinkLiveTime)); - return sasUri.ToString(); - } - - public async Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway) - { - await InitAsync(); - - var blobClient = _sendFilesContainerClient.GetBlobClient(BlobName(send, fileId)); - - try + var length = blobProperties.Value.ContentLength; + if (length < expectedFileSize - leeway || length > expectedFileSize + leeway) { - var blobProperties = await blobClient.GetPropertiesAsync(); - var metadata = blobProperties.Value.Metadata; - - if (send.UserId.HasValue) - { - metadata["userId"] = send.UserId.Value.ToString(); - } - else - { - metadata["organizationId"] = send.OrganizationId.Value.ToString(); - } - await blobClient.SetMetadataAsync(metadata); - - var headers = new BlobHttpHeaders - { - ContentDisposition = $"attachment; filename=\"{fileId}\"" - }; - await blobClient.SetHttpHeadersAsync(headers); - - var length = blobProperties.Value.ContentLength; - if (length < expectedFileSize - leeway || length > expectedFileSize + leeway) - { - return (false, length); - } - - return (true, length); - } - catch (Exception ex) - { - _logger.LogError(ex, "Unhandled error in ValidateFileAsync"); - return (false, null); + return (false, length); } + + return (true, length); } - - private async Task InitAsync() + catch (Exception ex) { - if (_sendFilesContainerClient == null) - { - _sendFilesContainerClient = _blobServiceClient.GetBlobContainerClient(FilesContainerName); - await _sendFilesContainerClient.CreateIfNotExistsAsync(PublicAccessType.None, null, null); - } + _logger.LogError(ex, "Unhandled error in ValidateFileAsync"); + return (false, null); + } + } + + private async Task InitAsync() + { + if (_sendFilesContainerClient == null) + { + _sendFilesContainerClient = _blobServiceClient.GetBlobContainerClient(FilesContainerName); + await _sendFilesContainerClient.CreateIfNotExistsAsync(PublicAccessType.None, null, null); } } } diff --git a/src/Core/Services/Implementations/BaseIdentityClientService.cs b/src/Core/Services/Implementations/BaseIdentityClientService.cs index 2115eba24..fd9be533b 100644 --- a/src/Core/Services/Implementations/BaseIdentityClientService.cs +++ b/src/Core/Services/Implementations/BaseIdentityClientService.cs @@ -5,203 +5,202 @@ using System.Text.Json; using Bit.Core.Utilities; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public abstract class BaseIdentityClientService : IDisposable { - public abstract class BaseIdentityClientService : IDisposable + private readonly IHttpClientFactory _httpFactory; + private readonly string _identityScope; + private readonly string _identityClientId; + private readonly string _identityClientSecret; + protected readonly ILogger _logger; + + private JsonDocument _decodedToken; + private DateTime? _nextAuthAttempt = null; + + public BaseIdentityClientService( + IHttpClientFactory httpFactory, + string baseClientServerUri, + string baseIdentityServerUri, + string identityScope, + string identityClientId, + string identityClientSecret, + ILogger logger) { - private readonly IHttpClientFactory _httpFactory; - private readonly string _identityScope; - private readonly string _identityClientId; - private readonly string _identityClientSecret; - protected readonly ILogger _logger; + _httpFactory = httpFactory; + _identityScope = identityScope; + _identityClientId = identityClientId; + _identityClientSecret = identityClientSecret; + _logger = logger; - private JsonDocument _decodedToken; - private DateTime? _nextAuthAttempt = null; + Client = _httpFactory.CreateClient("client"); + Client.BaseAddress = new Uri(baseClientServerUri); + Client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); - public BaseIdentityClientService( - IHttpClientFactory httpFactory, - string baseClientServerUri, - string baseIdentityServerUri, - string identityScope, - string identityClientId, - string identityClientSecret, - ILogger logger) + IdentityClient = _httpFactory.CreateClient("identity"); + IdentityClient.BaseAddress = new Uri(baseIdentityServerUri); + IdentityClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + } + + protected HttpClient Client { get; private set; } + protected HttpClient IdentityClient { get; private set; } + protected string AccessToken { get; private set; } + + protected Task SendAsync(HttpMethod method, string path) => + SendAsync(method, path, null); + + protected Task SendAsync(HttpMethod method, string path, TRequest body) => + SendAsync(method, path, body); + + protected async Task SendAsync(HttpMethod method, string path, TRequest requestModel) + { + var tokenStateResponse = await HandleTokenStateAsync(); + if (!tokenStateResponse) { - _httpFactory = httpFactory; - _identityScope = identityScope; - _identityClientId = identityClientId; - _identityClientSecret = identityClientSecret; - _logger = logger; - - Client = _httpFactory.CreateClient("client"); - Client.BaseAddress = new Uri(baseClientServerUri); - Client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); - - IdentityClient = _httpFactory.CreateClient("identity"); - IdentityClient.BaseAddress = new Uri(baseIdentityServerUri); - IdentityClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + return default; } - protected HttpClient Client { get; private set; } - protected HttpClient IdentityClient { get; private set; } - protected string AccessToken { get; private set; } - - protected Task SendAsync(HttpMethod method, string path) => - SendAsync(method, path, null); - - protected Task SendAsync(HttpMethod method, string path, TRequest body) => - SendAsync(method, path, body); - - protected async Task SendAsync(HttpMethod method, string path, TRequest requestModel) + var message = new TokenHttpRequestMessage(requestModel, AccessToken) { - var tokenStateResponse = await HandleTokenStateAsync(); - if (!tokenStateResponse) - { - return default; - } - - var message = new TokenHttpRequestMessage(requestModel, AccessToken) - { - Method = method, - RequestUri = new Uri(string.Concat(Client.BaseAddress, path)) - }; - try - { - var response = await Client.SendAsync(message); - return await response.Content.ReadFromJsonAsync(); - } - catch (Exception e) - { - _logger.LogError(12334, e, "Failed to send to {0}.", message.RequestUri.ToString()); - return default; - } + Method = method, + RequestUri = new Uri(string.Concat(Client.BaseAddress, path)) + }; + try + { + var response = await Client.SendAsync(message); + return await response.Content.ReadFromJsonAsync(); } - - protected async Task HandleTokenStateAsync() + catch (Exception e) { - if (_nextAuthAttempt.HasValue && DateTime.UtcNow > _nextAuthAttempt.Value) - { - return false; - } - _nextAuthAttempt = null; + _logger.LogError(12334, e, "Failed to send to {0}.", message.RequestUri.ToString()); + return default; + } + } - if (!string.IsNullOrWhiteSpace(AccessToken) && !TokenNeedsRefresh()) - { - return true; - } + protected async Task HandleTokenStateAsync() + { + if (_nextAuthAttempt.HasValue && DateTime.UtcNow > _nextAuthAttempt.Value) + { + return false; + } + _nextAuthAttempt = null; - var requestMessage = new HttpRequestMessage - { - Method = HttpMethod.Post, - RequestUri = new Uri(string.Concat(IdentityClient.BaseAddress, "connect/token")), - Content = new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "scope", _identityScope }, - { "client_id", _identityClientId }, - { "client_secret", _identityClientSecret } - }) - }; - - HttpResponseMessage response = null; - try - { - response = await IdentityClient.SendAsync(requestMessage); - } - catch (Exception e) - { - _logger.LogError(12339, e, "Unable to authenticate with identity server."); - } - - if (response == null) - { - return false; - } - - if (!response.IsSuccessStatusCode) - { - _logger.LogInformation("Unsuccessful token response with status code {StatusCode}", response.StatusCode); - - if (response.StatusCode == HttpStatusCode.BadRequest) - { - _nextAuthAttempt = DateTime.UtcNow.AddDays(1); - } - - if (_logger.IsEnabled(LogLevel.Debug)) - { - var responseBody = await response.Content.ReadAsStringAsync(); - _logger.LogDebug("Error response body:\n{ResponseBody}", responseBody); - } - - return false; - } - - using var jsonDocument = await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync()); - - AccessToken = jsonDocument.RootElement.GetProperty("access_token").GetString(); + if (!string.IsNullOrWhiteSpace(AccessToken) && !TokenNeedsRefresh()) + { return true; } - protected class TokenHttpRequestMessage : HttpRequestMessage + var requestMessage = new HttpRequestMessage { - public TokenHttpRequestMessage(string token) + Method = HttpMethod.Post, + RequestUri = new Uri(string.Concat(IdentityClient.BaseAddress, "connect/token")), + Content = new FormUrlEncodedContent(new Dictionary { - Headers.Add("Authorization", $"Bearer {token}"); - } + { "grant_type", "client_credentials" }, + { "scope", _identityScope }, + { "client_id", _identityClientId }, + { "client_secret", _identityClientSecret } + }) + }; - public TokenHttpRequestMessage(object requestObject, string token) - : this(token) - { - if (requestObject != null) - { - Content = JsonContent.Create(requestObject); - } - } + HttpResponseMessage response = null; + try + { + response = await IdentityClient.SendAsync(requestMessage); + } + catch (Exception e) + { + _logger.LogError(12339, e, "Unable to authenticate with identity server."); } - protected bool TokenNeedsRefresh(int minutes = 5) + if (response == null) { - var decoded = DecodeToken(); - if (!decoded.RootElement.TryGetProperty("exp", out var expProp)) - { - throw new InvalidOperationException("No exp in token."); - } - - var expiration = CoreHelpers.FromEpocSeconds(expProp.GetInt64()); - return DateTime.UtcNow.AddMinutes(-1 * minutes) > expiration; + return false; } - protected JsonDocument DecodeToken() + if (!response.IsSuccessStatusCode) { - if (_decodedToken != null) + _logger.LogInformation("Unsuccessful token response with status code {StatusCode}", response.StatusCode); + + if (response.StatusCode == HttpStatusCode.BadRequest) { - return _decodedToken; + _nextAuthAttempt = DateTime.UtcNow.AddDays(1); } - if (AccessToken == null) + if (_logger.IsEnabled(LogLevel.Debug)) { - throw new InvalidOperationException($"{nameof(AccessToken)} not found."); + var responseBody = await response.Content.ReadAsStringAsync(); + _logger.LogDebug("Error response body:\n{ResponseBody}", responseBody); } - var parts = AccessToken.Split('.'); - if (parts.Length != 3) - { - throw new InvalidOperationException($"{nameof(AccessToken)} must have 3 parts"); - } + return false; + } - var decodedBytes = CoreHelpers.Base64UrlDecode(parts[1]); - if (decodedBytes == null || decodedBytes.Length < 1) - { - throw new InvalidOperationException($"{nameof(AccessToken)} must have 3 parts"); - } + using var jsonDocument = await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync()); - _decodedToken = JsonDocument.Parse(decodedBytes); + AccessToken = jsonDocument.RootElement.GetProperty("access_token").GetString(); + return true; + } + + protected class TokenHttpRequestMessage : HttpRequestMessage + { + public TokenHttpRequestMessage(string token) + { + Headers.Add("Authorization", $"Bearer {token}"); + } + + public TokenHttpRequestMessage(object requestObject, string token) + : this(token) + { + if (requestObject != null) + { + Content = JsonContent.Create(requestObject); + } + } + } + + protected bool TokenNeedsRefresh(int minutes = 5) + { + var decoded = DecodeToken(); + if (!decoded.RootElement.TryGetProperty("exp", out var expProp)) + { + throw new InvalidOperationException("No exp in token."); + } + + var expiration = CoreHelpers.FromEpocSeconds(expProp.GetInt64()); + return DateTime.UtcNow.AddMinutes(-1 * minutes) > expiration; + } + + protected JsonDocument DecodeToken() + { + if (_decodedToken != null) + { return _decodedToken; } - public void Dispose() + if (AccessToken == null) { - _decodedToken?.Dispose(); + throw new InvalidOperationException($"{nameof(AccessToken)} not found."); } + + var parts = AccessToken.Split('.'); + if (parts.Length != 3) + { + throw new InvalidOperationException($"{nameof(AccessToken)} must have 3 parts"); + } + + var decodedBytes = CoreHelpers.Base64UrlDecode(parts[1]); + if (decodedBytes == null || decodedBytes.Length < 1) + { + throw new InvalidOperationException($"{nameof(AccessToken)} must have 3 parts"); + } + + _decodedToken = JsonDocument.Parse(decodedBytes); + return _decodedToken; + } + + public void Dispose() + { + _decodedToken?.Dispose(); } } diff --git a/src/Core/Services/Implementations/BlockingMailQueueService.cs b/src/Core/Services/Implementations/BlockingMailQueueService.cs index 0a1a99b85..0323b09af 100644 --- a/src/Core/Services/Implementations/BlockingMailQueueService.cs +++ b/src/Core/Services/Implementations/BlockingMailQueueService.cs @@ -1,20 +1,19 @@ using Bit.Core.Models.Mail; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class BlockingMailEnqueuingService : IMailEnqueuingService { - public class BlockingMailEnqueuingService : IMailEnqueuingService + public async Task EnqueueAsync(IMailQueueMessage message, Func fallback) { - public async Task EnqueueAsync(IMailQueueMessage message, Func fallback) + await fallback(message); + } + + public async Task EnqueueManyAsync(IEnumerable messages, Func fallback) + { + foreach (var message in messages) { await fallback(message); } - - public async Task EnqueueManyAsync(IEnumerable messages, Func fallback) - { - foreach (var message in messages) - { - await fallback(message); - } - } } } diff --git a/src/Core/Services/Implementations/CipherService.cs b/src/Core/Services/Implementations/CipherService.cs index dfa156974..e2679e628 100644 --- a/src/Core/Services/Implementations/CipherService.cs +++ b/src/Core/Services/Implementations/CipherService.cs @@ -10,1022 +10,1021 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Core.Models.Data; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class CipherService : ICipherService { - public class CipherService : ICipherService + public const long MAX_FILE_SIZE = Constants.FileSize501mb; + public const string MAX_FILE_SIZE_READABLE = "500 MB"; + private readonly ICipherRepository _cipherRepository; + private readonly IFolderRepository _folderRepository; + private readonly ICollectionRepository _collectionRepository; + private readonly IUserRepository _userRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly ICollectionCipherRepository _collectionCipherRepository; + private readonly IPushNotificationService _pushService; + private readonly IAttachmentStorageService _attachmentStorageService; + private readonly IEventService _eventService; + private readonly IUserService _userService; + private readonly IPolicyRepository _policyRepository; + private readonly GlobalSettings _globalSettings; + private const long _fileSizeLeeway = 1024L * 1024L; // 1MB + private readonly IReferenceEventService _referenceEventService; + private readonly ICurrentContext _currentContext; + private readonly IProviderService _providerService; + + public CipherService( + ICipherRepository cipherRepository, + IFolderRepository folderRepository, + ICollectionRepository collectionRepository, + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + ICollectionCipherRepository collectionCipherRepository, + IPushNotificationService pushService, + IAttachmentStorageService attachmentStorageService, + IEventService eventService, + IUserService userService, + IPolicyRepository policyRepository, + GlobalSettings globalSettings, + IReferenceEventService referenceEventService, + ICurrentContext currentContext) { - public const long MAX_FILE_SIZE = Constants.FileSize501mb; - public const string MAX_FILE_SIZE_READABLE = "500 MB"; - private readonly ICipherRepository _cipherRepository; - private readonly IFolderRepository _folderRepository; - private readonly ICollectionRepository _collectionRepository; - private readonly IUserRepository _userRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly ICollectionCipherRepository _collectionCipherRepository; - private readonly IPushNotificationService _pushService; - private readonly IAttachmentStorageService _attachmentStorageService; - private readonly IEventService _eventService; - private readonly IUserService _userService; - private readonly IPolicyRepository _policyRepository; - private readonly GlobalSettings _globalSettings; - private const long _fileSizeLeeway = 1024L * 1024L; // 1MB - private readonly IReferenceEventService _referenceEventService; - private readonly ICurrentContext _currentContext; - private readonly IProviderService _providerService; + _cipherRepository = cipherRepository; + _folderRepository = folderRepository; + _collectionRepository = collectionRepository; + _userRepository = userRepository; + _organizationRepository = organizationRepository; + _collectionCipherRepository = collectionCipherRepository; + _pushService = pushService; + _attachmentStorageService = attachmentStorageService; + _eventService = eventService; + _userService = userService; + _policyRepository = policyRepository; + _globalSettings = globalSettings; + _referenceEventService = referenceEventService; + _currentContext = currentContext; + } - public CipherService( - ICipherRepository cipherRepository, - IFolderRepository folderRepository, - ICollectionRepository collectionRepository, - IUserRepository userRepository, - IOrganizationRepository organizationRepository, - ICollectionCipherRepository collectionCipherRepository, - IPushNotificationService pushService, - IAttachmentStorageService attachmentStorageService, - IEventService eventService, - IUserService userService, - IPolicyRepository policyRepository, - GlobalSettings globalSettings, - IReferenceEventService referenceEventService, - ICurrentContext currentContext) + public async Task SaveAsync(Cipher cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, + IEnumerable collectionIds = null, bool skipPermissionCheck = false, bool limitCollectionScope = true) + { + if (!skipPermissionCheck && !(await UserCanEditAsync(cipher, savingUserId))) { - _cipherRepository = cipherRepository; - _folderRepository = folderRepository; - _collectionRepository = collectionRepository; - _userRepository = userRepository; - _organizationRepository = organizationRepository; - _collectionCipherRepository = collectionCipherRepository; - _pushService = pushService; - _attachmentStorageService = attachmentStorageService; - _eventService = eventService; - _userService = userService; - _policyRepository = policyRepository; - _globalSettings = globalSettings; - _referenceEventService = referenceEventService; - _currentContext = currentContext; + throw new BadRequestException("You do not have permissions to edit this."); } - public async Task SaveAsync(Cipher cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, - IEnumerable collectionIds = null, bool skipPermissionCheck = false, bool limitCollectionScope = true) + if (cipher.Id == default(Guid)) { - if (!skipPermissionCheck && !(await UserCanEditAsync(cipher, savingUserId))) + if (cipher.OrganizationId.HasValue && collectionIds != null) { - throw new BadRequestException("You do not have permissions to edit this."); - } - - if (cipher.Id == default(Guid)) - { - if (cipher.OrganizationId.HasValue && collectionIds != null) + if (limitCollectionScope) { - if (limitCollectionScope) - { - // Set user ID to limit scope of collection ids in the create sproc - cipher.UserId = savingUserId; - } - await _cipherRepository.CreateAsync(cipher, collectionIds); - - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.CipherCreated, await _organizationRepository.GetByIdAsync(cipher.OrganizationId.Value))); + // Set user ID to limit scope of collection ids in the create sproc + cipher.UserId = savingUserId; } - else - { - await _cipherRepository.CreateAsync(cipher); - } - await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Created); + await _cipherRepository.CreateAsync(cipher, collectionIds); - // push - await _pushService.PushSyncCipherCreateAsync(cipher, null); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.CipherCreated, await _organizationRepository.GetByIdAsync(cipher.OrganizationId.Value))); } else { - ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); - cipher.RevisionDate = DateTime.UtcNow; - await _cipherRepository.ReplaceAsync(cipher); - await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Updated); - - // push - await _pushService.PushSyncCipherUpdateAsync(cipher, collectionIds); + await _cipherRepository.CreateAsync(cipher); } + await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Created); + + // push + await _pushService.PushSyncCipherCreateAsync(cipher, null); + } + else + { + ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); + cipher.RevisionDate = DateTime.UtcNow; + await _cipherRepository.ReplaceAsync(cipher); + await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Updated); + + // push + await _pushService.PushSyncCipherUpdateAsync(cipher, collectionIds); + } + } + + public async Task SaveDetailsAsync(CipherDetails cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, + IEnumerable collectionIds = null, bool skipPermissionCheck = false) + { + if (!skipPermissionCheck && !(await UserCanEditAsync(cipher, savingUserId))) + { + throw new BadRequestException("You do not have permissions to edit this."); } - public async Task SaveDetailsAsync(CipherDetails cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, - IEnumerable collectionIds = null, bool skipPermissionCheck = false) + cipher.UserId = savingUserId; + if (cipher.Id == default(Guid)) { - if (!skipPermissionCheck && !(await UserCanEditAsync(cipher, savingUserId))) + if (cipher.OrganizationId.HasValue && collectionIds != null) { - throw new BadRequestException("You do not have permissions to edit this."); - } - - cipher.UserId = savingUserId; - if (cipher.Id == default(Guid)) - { - if (cipher.OrganizationId.HasValue && collectionIds != null) + var existingCollectionIds = (await _collectionRepository.GetManyByOrganizationIdAsync(cipher.OrganizationId.Value)).Select(c => c.Id); + if (collectionIds.Except(existingCollectionIds).Any()) { - var existingCollectionIds = (await _collectionRepository.GetManyByOrganizationIdAsync(cipher.OrganizationId.Value)).Select(c => c.Id); - if (collectionIds.Except(existingCollectionIds).Any()) - { - throw new BadRequestException("Specified CollectionId does not exist on the specified Organization."); - } - await _cipherRepository.CreateAsync(cipher, collectionIds); + throw new BadRequestException("Specified CollectionId does not exist on the specified Organization."); } - else - { - // Make sure the user can save new ciphers to their personal vault - var personalOwnershipPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(savingUserId, - PolicyType.PersonalOwnership); - if (personalOwnershipPolicyCount > 0) - { - throw new BadRequestException("Due to an Enterprise Policy, you are restricted from saving items to your personal vault."); - } - await _cipherRepository.CreateAsync(cipher); - } - await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Created); - - if (cipher.OrganizationId.HasValue) - { - var org = await _organizationRepository.GetByIdAsync(cipher.OrganizationId.Value); - cipher.OrganizationUseTotp = org.UseTotp; - } - - // push - await _pushService.PushSyncCipherCreateAsync(cipher, null); + await _cipherRepository.CreateAsync(cipher, collectionIds); } else { - ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); - cipher.RevisionDate = DateTime.UtcNow; - await _cipherRepository.ReplaceAsync(cipher); - await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Updated); - - // push - await _pushService.PushSyncCipherUpdateAsync(cipher, collectionIds); + // Make sure the user can save new ciphers to their personal vault + var personalOwnershipPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(savingUserId, + PolicyType.PersonalOwnership); + if (personalOwnershipPolicyCount > 0) + { + throw new BadRequestException("Due to an Enterprise Policy, you are restricted from saving items to your personal vault."); + } + await _cipherRepository.CreateAsync(cipher); } + await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Created); + + if (cipher.OrganizationId.HasValue) + { + var org = await _organizationRepository.GetByIdAsync(cipher.OrganizationId.Value); + cipher.OrganizationUseTotp = org.UseTotp; + } + + // push + await _pushService.PushSyncCipherCreateAsync(cipher, null); + } + else + { + ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); + cipher.RevisionDate = DateTime.UtcNow; + await _cipherRepository.ReplaceAsync(cipher); + await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Updated); + + // push + await _pushService.PushSyncCipherUpdateAsync(cipher, collectionIds); + } + } + + public async Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachment) + { + if (attachment == null) + { + throw new BadRequestException("Cipher attachment does not exist"); } - public async Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachment) + await _attachmentStorageService.UploadNewAttachmentAsync(stream, cipher, attachment); + + if (!await ValidateCipherAttachmentFile(cipher, attachment)) { - if (attachment == null) - { - throw new BadRequestException("Cipher attachment does not exist"); - } - - await _attachmentStorageService.UploadNewAttachmentAsync(stream, cipher, attachment); - - if (!await ValidateCipherAttachmentFile(cipher, attachment)) - { - throw new BadRequestException("File received does not match expected file length."); - } + throw new BadRequestException("File received does not match expected file length."); } + } - public async Task<(string attachmentId, string uploadUrl)> CreateAttachmentForDelayedUploadAsync(Cipher cipher, - string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId) + public async Task<(string attachmentId, string uploadUrl)> CreateAttachmentForDelayedUploadAsync(Cipher cipher, + string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId) + { + await ValidateCipherEditForAttachmentAsync(cipher, savingUserId, adminRequest, fileSize); + + var attachmentId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); + var data = new CipherAttachment.MetaData { - await ValidateCipherEditForAttachmentAsync(cipher, savingUserId, adminRequest, fileSize); + AttachmentId = attachmentId, + FileName = fileName, + Key = key, + Size = fileSize, + Validated = false, + }; - var attachmentId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); - var data = new CipherAttachment.MetaData - { - AttachmentId = attachmentId, - FileName = fileName, - Key = key, - Size = fileSize, - Validated = false, - }; + var uploadUrl = await _attachmentStorageService.GetAttachmentUploadUrlAsync(cipher, data); - var uploadUrl = await _attachmentStorageService.GetAttachmentUploadUrlAsync(cipher, data); + await _cipherRepository.UpdateAttachmentAsync(new CipherAttachment + { + Id = cipher.Id, + UserId = cipher.UserId, + OrganizationId = cipher.OrganizationId, + AttachmentId = attachmentId, + AttachmentData = JsonSerializer.Serialize(data) + }); + cipher.AddAttachment(attachmentId, data); + await _pushService.PushSyncCipherUpdateAsync(cipher, null); - await _cipherRepository.UpdateAttachmentAsync(new CipherAttachment + return (attachmentId, uploadUrl); + } + + public async Task CreateAttachmentAsync(Cipher cipher, Stream stream, string fileName, string key, + long requestLength, Guid savingUserId, bool orgAdmin = false) + { + await ValidateCipherEditForAttachmentAsync(cipher, savingUserId, orgAdmin, requestLength); + + var attachmentId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); + var data = new CipherAttachment.MetaData + { + AttachmentId = attachmentId, + FileName = fileName, + Key = key, + }; + + await _attachmentStorageService.UploadNewAttachmentAsync(stream, cipher, data); + // Must read stream length after it has been saved, otherwise it's 0 + data.Size = stream.Length; + + try + { + var attachment = new CipherAttachment { Id = cipher.Id, UserId = cipher.UserId, OrganizationId = cipher.OrganizationId, AttachmentId = attachmentId, AttachmentData = JsonSerializer.Serialize(data) - }); + }; + + await _cipherRepository.UpdateAttachmentAsync(attachment); + await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_AttachmentCreated); cipher.AddAttachment(attachmentId, data); - await _pushService.PushSyncCipherUpdateAsync(cipher, null); - return (attachmentId, uploadUrl); - } - - public async Task CreateAttachmentAsync(Cipher cipher, Stream stream, string fileName, string key, - long requestLength, Guid savingUserId, bool orgAdmin = false) - { - await ValidateCipherEditForAttachmentAsync(cipher, savingUserId, orgAdmin, requestLength); - - var attachmentId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); - var data = new CipherAttachment.MetaData + if (!await ValidateCipherAttachmentFile(cipher, data)) { - AttachmentId = attachmentId, - FileName = fileName, - Key = key, - }; - - await _attachmentStorageService.UploadNewAttachmentAsync(stream, cipher, data); - // Must read stream length after it has been saved, otherwise it's 0 - data.Size = stream.Length; - - try - { - var attachment = new CipherAttachment - { - Id = cipher.Id, - UserId = cipher.UserId, - OrganizationId = cipher.OrganizationId, - AttachmentId = attachmentId, - AttachmentData = JsonSerializer.Serialize(data) - }; - - await _cipherRepository.UpdateAttachmentAsync(attachment); - await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_AttachmentCreated); - cipher.AddAttachment(attachmentId, data); - - if (!await ValidateCipherAttachmentFile(cipher, data)) - { - throw new Exception("Content-Length does not match uploaded file size"); - } - } - catch - { - // Clean up since this is not transactional - await _attachmentStorageService.DeleteAttachmentAsync(cipher.Id, data); - throw; - } - - // push - await _pushService.PushSyncCipherUpdateAsync(cipher, null); - } - - public async Task CreateAttachmentShareAsync(Cipher cipher, Stream stream, long requestLength, - string attachmentId, Guid organizationId) - { - try - { - if (requestLength < 1) - { - throw new BadRequestException("No data to attach."); - } - - if (cipher.Id == default(Guid)) - { - throw new BadRequestException(nameof(cipher.Id)); - } - - if (cipher.OrganizationId.HasValue) - { - throw new BadRequestException("Cipher belongs to an organization already."); - } - - var org = await _organizationRepository.GetByIdAsync(organizationId); - if (org == null || !org.MaxStorageGb.HasValue) - { - throw new BadRequestException("This organization cannot use attachments."); - } - - var storageBytesRemaining = org.StorageBytesRemaining(); - if (storageBytesRemaining < requestLength) - { - throw new BadRequestException("Not enough storage available for this organization."); - } - - var attachments = cipher.GetAttachments(); - if (!attachments.ContainsKey(attachmentId)) - { - throw new BadRequestException($"Cipher does not own specified attachment"); - } - - await _attachmentStorageService.UploadShareAttachmentAsync(stream, cipher.Id, organizationId, - attachments[attachmentId]); - - // Previous call may alter metadata - var updatedAttachment = new CipherAttachment - { - Id = cipher.Id, - UserId = cipher.UserId, - OrganizationId = cipher.OrganizationId, - AttachmentId = attachmentId, - AttachmentData = JsonSerializer.Serialize(attachments[attachmentId]) - }; - - await _cipherRepository.UpdateAttachmentAsync(updatedAttachment); - } - catch - { - await _attachmentStorageService.CleanupAsync(cipher.Id); - throw; + throw new Exception("Content-Length does not match uploaded file size"); } } - - public async Task ValidateCipherAttachmentFile(Cipher cipher, CipherAttachment.MetaData attachmentData) + catch { - var (valid, realSize) = await _attachmentStorageService.ValidateFileAsync(cipher, attachmentData, _fileSizeLeeway); - - if (!valid || realSize > MAX_FILE_SIZE) - { - // File reported differs in size from that promised. Must be a rogue client. Delete Send - await DeleteAttachmentAsync(cipher, attachmentData); - return false; - } - // Update Send data if necessary - if (realSize != attachmentData.Size) - { - attachmentData.Size = realSize.Value; - } - attachmentData.Validated = true; - - var updatedAttachment = new CipherAttachment - { - Id = cipher.Id, - UserId = cipher.UserId, - OrganizationId = cipher.OrganizationId, - AttachmentId = attachmentData.AttachmentId, - AttachmentData = JsonSerializer.Serialize(attachmentData) - }; - - - await _cipherRepository.UpdateAttachmentAsync(updatedAttachment); - - return valid; + // Clean up since this is not transactional + await _attachmentStorageService.DeleteAttachmentAsync(cipher.Id, data); + throw; } - public async Task GetAttachmentDownloadDataAsync(Cipher cipher, string attachmentId) + // push + await _pushService.PushSyncCipherUpdateAsync(cipher, null); + } + + public async Task CreateAttachmentShareAsync(Cipher cipher, Stream stream, long requestLength, + string attachmentId, Guid organizationId) + { + try { - var attachments = cipher?.GetAttachments() ?? new Dictionary(); - - if (!attachments.ContainsKey(attachmentId)) - { - throw new NotFoundException(); - } - - var data = attachments[attachmentId]; - var response = new AttachmentResponseData - { - Cipher = cipher, - Data = data, - Id = attachmentId, - Url = await _attachmentStorageService.GetAttachmentDownloadUrlAsync(cipher, data), - }; - - return response; - } - - public async Task DeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false) - { - if (!orgAdmin && !(await UserCanEditAsync(cipher, deletingUserId))) - { - throw new BadRequestException("You do not have permissions to delete this."); - } - - await _cipherRepository.DeleteAsync(cipher); - await _attachmentStorageService.DeleteAttachmentsForCipherAsync(cipher.Id); - await _eventService.LogCipherEventAsync(cipher, EventType.Cipher_Deleted); - - // push - await _pushService.PushSyncCipherDeleteAsync(cipher); - } - - public async Task DeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false) - { - var cipherIdsSet = new HashSet(cipherIds); - var deletingCiphers = new List(); - - if (orgAdmin && organizationId.HasValue) - { - var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(organizationId.Value); - deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id)).ToList(); - await _cipherRepository.DeleteByIdsOrganizationIdAsync(deletingCiphers.Select(c => c.Id), organizationId.Value); - } - else - { - var ciphers = await _cipherRepository.GetManyByUserIdAsync(deletingUserId); - deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id) && c.Edit).Select(x => (Cipher)x).ToList(); - await _cipherRepository.DeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId); - } - - var events = deletingCiphers.Select(c => - new Tuple(c, EventType.Cipher_Deleted, null)); - foreach (var eventsBatch in events.Batch(100)) - { - await _eventService.LogCipherEventsAsync(eventsBatch); - } - - // push - await _pushService.PushSyncCiphersAsync(deletingUserId); - } - - public async Task DeleteAttachmentAsync(Cipher cipher, string attachmentId, Guid deletingUserId, - bool orgAdmin = false) - { - if (!orgAdmin && !(await UserCanEditAsync(cipher, deletingUserId))) - { - throw new BadRequestException("You do not have permissions to delete this."); - } - - if (!cipher.ContainsAttachment(attachmentId)) - { - throw new NotFoundException(); - } - - await DeleteAttachmentAsync(cipher, cipher.GetAttachments()[attachmentId]); - } - - public async Task PurgeAsync(Guid organizationId) - { - var org = await _organizationRepository.GetByIdAsync(organizationId); - if (org == null) - { - throw new NotFoundException(); - } - await _cipherRepository.DeleteByOrganizationIdAsync(organizationId); - await _eventService.LogOrganizationEventAsync(org, Enums.EventType.Organization_PurgedVault); - } - - public async Task MoveManyAsync(IEnumerable cipherIds, Guid? destinationFolderId, Guid movingUserId) - { - if (destinationFolderId.HasValue) - { - var folder = await _folderRepository.GetByIdAsync(destinationFolderId.Value); - if (folder == null || folder.UserId != movingUserId) - { - throw new BadRequestException("Invalid folder."); - } - } - - await _cipherRepository.MoveAsync(cipherIds, destinationFolderId, movingUserId); - // push - await _pushService.PushSyncCiphersAsync(movingUserId); - } - - public async Task SaveFolderAsync(Folder folder) - { - if (folder.Id == default(Guid)) - { - await _folderRepository.CreateAsync(folder); - - // push - await _pushService.PushSyncFolderCreateAsync(folder); - } - else - { - folder.RevisionDate = DateTime.UtcNow; - await _folderRepository.UpsertAsync(folder); - - // push - await _pushService.PushSyncFolderUpdateAsync(folder); - } - } - - public async Task DeleteFolderAsync(Folder folder) - { - await _folderRepository.DeleteAsync(folder); - - // push - await _pushService.PushSyncFolderDeleteAsync(folder); - } - - public async Task ShareAsync(Cipher originalCipher, Cipher cipher, Guid organizationId, - IEnumerable collectionIds, Guid sharingUserId, DateTime? lastKnownRevisionDate) - { - var attachments = cipher.GetAttachments(); - var hasOldAttachments = attachments?.Any(a => a.Key == null) ?? false; - var updatedCipher = false; - var migratedAttachments = false; - var originalAttachments = CoreHelpers.CloneObject(attachments); - - try - { - await ValidateCipherCanBeShared(cipher, sharingUserId, organizationId, lastKnownRevisionDate); - - // Sproc will not save this UserId on the cipher. It is used limit scope of the collectionIds. - cipher.UserId = sharingUserId; - cipher.OrganizationId = organizationId; - cipher.RevisionDate = DateTime.UtcNow; - if (!await _cipherRepository.ReplaceAsync(cipher, collectionIds)) - { - throw new BadRequestException("Unable to save."); - } - - updatedCipher = true; - await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Shared); - - if (hasOldAttachments) - { - // migrate old attachments - foreach (var attachment in attachments.Where(a => a.Key == null)) - { - await _attachmentStorageService.StartShareAttachmentAsync(cipher.Id, organizationId, - attachment.Value); - migratedAttachments = true; - } - - // commit attachment migration - await _attachmentStorageService.CleanupAsync(cipher.Id); - } - - // push - await _pushService.PushSyncCipherUpdateAsync(cipher, collectionIds); - } - catch - { - // roll everything back - if (updatedCipher) - { - await _cipherRepository.ReplaceAsync(originalCipher); - } - - if (!hasOldAttachments || !migratedAttachments) - { - throw; - } - - if (updatedCipher) - { - await _userRepository.UpdateStorageAsync(sharingUserId); - await _organizationRepository.UpdateStorageAsync(organizationId); - } - - foreach (var attachment in attachments.Where(a => a.Key == null)) - { - await _attachmentStorageService.RollbackShareAttachmentAsync(cipher.Id, organizationId, - attachment.Value, originalAttachments[attachment.Key].ContainerName); - } - - await _attachmentStorageService.CleanupAsync(cipher.Id); - throw; - } - } - - public async Task ShareManyAsync(IEnumerable<(Cipher cipher, DateTime? lastKnownRevisionDate)> cipherInfos, - Guid organizationId, IEnumerable collectionIds, Guid sharingUserId) - { - var cipherIds = new List(); - foreach (var (cipher, lastKnownRevisionDate) in cipherInfos) - { - await ValidateCipherCanBeShared(cipher, sharingUserId, organizationId, lastKnownRevisionDate); - - cipher.UserId = null; - cipher.OrganizationId = organizationId; - cipher.RevisionDate = DateTime.UtcNow; - cipherIds.Add(cipher.Id); - } - - await _cipherRepository.UpdateCiphersAsync(sharingUserId, cipherInfos.Select(c => c.cipher)); - await _collectionCipherRepository.UpdateCollectionsForCiphersAsync(cipherIds, sharingUserId, - organizationId, collectionIds); - - var events = cipherInfos.Select(c => - new Tuple(c.cipher, EventType.Cipher_Shared, null)); - foreach (var eventsBatch in events.Batch(100)) - { - await _eventService.LogCipherEventsAsync(eventsBatch); - } - - // push - await _pushService.PushSyncCiphersAsync(sharingUserId); - } - - public async Task SaveCollectionsAsync(Cipher cipher, IEnumerable collectionIds, Guid savingUserId, - bool orgAdmin) - { - if (cipher.Id == default(Guid)) - { - throw new BadRequestException(nameof(cipher.Id)); - } - - if (!cipher.OrganizationId.HasValue) - { - throw new BadRequestException("Cipher must belong to an organization."); - } - - cipher.RevisionDate = DateTime.UtcNow; - - // The sprocs will validate that all collections belong to this org/user and that they have - // proper write permissions. - if (orgAdmin) - { - await _collectionCipherRepository.UpdateCollectionsForAdminAsync(cipher.Id, - cipher.OrganizationId.Value, collectionIds); - } - else - { - if (!(await UserCanEditAsync(cipher, savingUserId))) - { - throw new BadRequestException("You do not have permissions to edit this."); - } - await _collectionCipherRepository.UpdateCollectionsAsync(cipher.Id, savingUserId, collectionIds); - } - - await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_UpdatedCollections); - - // push - await _pushService.PushSyncCipherUpdateAsync(cipher, collectionIds); - } - - public async Task ImportCiphersAsync( - List folders, - List ciphers, - IEnumerable> folderRelationships) - { - var userId = folders.FirstOrDefault()?.UserId ?? ciphers.FirstOrDefault()?.UserId; - - // Make sure the user can save new ciphers to their personal vault - var personalOwnershipPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(userId.Value, - PolicyType.PersonalOwnership); - if (personalOwnershipPolicyCount > 0) - { - throw new BadRequestException("You cannot import items into your personal vault because you are " + - "a member of an organization which forbids it."); - } - - foreach (var cipher in ciphers) - { - cipher.SetNewId(); - - if (cipher.UserId.HasValue && cipher.Favorite) - { - cipher.Favorites = $"{{\"{cipher.UserId.ToString().ToUpperInvariant()}\":\"true\"}}"; - } - } - - // Init. ids for folders - foreach (var folder in folders) - { - folder.SetNewId(); - } - - // Create the folder associations based on the newly created folder ids - foreach (var relationship in folderRelationships) - { - var cipher = ciphers.ElementAtOrDefault(relationship.Key); - var folder = folders.ElementAtOrDefault(relationship.Value); - - if (cipher == null || folder == null) - { - continue; - } - - cipher.Folders = $"{{\"{cipher.UserId.ToString().ToUpperInvariant()}\":" + - $"\"{folder.Id.ToString().ToUpperInvariant()}\"}}"; - } - - // Create it all - await _cipherRepository.CreateAsync(ciphers, folders); - - // push - if (userId.HasValue) - { - await _pushService.PushSyncVaultAsync(userId.Value); - } - } - - public async Task ImportCiphersAsync( - List collections, - List ciphers, - IEnumerable> collectionRelationships, - Guid importingUserId) - { - var org = collections.Count > 0 ? - await _organizationRepository.GetByIdAsync(collections[0].OrganizationId) : - await _organizationRepository.GetByIdAsync(ciphers.FirstOrDefault(c => c.OrganizationId.HasValue).OrganizationId.Value); - - if (collections.Count > 0 && org != null && org.MaxCollections.HasValue) - { - var collectionCount = await _collectionRepository.GetCountByOrganizationIdAsync(org.Id); - if (org.MaxCollections.Value < (collectionCount + collections.Count)) - { - throw new BadRequestException("This organization can only have a maximum of " + - $"{org.MaxCollections.Value} collections."); - } - } - - // Init. ids for ciphers - foreach (var cipher in ciphers) - { - cipher.SetNewId(); - } - - // Init. ids for collections - foreach (var collection in collections) - { - collection.SetNewId(); - } - - // Create associations based on the newly assigned ids - var collectionCiphers = new List(); - foreach (var relationship in collectionRelationships) - { - var cipher = ciphers.ElementAtOrDefault(relationship.Key); - var collection = collections.ElementAtOrDefault(relationship.Value); - - if (cipher == null || collection == null) - { - continue; - } - - collectionCiphers.Add(new CollectionCipher - { - CipherId = cipher.Id, - CollectionId = collection.Id - }); - } - - // Create it all - await _cipherRepository.CreateAsync(ciphers, collections, collectionCiphers); - - // push - await _pushService.PushSyncVaultAsync(importingUserId); - - - if (org != null) - { - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.VaultImported, org)); - } - } - - public async Task SoftDeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false) - { - if (!orgAdmin && !(await UserCanEditAsync(cipher, deletingUserId))) - { - throw new BadRequestException("You do not have permissions to soft delete this."); - } - - if (cipher.DeletedDate.HasValue) - { - // Already soft-deleted, we can safely ignore this - return; - } - - cipher.DeletedDate = cipher.RevisionDate = DateTime.UtcNow; - - if (cipher is CipherDetails details) - { - await _cipherRepository.UpsertAsync(details); - } - else - { - await _cipherRepository.UpsertAsync(cipher); - } - await _eventService.LogCipherEventAsync(cipher, EventType.Cipher_SoftDeleted); - - // push - await _pushService.PushSyncCipherUpdateAsync(cipher, null); - } - - public async Task SoftDeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId, bool orgAdmin) - { - var cipherIdsSet = new HashSet(cipherIds); - var deletingCiphers = new List(); - - if (orgAdmin && organizationId.HasValue) - { - var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(organizationId.Value); - deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id)).ToList(); - await _cipherRepository.SoftDeleteByIdsOrganizationIdAsync(deletingCiphers.Select(c => c.Id), organizationId.Value); - } - else - { - var ciphers = await _cipherRepository.GetManyByUserIdAsync(deletingUserId); - deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id) && c.Edit).Select(x => (Cipher)x).ToList(); - await _cipherRepository.SoftDeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId); - } - - var events = deletingCiphers.Select(c => - new Tuple(c, EventType.Cipher_SoftDeleted, null)); - foreach (var eventsBatch in events.Batch(100)) - { - await _eventService.LogCipherEventsAsync(eventsBatch); - } - - // push - await _pushService.PushSyncCiphersAsync(deletingUserId); - } - - public async Task RestoreAsync(Cipher cipher, Guid restoringUserId, bool orgAdmin = false) - { - if (!orgAdmin && !(await UserCanEditAsync(cipher, restoringUserId))) - { - throw new BadRequestException("You do not have permissions to delete this."); - } - - if (!cipher.DeletedDate.HasValue) - { - // Already restored, we can safely ignore this - return; - } - - cipher.DeletedDate = null; - cipher.RevisionDate = DateTime.UtcNow; - - if (cipher is CipherDetails details) - { - await _cipherRepository.UpsertAsync(details); - } - else - { - await _cipherRepository.UpsertAsync(cipher); - } - await _eventService.LogCipherEventAsync(cipher, EventType.Cipher_Restored); - - // push - await _pushService.PushSyncCipherUpdateAsync(cipher, null); - } - - public async Task RestoreManyAsync(IEnumerable ciphers, Guid restoringUserId) - { - var revisionDate = await _cipherRepository.RestoreAsync(ciphers.Select(c => c.Id), restoringUserId); - - var events = ciphers.Select(c => - { - c.RevisionDate = revisionDate; - c.DeletedDate = null; - return new Tuple(c, EventType.Cipher_Restored, null); - }); - foreach (var eventsBatch in events.Batch(100)) - { - await _eventService.LogCipherEventsAsync(eventsBatch); - } - - // push - await _pushService.PushSyncCiphersAsync(restoringUserId); - } - - public async Task<(IEnumerable, Dictionary>)> GetOrganizationCiphers(Guid userId, Guid organizationId) - { - if (!await _currentContext.ViewAllCollections(organizationId) && !await _currentContext.AccessReports(organizationId)) - { - throw new NotFoundException(); - } - - IEnumerable orgCiphers; - if (await _currentContext.OrganizationAdmin(organizationId)) - { - // Admins, Owners and Providers can access all items even if not assigned to them - orgCiphers = await _cipherRepository.GetManyOrganizationDetailsByOrganizationIdAsync(organizationId); - } - else - { - var ciphers = await _cipherRepository.GetManyByUserIdAsync(userId, true); - orgCiphers = ciphers.Where(c => c.OrganizationId == organizationId); - } - - var orgCipherIds = orgCiphers.Select(c => c.Id); - - var collectionCiphers = await _collectionCipherRepository.GetManyByOrganizationIdAsync(organizationId); - var collectionCiphersGroupDict = collectionCiphers - .Where(c => orgCipherIds.Contains(c.CipherId)) - .GroupBy(c => c.CipherId).ToDictionary(s => s.Key); - - var providerId = await _currentContext.ProviderIdForOrg(organizationId); - if (providerId.HasValue) - { - await _providerService.LogProviderAccessToOrganizationAsync(organizationId); - } - - return (orgCiphers, collectionCiphersGroupDict); - } - - private async Task UserCanEditAsync(Cipher cipher, Guid userId) - { - if (!cipher.OrganizationId.HasValue && cipher.UserId.HasValue && cipher.UserId.Value == userId) - { - return true; - } - - return await _cipherRepository.GetCanEditByIdAsync(userId, cipher.Id); - } - - private void ValidateCipherLastKnownRevisionDateAsync(Cipher cipher, DateTime? lastKnownRevisionDate) - { - if (cipher.Id == default || !lastKnownRevisionDate.HasValue) - { - return; - } - - if ((cipher.RevisionDate - lastKnownRevisionDate.Value).Duration() > TimeSpan.FromSeconds(1)) - { - throw new BadRequestException( - "The cipher you are updating is out of date. Please save your work, sync your vault, and try again." - ); - } - } - - private async Task DeleteAttachmentAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) - { - if (attachmentData == null || string.IsNullOrWhiteSpace(attachmentData.AttachmentId)) - { - return; - } - - await _cipherRepository.DeleteAttachmentAsync(cipher.Id, attachmentData.AttachmentId); - cipher.DeleteAttachment(attachmentData.AttachmentId); - await _attachmentStorageService.DeleteAttachmentAsync(cipher.Id, attachmentData); - await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_AttachmentDeleted); - - // push - await _pushService.PushSyncCipherUpdateAsync(cipher, null); - } - - private async Task ValidateCipherEditForAttachmentAsync(Cipher cipher, Guid savingUserId, bool orgAdmin, - long requestLength) - { - if (!orgAdmin && !(await UserCanEditAsync(cipher, savingUserId))) - { - throw new BadRequestException("You do not have permissions to edit this."); - } - if (requestLength < 1) { throw new BadRequestException("No data to attach."); } - var storageBytesRemaining = await StorageBytesRemainingForCipherAsync(cipher); - - if (storageBytesRemaining < requestLength) - { - throw new BadRequestException("Not enough storage available."); - } - } - - private async Task StorageBytesRemainingForCipherAsync(Cipher cipher) - { - var storageBytesRemaining = 0L; - if (cipher.UserId.HasValue) - { - var user = await _userRepository.GetByIdAsync(cipher.UserId.Value); - if (!(await _userService.CanAccessPremium(user))) - { - throw new BadRequestException("You must have premium status to use attachments."); - } - - if (user.Premium) - { - storageBytesRemaining = user.StorageBytesRemaining(); - } - else - { - // Users that get access to file storage/premium from their organization get the default - // 1 GB max storage. - storageBytesRemaining = user.StorageBytesRemaining( - _globalSettings.SelfHosted ? (short)10240 : (short)1); - } - } - else if (cipher.OrganizationId.HasValue) - { - var org = await _organizationRepository.GetByIdAsync(cipher.OrganizationId.Value); - if (!org.MaxStorageGb.HasValue) - { - throw new BadRequestException("This organization cannot use attachments."); - } - - storageBytesRemaining = org.StorageBytesRemaining(); - } - - return storageBytesRemaining; - } - - private async Task ValidateCipherCanBeShared( - Cipher cipher, - Guid sharingUserId, - Guid organizationId, - DateTime? lastKnownRevisionDate) - { if (cipher.Id == default(Guid)) { - throw new BadRequestException("Cipher must already exist."); + throw new BadRequestException(nameof(cipher.Id)); } if (cipher.OrganizationId.HasValue) { - throw new BadRequestException("One or more ciphers already belong to an organization."); + throw new BadRequestException("Cipher belongs to an organization already."); } - if (!cipher.UserId.HasValue || cipher.UserId.Value != sharingUserId) - { - throw new BadRequestException("One or more ciphers do not belong to you."); - } - - var attachments = cipher.GetAttachments(); - var hasAttachments = attachments?.Any() ?? false; var org = await _organizationRepository.GetByIdAsync(organizationId); - - if (org == null) - { - throw new BadRequestException("Could not find organization."); - } - - if (hasAttachments && !org.MaxStorageGb.HasValue) + if (org == null || !org.MaxStorageGb.HasValue) { throw new BadRequestException("This organization cannot use attachments."); } - var storageAdjustment = attachments?.Sum(a => a.Value.Size) ?? 0; - if (org.StorageBytesRemaining() < storageAdjustment) + var storageBytesRemaining = org.StorageBytesRemaining(); + if (storageBytesRemaining < requestLength) { throw new BadRequestException("Not enough storage available for this organization."); } - ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); + var attachments = cipher.GetAttachments(); + if (!attachments.ContainsKey(attachmentId)) + { + throw new BadRequestException($"Cipher does not own specified attachment"); + } + + await _attachmentStorageService.UploadShareAttachmentAsync(stream, cipher.Id, organizationId, + attachments[attachmentId]); + + // Previous call may alter metadata + var updatedAttachment = new CipherAttachment + { + Id = cipher.Id, + UserId = cipher.UserId, + OrganizationId = cipher.OrganizationId, + AttachmentId = attachmentId, + AttachmentData = JsonSerializer.Serialize(attachments[attachmentId]) + }; + + await _cipherRepository.UpdateAttachmentAsync(updatedAttachment); + } + catch + { + await _attachmentStorageService.CleanupAsync(cipher.Id); + throw; } } + + public async Task ValidateCipherAttachmentFile(Cipher cipher, CipherAttachment.MetaData attachmentData) + { + var (valid, realSize) = await _attachmentStorageService.ValidateFileAsync(cipher, attachmentData, _fileSizeLeeway); + + if (!valid || realSize > MAX_FILE_SIZE) + { + // File reported differs in size from that promised. Must be a rogue client. Delete Send + await DeleteAttachmentAsync(cipher, attachmentData); + return false; + } + // Update Send data if necessary + if (realSize != attachmentData.Size) + { + attachmentData.Size = realSize.Value; + } + attachmentData.Validated = true; + + var updatedAttachment = new CipherAttachment + { + Id = cipher.Id, + UserId = cipher.UserId, + OrganizationId = cipher.OrganizationId, + AttachmentId = attachmentData.AttachmentId, + AttachmentData = JsonSerializer.Serialize(attachmentData) + }; + + + await _cipherRepository.UpdateAttachmentAsync(updatedAttachment); + + return valid; + } + + public async Task GetAttachmentDownloadDataAsync(Cipher cipher, string attachmentId) + { + var attachments = cipher?.GetAttachments() ?? new Dictionary(); + + if (!attachments.ContainsKey(attachmentId)) + { + throw new NotFoundException(); + } + + var data = attachments[attachmentId]; + var response = new AttachmentResponseData + { + Cipher = cipher, + Data = data, + Id = attachmentId, + Url = await _attachmentStorageService.GetAttachmentDownloadUrlAsync(cipher, data), + }; + + return response; + } + + public async Task DeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false) + { + if (!orgAdmin && !(await UserCanEditAsync(cipher, deletingUserId))) + { + throw new BadRequestException("You do not have permissions to delete this."); + } + + await _cipherRepository.DeleteAsync(cipher); + await _attachmentStorageService.DeleteAttachmentsForCipherAsync(cipher.Id); + await _eventService.LogCipherEventAsync(cipher, EventType.Cipher_Deleted); + + // push + await _pushService.PushSyncCipherDeleteAsync(cipher); + } + + public async Task DeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false) + { + var cipherIdsSet = new HashSet(cipherIds); + var deletingCiphers = new List(); + + if (orgAdmin && organizationId.HasValue) + { + var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(organizationId.Value); + deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id)).ToList(); + await _cipherRepository.DeleteByIdsOrganizationIdAsync(deletingCiphers.Select(c => c.Id), organizationId.Value); + } + else + { + var ciphers = await _cipherRepository.GetManyByUserIdAsync(deletingUserId); + deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id) && c.Edit).Select(x => (Cipher)x).ToList(); + await _cipherRepository.DeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId); + } + + var events = deletingCiphers.Select(c => + new Tuple(c, EventType.Cipher_Deleted, null)); + foreach (var eventsBatch in events.Batch(100)) + { + await _eventService.LogCipherEventsAsync(eventsBatch); + } + + // push + await _pushService.PushSyncCiphersAsync(deletingUserId); + } + + public async Task DeleteAttachmentAsync(Cipher cipher, string attachmentId, Guid deletingUserId, + bool orgAdmin = false) + { + if (!orgAdmin && !(await UserCanEditAsync(cipher, deletingUserId))) + { + throw new BadRequestException("You do not have permissions to delete this."); + } + + if (!cipher.ContainsAttachment(attachmentId)) + { + throw new NotFoundException(); + } + + await DeleteAttachmentAsync(cipher, cipher.GetAttachments()[attachmentId]); + } + + public async Task PurgeAsync(Guid organizationId) + { + var org = await _organizationRepository.GetByIdAsync(organizationId); + if (org == null) + { + throw new NotFoundException(); + } + await _cipherRepository.DeleteByOrganizationIdAsync(organizationId); + await _eventService.LogOrganizationEventAsync(org, Enums.EventType.Organization_PurgedVault); + } + + public async Task MoveManyAsync(IEnumerable cipherIds, Guid? destinationFolderId, Guid movingUserId) + { + if (destinationFolderId.HasValue) + { + var folder = await _folderRepository.GetByIdAsync(destinationFolderId.Value); + if (folder == null || folder.UserId != movingUserId) + { + throw new BadRequestException("Invalid folder."); + } + } + + await _cipherRepository.MoveAsync(cipherIds, destinationFolderId, movingUserId); + // push + await _pushService.PushSyncCiphersAsync(movingUserId); + } + + public async Task SaveFolderAsync(Folder folder) + { + if (folder.Id == default(Guid)) + { + await _folderRepository.CreateAsync(folder); + + // push + await _pushService.PushSyncFolderCreateAsync(folder); + } + else + { + folder.RevisionDate = DateTime.UtcNow; + await _folderRepository.UpsertAsync(folder); + + // push + await _pushService.PushSyncFolderUpdateAsync(folder); + } + } + + public async Task DeleteFolderAsync(Folder folder) + { + await _folderRepository.DeleteAsync(folder); + + // push + await _pushService.PushSyncFolderDeleteAsync(folder); + } + + public async Task ShareAsync(Cipher originalCipher, Cipher cipher, Guid organizationId, + IEnumerable collectionIds, Guid sharingUserId, DateTime? lastKnownRevisionDate) + { + var attachments = cipher.GetAttachments(); + var hasOldAttachments = attachments?.Any(a => a.Key == null) ?? false; + var updatedCipher = false; + var migratedAttachments = false; + var originalAttachments = CoreHelpers.CloneObject(attachments); + + try + { + await ValidateCipherCanBeShared(cipher, sharingUserId, organizationId, lastKnownRevisionDate); + + // Sproc will not save this UserId on the cipher. It is used limit scope of the collectionIds. + cipher.UserId = sharingUserId; + cipher.OrganizationId = organizationId; + cipher.RevisionDate = DateTime.UtcNow; + if (!await _cipherRepository.ReplaceAsync(cipher, collectionIds)) + { + throw new BadRequestException("Unable to save."); + } + + updatedCipher = true; + await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_Shared); + + if (hasOldAttachments) + { + // migrate old attachments + foreach (var attachment in attachments.Where(a => a.Key == null)) + { + await _attachmentStorageService.StartShareAttachmentAsync(cipher.Id, organizationId, + attachment.Value); + migratedAttachments = true; + } + + // commit attachment migration + await _attachmentStorageService.CleanupAsync(cipher.Id); + } + + // push + await _pushService.PushSyncCipherUpdateAsync(cipher, collectionIds); + } + catch + { + // roll everything back + if (updatedCipher) + { + await _cipherRepository.ReplaceAsync(originalCipher); + } + + if (!hasOldAttachments || !migratedAttachments) + { + throw; + } + + if (updatedCipher) + { + await _userRepository.UpdateStorageAsync(sharingUserId); + await _organizationRepository.UpdateStorageAsync(organizationId); + } + + foreach (var attachment in attachments.Where(a => a.Key == null)) + { + await _attachmentStorageService.RollbackShareAttachmentAsync(cipher.Id, organizationId, + attachment.Value, originalAttachments[attachment.Key].ContainerName); + } + + await _attachmentStorageService.CleanupAsync(cipher.Id); + throw; + } + } + + public async Task ShareManyAsync(IEnumerable<(Cipher cipher, DateTime? lastKnownRevisionDate)> cipherInfos, + Guid organizationId, IEnumerable collectionIds, Guid sharingUserId) + { + var cipherIds = new List(); + foreach (var (cipher, lastKnownRevisionDate) in cipherInfos) + { + await ValidateCipherCanBeShared(cipher, sharingUserId, organizationId, lastKnownRevisionDate); + + cipher.UserId = null; + cipher.OrganizationId = organizationId; + cipher.RevisionDate = DateTime.UtcNow; + cipherIds.Add(cipher.Id); + } + + await _cipherRepository.UpdateCiphersAsync(sharingUserId, cipherInfos.Select(c => c.cipher)); + await _collectionCipherRepository.UpdateCollectionsForCiphersAsync(cipherIds, sharingUserId, + organizationId, collectionIds); + + var events = cipherInfos.Select(c => + new Tuple(c.cipher, EventType.Cipher_Shared, null)); + foreach (var eventsBatch in events.Batch(100)) + { + await _eventService.LogCipherEventsAsync(eventsBatch); + } + + // push + await _pushService.PushSyncCiphersAsync(sharingUserId); + } + + public async Task SaveCollectionsAsync(Cipher cipher, IEnumerable collectionIds, Guid savingUserId, + bool orgAdmin) + { + if (cipher.Id == default(Guid)) + { + throw new BadRequestException(nameof(cipher.Id)); + } + + if (!cipher.OrganizationId.HasValue) + { + throw new BadRequestException("Cipher must belong to an organization."); + } + + cipher.RevisionDate = DateTime.UtcNow; + + // The sprocs will validate that all collections belong to this org/user and that they have + // proper write permissions. + if (orgAdmin) + { + await _collectionCipherRepository.UpdateCollectionsForAdminAsync(cipher.Id, + cipher.OrganizationId.Value, collectionIds); + } + else + { + if (!(await UserCanEditAsync(cipher, savingUserId))) + { + throw new BadRequestException("You do not have permissions to edit this."); + } + await _collectionCipherRepository.UpdateCollectionsAsync(cipher.Id, savingUserId, collectionIds); + } + + await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_UpdatedCollections); + + // push + await _pushService.PushSyncCipherUpdateAsync(cipher, collectionIds); + } + + public async Task ImportCiphersAsync( + List folders, + List ciphers, + IEnumerable> folderRelationships) + { + var userId = folders.FirstOrDefault()?.UserId ?? ciphers.FirstOrDefault()?.UserId; + + // Make sure the user can save new ciphers to their personal vault + var personalOwnershipPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(userId.Value, + PolicyType.PersonalOwnership); + if (personalOwnershipPolicyCount > 0) + { + throw new BadRequestException("You cannot import items into your personal vault because you are " + + "a member of an organization which forbids it."); + } + + foreach (var cipher in ciphers) + { + cipher.SetNewId(); + + if (cipher.UserId.HasValue && cipher.Favorite) + { + cipher.Favorites = $"{{\"{cipher.UserId.ToString().ToUpperInvariant()}\":\"true\"}}"; + } + } + + // Init. ids for folders + foreach (var folder in folders) + { + folder.SetNewId(); + } + + // Create the folder associations based on the newly created folder ids + foreach (var relationship in folderRelationships) + { + var cipher = ciphers.ElementAtOrDefault(relationship.Key); + var folder = folders.ElementAtOrDefault(relationship.Value); + + if (cipher == null || folder == null) + { + continue; + } + + cipher.Folders = $"{{\"{cipher.UserId.ToString().ToUpperInvariant()}\":" + + $"\"{folder.Id.ToString().ToUpperInvariant()}\"}}"; + } + + // Create it all + await _cipherRepository.CreateAsync(ciphers, folders); + + // push + if (userId.HasValue) + { + await _pushService.PushSyncVaultAsync(userId.Value); + } + } + + public async Task ImportCiphersAsync( + List collections, + List ciphers, + IEnumerable> collectionRelationships, + Guid importingUserId) + { + var org = collections.Count > 0 ? + await _organizationRepository.GetByIdAsync(collections[0].OrganizationId) : + await _organizationRepository.GetByIdAsync(ciphers.FirstOrDefault(c => c.OrganizationId.HasValue).OrganizationId.Value); + + if (collections.Count > 0 && org != null && org.MaxCollections.HasValue) + { + var collectionCount = await _collectionRepository.GetCountByOrganizationIdAsync(org.Id); + if (org.MaxCollections.Value < (collectionCount + collections.Count)) + { + throw new BadRequestException("This organization can only have a maximum of " + + $"{org.MaxCollections.Value} collections."); + } + } + + // Init. ids for ciphers + foreach (var cipher in ciphers) + { + cipher.SetNewId(); + } + + // Init. ids for collections + foreach (var collection in collections) + { + collection.SetNewId(); + } + + // Create associations based on the newly assigned ids + var collectionCiphers = new List(); + foreach (var relationship in collectionRelationships) + { + var cipher = ciphers.ElementAtOrDefault(relationship.Key); + var collection = collections.ElementAtOrDefault(relationship.Value); + + if (cipher == null || collection == null) + { + continue; + } + + collectionCiphers.Add(new CollectionCipher + { + CipherId = cipher.Id, + CollectionId = collection.Id + }); + } + + // Create it all + await _cipherRepository.CreateAsync(ciphers, collections, collectionCiphers); + + // push + await _pushService.PushSyncVaultAsync(importingUserId); + + + if (org != null) + { + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.VaultImported, org)); + } + } + + public async Task SoftDeleteAsync(Cipher cipher, Guid deletingUserId, bool orgAdmin = false) + { + if (!orgAdmin && !(await UserCanEditAsync(cipher, deletingUserId))) + { + throw new BadRequestException("You do not have permissions to soft delete this."); + } + + if (cipher.DeletedDate.HasValue) + { + // Already soft-deleted, we can safely ignore this + return; + } + + cipher.DeletedDate = cipher.RevisionDate = DateTime.UtcNow; + + if (cipher is CipherDetails details) + { + await _cipherRepository.UpsertAsync(details); + } + else + { + await _cipherRepository.UpsertAsync(cipher); + } + await _eventService.LogCipherEventAsync(cipher, EventType.Cipher_SoftDeleted); + + // push + await _pushService.PushSyncCipherUpdateAsync(cipher, null); + } + + public async Task SoftDeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId, bool orgAdmin) + { + var cipherIdsSet = new HashSet(cipherIds); + var deletingCiphers = new List(); + + if (orgAdmin && organizationId.HasValue) + { + var ciphers = await _cipherRepository.GetManyByOrganizationIdAsync(organizationId.Value); + deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id)).ToList(); + await _cipherRepository.SoftDeleteByIdsOrganizationIdAsync(deletingCiphers.Select(c => c.Id), organizationId.Value); + } + else + { + var ciphers = await _cipherRepository.GetManyByUserIdAsync(deletingUserId); + deletingCiphers = ciphers.Where(c => cipherIdsSet.Contains(c.Id) && c.Edit).Select(x => (Cipher)x).ToList(); + await _cipherRepository.SoftDeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId); + } + + var events = deletingCiphers.Select(c => + new Tuple(c, EventType.Cipher_SoftDeleted, null)); + foreach (var eventsBatch in events.Batch(100)) + { + await _eventService.LogCipherEventsAsync(eventsBatch); + } + + // push + await _pushService.PushSyncCiphersAsync(deletingUserId); + } + + public async Task RestoreAsync(Cipher cipher, Guid restoringUserId, bool orgAdmin = false) + { + if (!orgAdmin && !(await UserCanEditAsync(cipher, restoringUserId))) + { + throw new BadRequestException("You do not have permissions to delete this."); + } + + if (!cipher.DeletedDate.HasValue) + { + // Already restored, we can safely ignore this + return; + } + + cipher.DeletedDate = null; + cipher.RevisionDate = DateTime.UtcNow; + + if (cipher is CipherDetails details) + { + await _cipherRepository.UpsertAsync(details); + } + else + { + await _cipherRepository.UpsertAsync(cipher); + } + await _eventService.LogCipherEventAsync(cipher, EventType.Cipher_Restored); + + // push + await _pushService.PushSyncCipherUpdateAsync(cipher, null); + } + + public async Task RestoreManyAsync(IEnumerable ciphers, Guid restoringUserId) + { + var revisionDate = await _cipherRepository.RestoreAsync(ciphers.Select(c => c.Id), restoringUserId); + + var events = ciphers.Select(c => + { + c.RevisionDate = revisionDate; + c.DeletedDate = null; + return new Tuple(c, EventType.Cipher_Restored, null); + }); + foreach (var eventsBatch in events.Batch(100)) + { + await _eventService.LogCipherEventsAsync(eventsBatch); + } + + // push + await _pushService.PushSyncCiphersAsync(restoringUserId); + } + + public async Task<(IEnumerable, Dictionary>)> GetOrganizationCiphers(Guid userId, Guid organizationId) + { + if (!await _currentContext.ViewAllCollections(organizationId) && !await _currentContext.AccessReports(organizationId)) + { + throw new NotFoundException(); + } + + IEnumerable orgCiphers; + if (await _currentContext.OrganizationAdmin(organizationId)) + { + // Admins, Owners and Providers can access all items even if not assigned to them + orgCiphers = await _cipherRepository.GetManyOrganizationDetailsByOrganizationIdAsync(organizationId); + } + else + { + var ciphers = await _cipherRepository.GetManyByUserIdAsync(userId, true); + orgCiphers = ciphers.Where(c => c.OrganizationId == organizationId); + } + + var orgCipherIds = orgCiphers.Select(c => c.Id); + + var collectionCiphers = await _collectionCipherRepository.GetManyByOrganizationIdAsync(organizationId); + var collectionCiphersGroupDict = collectionCiphers + .Where(c => orgCipherIds.Contains(c.CipherId)) + .GroupBy(c => c.CipherId).ToDictionary(s => s.Key); + + var providerId = await _currentContext.ProviderIdForOrg(organizationId); + if (providerId.HasValue) + { + await _providerService.LogProviderAccessToOrganizationAsync(organizationId); + } + + return (orgCiphers, collectionCiphersGroupDict); + } + + private async Task UserCanEditAsync(Cipher cipher, Guid userId) + { + if (!cipher.OrganizationId.HasValue && cipher.UserId.HasValue && cipher.UserId.Value == userId) + { + return true; + } + + return await _cipherRepository.GetCanEditByIdAsync(userId, cipher.Id); + } + + private void ValidateCipherLastKnownRevisionDateAsync(Cipher cipher, DateTime? lastKnownRevisionDate) + { + if (cipher.Id == default || !lastKnownRevisionDate.HasValue) + { + return; + } + + if ((cipher.RevisionDate - lastKnownRevisionDate.Value).Duration() > TimeSpan.FromSeconds(1)) + { + throw new BadRequestException( + "The cipher you are updating is out of date. Please save your work, sync your vault, and try again." + ); + } + } + + private async Task DeleteAttachmentAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) + { + if (attachmentData == null || string.IsNullOrWhiteSpace(attachmentData.AttachmentId)) + { + return; + } + + await _cipherRepository.DeleteAttachmentAsync(cipher.Id, attachmentData.AttachmentId); + cipher.DeleteAttachment(attachmentData.AttachmentId); + await _attachmentStorageService.DeleteAttachmentAsync(cipher.Id, attachmentData); + await _eventService.LogCipherEventAsync(cipher, Enums.EventType.Cipher_AttachmentDeleted); + + // push + await _pushService.PushSyncCipherUpdateAsync(cipher, null); + } + + private async Task ValidateCipherEditForAttachmentAsync(Cipher cipher, Guid savingUserId, bool orgAdmin, + long requestLength) + { + if (!orgAdmin && !(await UserCanEditAsync(cipher, savingUserId))) + { + throw new BadRequestException("You do not have permissions to edit this."); + } + + if (requestLength < 1) + { + throw new BadRequestException("No data to attach."); + } + + var storageBytesRemaining = await StorageBytesRemainingForCipherAsync(cipher); + + if (storageBytesRemaining < requestLength) + { + throw new BadRequestException("Not enough storage available."); + } + } + + private async Task StorageBytesRemainingForCipherAsync(Cipher cipher) + { + var storageBytesRemaining = 0L; + if (cipher.UserId.HasValue) + { + var user = await _userRepository.GetByIdAsync(cipher.UserId.Value); + if (!(await _userService.CanAccessPremium(user))) + { + throw new BadRequestException("You must have premium status to use attachments."); + } + + if (user.Premium) + { + storageBytesRemaining = user.StorageBytesRemaining(); + } + else + { + // Users that get access to file storage/premium from their organization get the default + // 1 GB max storage. + storageBytesRemaining = user.StorageBytesRemaining( + _globalSettings.SelfHosted ? (short)10240 : (short)1); + } + } + else if (cipher.OrganizationId.HasValue) + { + var org = await _organizationRepository.GetByIdAsync(cipher.OrganizationId.Value); + if (!org.MaxStorageGb.HasValue) + { + throw new BadRequestException("This organization cannot use attachments."); + } + + storageBytesRemaining = org.StorageBytesRemaining(); + } + + return storageBytesRemaining; + } + + private async Task ValidateCipherCanBeShared( + Cipher cipher, + Guid sharingUserId, + Guid organizationId, + DateTime? lastKnownRevisionDate) + { + if (cipher.Id == default(Guid)) + { + throw new BadRequestException("Cipher must already exist."); + } + + if (cipher.OrganizationId.HasValue) + { + throw new BadRequestException("One or more ciphers already belong to an organization."); + } + + if (!cipher.UserId.HasValue || cipher.UserId.Value != sharingUserId) + { + throw new BadRequestException("One or more ciphers do not belong to you."); + } + + var attachments = cipher.GetAttachments(); + var hasAttachments = attachments?.Any() ?? false; + var org = await _organizationRepository.GetByIdAsync(organizationId); + + if (org == null) + { + throw new BadRequestException("Could not find organization."); + } + + if (hasAttachments && !org.MaxStorageGb.HasValue) + { + throw new BadRequestException("This organization cannot use attachments."); + } + + var storageAdjustment = attachments?.Sum(a => a.Value.Size) ?? 0; + if (org.StorageBytesRemaining() < storageAdjustment) + { + throw new BadRequestException("Not enough storage available for this organization."); + } + + ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); + } } diff --git a/src/Core/Services/Implementations/CollectionService.cs b/src/Core/Services/Implementations/CollectionService.cs index e41532c1e..699f38925 100644 --- a/src/Core/Services/Implementations/CollectionService.cs +++ b/src/Core/Services/Implementations/CollectionService.cs @@ -6,136 +6,135 @@ using Bit.Core.Models.Business; using Bit.Core.Models.Data; using Bit.Core.Repositories; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class CollectionService : ICollectionService { - public class CollectionService : ICollectionService + private readonly IEventService _eventService; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly ICollectionRepository _collectionRepository; + private readonly IUserRepository _userRepository; + private readonly IMailService _mailService; + private readonly IReferenceEventService _referenceEventService; + private readonly ICurrentContext _currentContext; + + public CollectionService( + IEventService eventService, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository, + IUserRepository userRepository, + IMailService mailService, + IReferenceEventService referenceEventService, + ICurrentContext currentContext) { - private readonly IEventService _eventService; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly ICollectionRepository _collectionRepository; - private readonly IUserRepository _userRepository; - private readonly IMailService _mailService; - private readonly IReferenceEventService _referenceEventService; - private readonly ICurrentContext _currentContext; + _eventService = eventService; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _collectionRepository = collectionRepository; + _userRepository = userRepository; + _mailService = mailService; + _referenceEventService = referenceEventService; + _currentContext = currentContext; + } - public CollectionService( - IEventService eventService, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - ICollectionRepository collectionRepository, - IUserRepository userRepository, - IMailService mailService, - IReferenceEventService referenceEventService, - ICurrentContext currentContext) + public async Task SaveAsync(Collection collection, IEnumerable groups = null, + Guid? assignUserId = null) + { + var org = await _organizationRepository.GetByIdAsync(collection.OrganizationId); + if (org == null) { - _eventService = eventService; - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _collectionRepository = collectionRepository; - _userRepository = userRepository; - _mailService = mailService; - _referenceEventService = referenceEventService; - _currentContext = currentContext; + throw new BadRequestException("Organization not found"); } - public async Task SaveAsync(Collection collection, IEnumerable groups = null, - Guid? assignUserId = null) + if (collection.Id == default(Guid)) { - var org = await _organizationRepository.GetByIdAsync(collection.OrganizationId); - if (org == null) + if (org.MaxCollections.HasValue) { - throw new BadRequestException("Organization not found"); + var collectionCount = await _collectionRepository.GetCountByOrganizationIdAsync(org.Id); + if (org.MaxCollections.Value <= collectionCount) + { + throw new BadRequestException("You have reached the maximum number of collections " + + $"({org.MaxCollections.Value}) for this organization."); + } } - if (collection.Id == default(Guid)) + if (groups == null || !org.UseGroups) { - if (org.MaxCollections.HasValue) - { - var collectionCount = await _collectionRepository.GetCountByOrganizationIdAsync(org.Id); - if (org.MaxCollections.Value <= collectionCount) - { - throw new BadRequestException("You have reached the maximum number of collections " + - $"({org.MaxCollections.Value}) for this organization."); - } - } - - if (groups == null || !org.UseGroups) - { - await _collectionRepository.CreateAsync(collection); - } - else - { - await _collectionRepository.CreateAsync(collection, groups); - } - - // Assign a user to the newly created collection. - if (assignUserId.HasValue) - { - var orgUser = await _organizationUserRepository.GetByOrganizationAsync(org.Id, assignUserId.Value); - if (orgUser != null && orgUser.Status == Enums.OrganizationUserStatusType.Confirmed) - { - await _collectionRepository.UpdateUsersAsync(collection.Id, - new List { - new SelectionReadOnly { Id = orgUser.Id, ReadOnly = false } }); - } - } - - await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Created); - await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.CollectionCreated, org)); + await _collectionRepository.CreateAsync(collection); } else { - if (!org.UseGroups) + await _collectionRepository.CreateAsync(collection, groups); + } + + // Assign a user to the newly created collection. + if (assignUserId.HasValue) + { + var orgUser = await _organizationUserRepository.GetByOrganizationAsync(org.Id, assignUserId.Value); + if (orgUser != null && orgUser.Status == Enums.OrganizationUserStatusType.Confirmed) { - await _collectionRepository.ReplaceAsync(collection); + await _collectionRepository.UpdateUsersAsync(collection.Id, + new List { + new SelectionReadOnly { Id = orgUser.Id, ReadOnly = false } }); } - else - { - await _collectionRepository.ReplaceAsync(collection, groups ?? new List()); - } - - await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Updated); - } - } - - public async Task DeleteAsync(Collection collection) - { - await _collectionRepository.DeleteAsync(collection); - await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Deleted); - } - - public async Task DeleteUserAsync(Collection collection, Guid organizationUserId) - { - var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); - if (orgUser == null || orgUser.OrganizationId != collection.OrganizationId) - { - throw new NotFoundException(); - } - await _collectionRepository.DeleteUserAsync(collection.Id, organizationUserId); - await _eventService.LogOrganizationUserEventAsync(orgUser, Enums.EventType.OrganizationUser_Updated); - } - - public async Task> GetOrganizationCollections(Guid organizationId) - { - if (!await _currentContext.ViewAllCollections(organizationId) && !await _currentContext.ManageUsers(organizationId)) - { - throw new NotFoundException(); } - IEnumerable orgCollections; - if (await _currentContext.OrganizationAdmin(organizationId) || await _currentContext.ViewAllCollections(organizationId)) + await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Created); + await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.CollectionCreated, org)); + } + else + { + if (!org.UseGroups) { - // Admins, Owners, Providers and Custom (with collection management permissions) can access all items even if not assigned to them - orgCollections = await _collectionRepository.GetManyByOrganizationIdAsync(organizationId); + await _collectionRepository.ReplaceAsync(collection); } else { - var collections = await _collectionRepository.GetManyByUserIdAsync(_currentContext.UserId.Value); - orgCollections = collections.Where(c => c.OrganizationId == organizationId); + await _collectionRepository.ReplaceAsync(collection, groups ?? new List()); } - return orgCollections; + await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Updated); } } + + public async Task DeleteAsync(Collection collection) + { + await _collectionRepository.DeleteAsync(collection); + await _eventService.LogCollectionEventAsync(collection, Enums.EventType.Collection_Deleted); + } + + public async Task DeleteUserAsync(Collection collection, Guid organizationUserId) + { + var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); + if (orgUser == null || orgUser.OrganizationId != collection.OrganizationId) + { + throw new NotFoundException(); + } + await _collectionRepository.DeleteUserAsync(collection.Id, organizationUserId); + await _eventService.LogOrganizationUserEventAsync(orgUser, Enums.EventType.OrganizationUser_Updated); + } + + public async Task> GetOrganizationCollections(Guid organizationId) + { + if (!await _currentContext.ViewAllCollections(organizationId) && !await _currentContext.ManageUsers(organizationId)) + { + throw new NotFoundException(); + } + + IEnumerable orgCollections; + if (await _currentContext.OrganizationAdmin(organizationId) || await _currentContext.ViewAllCollections(organizationId)) + { + // Admins, Owners, Providers and Custom (with collection management permissions) can access all items even if not assigned to them + orgCollections = await _collectionRepository.GetManyByOrganizationIdAsync(organizationId); + } + else + { + var collections = await _collectionRepository.GetManyByUserIdAsync(_currentContext.UserId.Value); + orgCollections = collections.Where(c => c.OrganizationId == organizationId); + } + + return orgCollections; + } } diff --git a/src/Core/Services/Implementations/DeviceService.cs b/src/Core/Services/Implementations/DeviceService.cs index a65a49bdd..99f4648a3 100644 --- a/src/Core/Services/Implementations/DeviceService.cs +++ b/src/Core/Services/Implementations/DeviceService.cs @@ -1,47 +1,46 @@ using Bit.Core.Entities; using Bit.Core.Repositories; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class DeviceService : IDeviceService { - public class DeviceService : IDeviceService + private readonly IDeviceRepository _deviceRepository; + private readonly IPushRegistrationService _pushRegistrationService; + + public DeviceService( + IDeviceRepository deviceRepository, + IPushRegistrationService pushRegistrationService) { - private readonly IDeviceRepository _deviceRepository; - private readonly IPushRegistrationService _pushRegistrationService; + _deviceRepository = deviceRepository; + _pushRegistrationService = pushRegistrationService; + } - public DeviceService( - IDeviceRepository deviceRepository, - IPushRegistrationService pushRegistrationService) + public async Task SaveAsync(Device device) + { + if (device.Id == default(Guid)) { - _deviceRepository = deviceRepository; - _pushRegistrationService = pushRegistrationService; + await _deviceRepository.CreateAsync(device); + } + else + { + device.RevisionDate = DateTime.UtcNow; + await _deviceRepository.ReplaceAsync(device); } - public async Task SaveAsync(Device device) - { - if (device.Id == default(Guid)) - { - await _deviceRepository.CreateAsync(device); - } - else - { - device.RevisionDate = DateTime.UtcNow; - await _deviceRepository.ReplaceAsync(device); - } + await _pushRegistrationService.CreateOrUpdateRegistrationAsync(device.PushToken, device.Id.ToString(), + device.UserId.ToString(), device.Identifier, device.Type); + } - await _pushRegistrationService.CreateOrUpdateRegistrationAsync(device.PushToken, device.Id.ToString(), - device.UserId.ToString(), device.Identifier, device.Type); - } + public async Task ClearTokenAsync(Device device) + { + await _deviceRepository.ClearPushTokenAsync(device.Id); + await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString()); + } - public async Task ClearTokenAsync(Device device) - { - await _deviceRepository.ClearPushTokenAsync(device.Id); - await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString()); - } - - public async Task DeleteAsync(Device device) - { - await _deviceRepository.DeleteAsync(device); - await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString()); - } + public async Task DeleteAsync(Device device) + { + await _deviceRepository.DeleteAsync(device); + await _pushRegistrationService.DeleteRegistrationAsync(device.Id.ToString()); } } diff --git a/src/Core/Services/Implementations/EmergencyAccessService.cs b/src/Core/Services/Implementations/EmergencyAccessService.cs index 06a5e5a85..e48000b52 100644 --- a/src/Core/Services/Implementations/EmergencyAccessService.cs +++ b/src/Core/Services/Implementations/EmergencyAccessService.cs @@ -9,416 +9,415 @@ using Bit.Core.Settings; using Bit.Core.Tokens; using Microsoft.AspNetCore.Identity; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class EmergencyAccessService : IEmergencyAccessService { - public class EmergencyAccessService : IEmergencyAccessService + private readonly IEmergencyAccessRepository _emergencyAccessRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IUserRepository _userRepository; + private readonly ICipherRepository _cipherRepository; + private readonly IPolicyRepository _policyRepository; + private readonly ICipherService _cipherService; + private readonly IMailService _mailService; + private readonly IUserService _userService; + private readonly GlobalSettings _globalSettings; + private readonly IPasswordHasher _passwordHasher; + private readonly IOrganizationService _organizationService; + private readonly IDataProtectorTokenFactory _dataProtectorTokenizer; + + public EmergencyAccessService( + IEmergencyAccessRepository emergencyAccessRepository, + IOrganizationUserRepository organizationUserRepository, + IUserRepository userRepository, + ICipherRepository cipherRepository, + IPolicyRepository policyRepository, + ICipherService cipherService, + IMailService mailService, + IUserService userService, + IPasswordHasher passwordHasher, + GlobalSettings globalSettings, + IOrganizationService organizationService, + IDataProtectorTokenFactory dataProtectorTokenizer) { - private readonly IEmergencyAccessRepository _emergencyAccessRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IUserRepository _userRepository; - private readonly ICipherRepository _cipherRepository; - private readonly IPolicyRepository _policyRepository; - private readonly ICipherService _cipherService; - private readonly IMailService _mailService; - private readonly IUserService _userService; - private readonly GlobalSettings _globalSettings; - private readonly IPasswordHasher _passwordHasher; - private readonly IOrganizationService _organizationService; - private readonly IDataProtectorTokenFactory _dataProtectorTokenizer; + _emergencyAccessRepository = emergencyAccessRepository; + _organizationUserRepository = organizationUserRepository; + _userRepository = userRepository; + _cipherRepository = cipherRepository; + _policyRepository = policyRepository; + _cipherService = cipherService; + _mailService = mailService; + _userService = userService; + _passwordHasher = passwordHasher; + _globalSettings = globalSettings; + _organizationService = organizationService; + _dataProtectorTokenizer = dataProtectorTokenizer; + } - public EmergencyAccessService( - IEmergencyAccessRepository emergencyAccessRepository, - IOrganizationUserRepository organizationUserRepository, - IUserRepository userRepository, - ICipherRepository cipherRepository, - IPolicyRepository policyRepository, - ICipherService cipherService, - IMailService mailService, - IUserService userService, - IPasswordHasher passwordHasher, - GlobalSettings globalSettings, - IOrganizationService organizationService, - IDataProtectorTokenFactory dataProtectorTokenizer) + public async Task InviteAsync(User invitingUser, string email, EmergencyAccessType type, int waitTime) + { + if (!await _userService.CanAccessPremium(invitingUser)) { - _emergencyAccessRepository = emergencyAccessRepository; - _organizationUserRepository = organizationUserRepository; - _userRepository = userRepository; - _cipherRepository = cipherRepository; - _policyRepository = policyRepository; - _cipherService = cipherService; - _mailService = mailService; - _userService = userService; - _passwordHasher = passwordHasher; - _globalSettings = globalSettings; - _organizationService = organizationService; - _dataProtectorTokenizer = dataProtectorTokenizer; + throw new BadRequestException("Not a premium user."); } - public async Task InviteAsync(User invitingUser, string email, EmergencyAccessType type, int waitTime) + if (type == EmergencyAccessType.Takeover && invitingUser.UsesKeyConnector) { - if (!await _userService.CanAccessPremium(invitingUser)) - { - throw new BadRequestException("Not a premium user."); - } + throw new BadRequestException("You cannot use Emergency Access Takeover because you are using Key Connector."); + } - if (type == EmergencyAccessType.Takeover && invitingUser.UsesKeyConnector) + var emergencyAccess = new EmergencyAccess + { + GrantorId = invitingUser.Id, + Email = email.ToLowerInvariant(), + Status = EmergencyAccessStatusType.Invited, + Type = type, + WaitTimeDays = waitTime, + CreationDate = DateTime.UtcNow, + RevisionDate = DateTime.UtcNow, + }; + + await _emergencyAccessRepository.CreateAsync(emergencyAccess); + await SendInviteAsync(emergencyAccess, NameOrEmail(invitingUser)); + + return emergencyAccess; + } + + public async Task GetAsync(Guid emergencyAccessId, Guid userId) + { + var emergencyAccess = await _emergencyAccessRepository.GetDetailsByIdGrantorIdAsync(emergencyAccessId, userId); + if (emergencyAccess == null) + { + throw new BadRequestException("Emergency Access not valid."); + } + + return emergencyAccess; + } + + public async Task ResendInviteAsync(User invitingUser, Guid emergencyAccessId) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); + if (emergencyAccess == null || emergencyAccess.GrantorId != invitingUser.Id || + emergencyAccess.Status != EmergencyAccessStatusType.Invited) + { + throw new BadRequestException("Emergency Access not valid."); + } + + await SendInviteAsync(emergencyAccess, NameOrEmail(invitingUser)); + } + + public async Task AcceptUserAsync(Guid emergencyAccessId, User user, string token, IUserService userService) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); + if (emergencyAccess == null) + { + throw new BadRequestException("Emergency Access not valid."); + } + + if (!_dataProtectorTokenizer.TryUnprotect(token, out var data) && data.IsValid(emergencyAccessId, user.Email)) + { + throw new BadRequestException("Invalid token."); + } + + if (emergencyAccess.Status == EmergencyAccessStatusType.Accepted) + { + throw new BadRequestException("Invitation already accepted. You will receive an email when the grantor confirms you as an emergency access contact."); + } + else if (emergencyAccess.Status != EmergencyAccessStatusType.Invited) + { + throw new BadRequestException("Invitation already accepted."); + } + + if (string.IsNullOrWhiteSpace(emergencyAccess.Email) || + !emergencyAccess.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase)) + { + throw new BadRequestException("User email does not match invite."); + } + + var granteeEmail = emergencyAccess.Email; + + emergencyAccess.Status = EmergencyAccessStatusType.Accepted; + emergencyAccess.GranteeId = user.Id; + emergencyAccess.Email = null; + + var grantor = await userService.GetUserByIdAsync(emergencyAccess.GrantorId); + + await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); + await _mailService.SendEmergencyAccessAcceptedEmailAsync(granteeEmail, grantor.Email); + + return emergencyAccess; + } + + public async Task DeleteAsync(Guid emergencyAccessId, Guid grantorId) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); + if (emergencyAccess == null || (emergencyAccess.GrantorId != grantorId && emergencyAccess.GranteeId != grantorId)) + { + throw new BadRequestException("Emergency Access not valid."); + } + + await _emergencyAccessRepository.DeleteAsync(emergencyAccess); + } + + public async Task ConfirmUserAsync(Guid emergencyAcccessId, string key, Guid confirmingUserId) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAcccessId); + if (emergencyAccess == null || emergencyAccess.Status != EmergencyAccessStatusType.Accepted || + emergencyAccess.GrantorId != confirmingUserId) + { + throw new BadRequestException("Emergency Access not valid."); + } + + var grantor = await _userRepository.GetByIdAsync(confirmingUserId); + if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector) + { + throw new BadRequestException("You cannot use Emergency Access Takeover because you are using Key Connector."); + } + + var grantee = await _userRepository.GetByIdAsync(emergencyAccess.GranteeId.Value); + + emergencyAccess.Status = EmergencyAccessStatusType.Confirmed; + emergencyAccess.KeyEncrypted = key; + emergencyAccess.Email = null; + await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); + await _mailService.SendEmergencyAccessConfirmedEmailAsync(NameOrEmail(grantor), grantee.Email); + + return emergencyAccess; + } + + public async Task SaveAsync(EmergencyAccess emergencyAccess, User savingUser) + { + if (!await _userService.CanAccessPremium(savingUser)) + { + throw new BadRequestException("Not a premium user."); + } + + if (emergencyAccess.GrantorId != savingUser.Id) + { + throw new BadRequestException("Emergency Access not valid."); + } + + if (emergencyAccess.Type == EmergencyAccessType.Takeover) + { + var grantor = await _userService.GetUserByIdAsync(emergencyAccess.GrantorId); + if (grantor.UsesKeyConnector) { throw new BadRequestException("You cannot use Emergency Access Takeover because you are using Key Connector."); } - - var emergencyAccess = new EmergencyAccess - { - GrantorId = invitingUser.Id, - Email = email.ToLowerInvariant(), - Status = EmergencyAccessStatusType.Invited, - Type = type, - WaitTimeDays = waitTime, - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow, - }; - - await _emergencyAccessRepository.CreateAsync(emergencyAccess); - await SendInviteAsync(emergencyAccess, NameOrEmail(invitingUser)); - - return emergencyAccess; } - public async Task GetAsync(Guid emergencyAccessId, Guid userId) - { - var emergencyAccess = await _emergencyAccessRepository.GetDetailsByIdGrantorIdAsync(emergencyAccessId, userId); - if (emergencyAccess == null) - { - throw new BadRequestException("Emergency Access not valid."); - } + await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); + } - return emergencyAccess; + public async Task InitiateAsync(Guid id, User initiatingUser) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + + if (emergencyAccess == null || emergencyAccess.GranteeId != initiatingUser.Id || + emergencyAccess.Status != EmergencyAccessStatusType.Confirmed) + { + throw new BadRequestException("Emergency Access not valid."); } - public async Task ResendInviteAsync(User invitingUser, Guid emergencyAccessId) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); - if (emergencyAccess == null || emergencyAccess.GrantorId != invitingUser.Id || - emergencyAccess.Status != EmergencyAccessStatusType.Invited) - { - throw new BadRequestException("Emergency Access not valid."); - } + var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); - await SendInviteAsync(emergencyAccess, NameOrEmail(invitingUser)); + if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector) + { + throw new BadRequestException("You cannot takeover an account that is using Key Connector."); } - public async Task AcceptUserAsync(Guid emergencyAccessId, User user, string token, IUserService userService) + var now = DateTime.UtcNow; + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryInitiated; + emergencyAccess.RevisionDate = now; + emergencyAccess.RecoveryInitiatedDate = now; + emergencyAccess.LastNotificationDate = now; + await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); + + await _mailService.SendEmergencyAccessRecoveryInitiated(emergencyAccess, NameOrEmail(initiatingUser), grantor.Email); + } + + public async Task ApproveAsync(Guid id, User approvingUser) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + + if (emergencyAccess == null || emergencyAccess.GrantorId != approvingUser.Id || + emergencyAccess.Status != EmergencyAccessStatusType.RecoveryInitiated) { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); - if (emergencyAccess == null) - { - throw new BadRequestException("Emergency Access not valid."); - } - - if (!_dataProtectorTokenizer.TryUnprotect(token, out var data) && data.IsValid(emergencyAccessId, user.Email)) - { - throw new BadRequestException("Invalid token."); - } - - if (emergencyAccess.Status == EmergencyAccessStatusType.Accepted) - { - throw new BadRequestException("Invitation already accepted. You will receive an email when the grantor confirms you as an emergency access contact."); - } - else if (emergencyAccess.Status != EmergencyAccessStatusType.Invited) - { - throw new BadRequestException("Invitation already accepted."); - } - - if (string.IsNullOrWhiteSpace(emergencyAccess.Email) || - !emergencyAccess.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase)) - { - throw new BadRequestException("User email does not match invite."); - } - - var granteeEmail = emergencyAccess.Email; - - emergencyAccess.Status = EmergencyAccessStatusType.Accepted; - emergencyAccess.GranteeId = user.Id; - emergencyAccess.Email = null; - - var grantor = await userService.GetUserByIdAsync(emergencyAccess.GrantorId); - - await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); - await _mailService.SendEmergencyAccessAcceptedEmailAsync(granteeEmail, grantor.Email); - - return emergencyAccess; + throw new BadRequestException("Emergency Access not valid."); } - public async Task DeleteAsync(Guid emergencyAccessId, Guid grantorId) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); - if (emergencyAccess == null || (emergencyAccess.GrantorId != grantorId && emergencyAccess.GranteeId != grantorId)) - { - throw new BadRequestException("Emergency Access not valid."); - } + emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved; + await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); - await _emergencyAccessRepository.DeleteAsync(emergencyAccess); + var grantee = await _userRepository.GetByIdAsync(emergencyAccess.GranteeId.Value); + await _mailService.SendEmergencyAccessRecoveryApproved(emergencyAccess, NameOrEmail(approvingUser), grantee.Email); + } + + public async Task RejectAsync(Guid id, User rejectingUser) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + + if (emergencyAccess == null || emergencyAccess.GrantorId != rejectingUser.Id || + (emergencyAccess.Status != EmergencyAccessStatusType.RecoveryInitiated && + emergencyAccess.Status != EmergencyAccessStatusType.RecoveryApproved)) + { + throw new BadRequestException("Emergency Access not valid."); } - public async Task ConfirmUserAsync(Guid emergencyAcccessId, string key, Guid confirmingUserId) + emergencyAccess.Status = EmergencyAccessStatusType.Confirmed; + await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); + + var grantee = await _userRepository.GetByIdAsync(emergencyAccess.GranteeId.Value); + await _mailService.SendEmergencyAccessRecoveryRejected(emergencyAccess, NameOrEmail(rejectingUser), grantee.Email); + } + + public async Task> GetPoliciesAsync(Guid id, User requestingUser) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + + if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover)) { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAcccessId); - if (emergencyAccess == null || emergencyAccess.Status != EmergencyAccessStatusType.Accepted || - emergencyAccess.GrantorId != confirmingUserId) - { - throw new BadRequestException("Emergency Access not valid."); - } - - var grantor = await _userRepository.GetByIdAsync(confirmingUserId); - if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector) - { - throw new BadRequestException("You cannot use Emergency Access Takeover because you are using Key Connector."); - } - - var grantee = await _userRepository.GetByIdAsync(emergencyAccess.GranteeId.Value); - - emergencyAccess.Status = EmergencyAccessStatusType.Confirmed; - emergencyAccess.KeyEncrypted = key; - emergencyAccess.Email = null; - await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); - await _mailService.SendEmergencyAccessConfirmedEmailAsync(NameOrEmail(grantor), grantee.Email); - - return emergencyAccess; + throw new BadRequestException("Emergency Access not valid."); } - public async Task SaveAsync(EmergencyAccess emergencyAccess, User savingUser) + var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); + + var grantorOrganizations = await _organizationUserRepository.GetManyByUserAsync(grantor.Id); + var isOrganizationOwner = grantorOrganizations.Any(organization => organization.Type == OrganizationUserType.Owner); + var policies = isOrganizationOwner ? await _policyRepository.GetManyByUserIdAsync(grantor.Id) : null; + + return policies; + } + + public async Task<(EmergencyAccess, User)> TakeoverAsync(Guid id, User requestingUser) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + + if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover)) { - if (!await _userService.CanAccessPremium(savingUser)) - { - throw new BadRequestException("Not a premium user."); - } - - if (emergencyAccess.GrantorId != savingUser.Id) - { - throw new BadRequestException("Emergency Access not valid."); - } - - if (emergencyAccess.Type == EmergencyAccessType.Takeover) - { - var grantor = await _userService.GetUserByIdAsync(emergencyAccess.GrantorId); - if (grantor.UsesKeyConnector) - { - throw new BadRequestException("You cannot use Emergency Access Takeover because you are using Key Connector."); - } - } - - await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); + throw new BadRequestException("Emergency Access not valid."); } - public async Task InitiateAsync(Guid id, User initiatingUser) + var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); + + if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector) { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); - - if (emergencyAccess == null || emergencyAccess.GranteeId != initiatingUser.Id || - emergencyAccess.Status != EmergencyAccessStatusType.Confirmed) - { - throw new BadRequestException("Emergency Access not valid."); - } - - var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); - - if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector) - { - throw new BadRequestException("You cannot takeover an account that is using Key Connector."); - } - - var now = DateTime.UtcNow; - emergencyAccess.Status = EmergencyAccessStatusType.RecoveryInitiated; - emergencyAccess.RevisionDate = now; - emergencyAccess.RecoveryInitiatedDate = now; - emergencyAccess.LastNotificationDate = now; - await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); - - await _mailService.SendEmergencyAccessRecoveryInitiated(emergencyAccess, NameOrEmail(initiatingUser), grantor.Email); + throw new BadRequestException("You cannot takeover an account that is using Key Connector."); } - public async Task ApproveAsync(Guid id, User approvingUser) + return (emergencyAccess, grantor); + } + + public async Task PasswordAsync(Guid id, User requestingUser, string newMasterPasswordHash, string key) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + + if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover)) { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); - - if (emergencyAccess == null || emergencyAccess.GrantorId != approvingUser.Id || - emergencyAccess.Status != EmergencyAccessStatusType.RecoveryInitiated) - { - throw new BadRequestException("Emergency Access not valid."); - } - - emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved; - await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); - - var grantee = await _userRepository.GetByIdAsync(emergencyAccess.GranteeId.Value); - await _mailService.SendEmergencyAccessRecoveryApproved(emergencyAccess, NameOrEmail(approvingUser), grantee.Email); + throw new BadRequestException("Emergency Access not valid."); } - public async Task RejectAsync(Guid id, User rejectingUser) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); - if (emergencyAccess == null || emergencyAccess.GrantorId != rejectingUser.Id || - (emergencyAccess.Status != EmergencyAccessStatusType.RecoveryInitiated && - emergencyAccess.Status != EmergencyAccessStatusType.RecoveryApproved)) + grantor.MasterPassword = _passwordHasher.HashPassword(grantor, newMasterPasswordHash); + grantor.Key = key; + // Disable TwoFactor providers since they will otherwise block logins + grantor.SetTwoFactorProviders(new Dictionary()); + grantor.UnknownDeviceVerificationEnabled = false; + await _userRepository.ReplaceAsync(grantor); + + // Remove grantor from all organizations unless Owner + var orgUser = await _organizationUserRepository.GetManyByUserAsync(grantor.Id); + foreach (var o in orgUser) + { + if (o.Type != OrganizationUserType.Owner) { - throw new BadRequestException("Emergency Access not valid."); + await _organizationService.DeleteUserAsync(o.OrganizationId, grantor.Id); } - - emergencyAccess.Status = EmergencyAccessStatusType.Confirmed; - await _emergencyAccessRepository.ReplaceAsync(emergencyAccess); - - var grantee = await _userRepository.GetByIdAsync(emergencyAccess.GranteeId.Value); - await _mailService.SendEmergencyAccessRecoveryRejected(emergencyAccess, NameOrEmail(rejectingUser), grantee.Email); - } - - public async Task> GetPoliciesAsync(Guid id, User requestingUser) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); - - if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover)) - { - throw new BadRequestException("Emergency Access not valid."); - } - - var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); - - var grantorOrganizations = await _organizationUserRepository.GetManyByUserAsync(grantor.Id); - var isOrganizationOwner = grantorOrganizations.Any(organization => organization.Type == OrganizationUserType.Owner); - var policies = isOrganizationOwner ? await _policyRepository.GetManyByUserIdAsync(grantor.Id) : null; - - return policies; - } - - public async Task<(EmergencyAccess, User)> TakeoverAsync(Guid id, User requestingUser) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); - - if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover)) - { - throw new BadRequestException("Emergency Access not valid."); - } - - var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); - - if (emergencyAccess.Type == EmergencyAccessType.Takeover && grantor.UsesKeyConnector) - { - throw new BadRequestException("You cannot takeover an account that is using Key Connector."); - } - - return (emergencyAccess, grantor); - } - - public async Task PasswordAsync(Guid id, User requestingUser, string newMasterPasswordHash, string key) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); - - if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.Takeover)) - { - throw new BadRequestException("Emergency Access not valid."); - } - - var grantor = await _userRepository.GetByIdAsync(emergencyAccess.GrantorId); - - grantor.MasterPassword = _passwordHasher.HashPassword(grantor, newMasterPasswordHash); - grantor.Key = key; - // Disable TwoFactor providers since they will otherwise block logins - grantor.SetTwoFactorProviders(new Dictionary()); - grantor.UnknownDeviceVerificationEnabled = false; - await _userRepository.ReplaceAsync(grantor); - - // Remove grantor from all organizations unless Owner - var orgUser = await _organizationUserRepository.GetManyByUserAsync(grantor.Id); - foreach (var o in orgUser) - { - if (o.Type != OrganizationUserType.Owner) - { - await _organizationService.DeleteUserAsync(o.OrganizationId, grantor.Id); - } - } - } - - public async Task SendNotificationsAsync() - { - var toNotify = await _emergencyAccessRepository.GetManyToNotifyAsync(); - - foreach (var notify in toNotify) - { - var ea = notify.ToEmergencyAccess(); - ea.LastNotificationDate = DateTime.UtcNow; - await _emergencyAccessRepository.ReplaceAsync(ea); - - var granteeNameOrEmail = string.IsNullOrWhiteSpace(notify.GranteeName) ? notify.GranteeEmail : notify.GranteeName; - - await _mailService.SendEmergencyAccessRecoveryReminder(ea, granteeNameOrEmail, notify.GrantorEmail); - } - } - - public async Task HandleTimedOutRequestsAsync() - { - var expired = await _emergencyAccessRepository.GetExpiredRecoveriesAsync(); - - foreach (var details in expired) - { - var ea = details.ToEmergencyAccess(); - ea.Status = EmergencyAccessStatusType.RecoveryApproved; - await _emergencyAccessRepository.ReplaceAsync(ea); - - var grantorNameOrEmail = string.IsNullOrWhiteSpace(details.GrantorName) ? details.GrantorEmail : details.GrantorName; - var granteeNameOrEmail = string.IsNullOrWhiteSpace(details.GranteeName) ? details.GranteeEmail : details.GranteeName; - - await _mailService.SendEmergencyAccessRecoveryApproved(ea, grantorNameOrEmail, details.GranteeEmail); - await _mailService.SendEmergencyAccessRecoveryTimedOut(ea, granteeNameOrEmail, details.GrantorEmail); - } - } - - public async Task ViewAsync(Guid id, User requestingUser) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); - - if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.View)) - { - throw new BadRequestException("Emergency Access not valid."); - } - - var ciphers = await _cipherRepository.GetManyByUserIdAsync(emergencyAccess.GrantorId, false); - - return new EmergencyAccessViewData - { - EmergencyAccess = emergencyAccess, - Ciphers = ciphers, - }; - } - - public async Task GetAttachmentDownloadAsync(Guid id, Guid cipherId, string attachmentId, User requestingUser) - { - var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); - - if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.View)) - { - throw new BadRequestException("Emergency Access not valid."); - } - - var cipher = await _cipherRepository.GetByIdAsync(cipherId, emergencyAccess.GrantorId); - return await _cipherService.GetAttachmentDownloadDataAsync(cipher, attachmentId); - } - - private async Task SendInviteAsync(EmergencyAccess emergencyAccess, string invitingUsersName) - { - var token = _dataProtectorTokenizer.Protect(new EmergencyAccessInviteTokenable(emergencyAccess, _globalSettings.OrganizationInviteExpirationHours)); - await _mailService.SendEmergencyAccessInviteEmailAsync(emergencyAccess, invitingUsersName, token); - } - - private string NameOrEmail(User user) - { - return string.IsNullOrWhiteSpace(user.Name) ? user.Email : user.Name; - } - - private bool IsValidRequest(EmergencyAccess availibleAccess, User requestingUser, EmergencyAccessType requestedAccessType) - { - return availibleAccess != null && - availibleAccess.GranteeId == requestingUser.Id && - availibleAccess.Status == EmergencyAccessStatusType.RecoveryApproved && - availibleAccess.Type == requestedAccessType; } } + + public async Task SendNotificationsAsync() + { + var toNotify = await _emergencyAccessRepository.GetManyToNotifyAsync(); + + foreach (var notify in toNotify) + { + var ea = notify.ToEmergencyAccess(); + ea.LastNotificationDate = DateTime.UtcNow; + await _emergencyAccessRepository.ReplaceAsync(ea); + + var granteeNameOrEmail = string.IsNullOrWhiteSpace(notify.GranteeName) ? notify.GranteeEmail : notify.GranteeName; + + await _mailService.SendEmergencyAccessRecoveryReminder(ea, granteeNameOrEmail, notify.GrantorEmail); + } + } + + public async Task HandleTimedOutRequestsAsync() + { + var expired = await _emergencyAccessRepository.GetExpiredRecoveriesAsync(); + + foreach (var details in expired) + { + var ea = details.ToEmergencyAccess(); + ea.Status = EmergencyAccessStatusType.RecoveryApproved; + await _emergencyAccessRepository.ReplaceAsync(ea); + + var grantorNameOrEmail = string.IsNullOrWhiteSpace(details.GrantorName) ? details.GrantorEmail : details.GrantorName; + var granteeNameOrEmail = string.IsNullOrWhiteSpace(details.GranteeName) ? details.GranteeEmail : details.GranteeName; + + await _mailService.SendEmergencyAccessRecoveryApproved(ea, grantorNameOrEmail, details.GranteeEmail); + await _mailService.SendEmergencyAccessRecoveryTimedOut(ea, granteeNameOrEmail, details.GrantorEmail); + } + } + + public async Task ViewAsync(Guid id, User requestingUser) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + + if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.View)) + { + throw new BadRequestException("Emergency Access not valid."); + } + + var ciphers = await _cipherRepository.GetManyByUserIdAsync(emergencyAccess.GrantorId, false); + + return new EmergencyAccessViewData + { + EmergencyAccess = emergencyAccess, + Ciphers = ciphers, + }; + } + + public async Task GetAttachmentDownloadAsync(Guid id, Guid cipherId, string attachmentId, User requestingUser) + { + var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(id); + + if (!IsValidRequest(emergencyAccess, requestingUser, EmergencyAccessType.View)) + { + throw new BadRequestException("Emergency Access not valid."); + } + + var cipher = await _cipherRepository.GetByIdAsync(cipherId, emergencyAccess.GrantorId); + return await _cipherService.GetAttachmentDownloadDataAsync(cipher, attachmentId); + } + + private async Task SendInviteAsync(EmergencyAccess emergencyAccess, string invitingUsersName) + { + var token = _dataProtectorTokenizer.Protect(new EmergencyAccessInviteTokenable(emergencyAccess, _globalSettings.OrganizationInviteExpirationHours)); + await _mailService.SendEmergencyAccessInviteEmailAsync(emergencyAccess, invitingUsersName, token); + } + + private string NameOrEmail(User user) + { + return string.IsNullOrWhiteSpace(user.Name) ? user.Email : user.Name; + } + + private bool IsValidRequest(EmergencyAccess availibleAccess, User requestingUser, EmergencyAccessType requestedAccessType) + { + return availibleAccess != null && + availibleAccess.GranteeId == requestingUser.Id && + availibleAccess.Status == EmergencyAccessStatusType.RecoveryApproved && + availibleAccess.Type == requestedAccessType; + } } diff --git a/src/Core/Services/Implementations/EventService.cs b/src/Core/Services/Implementations/EventService.cs index 3a35555d1..18a4b19cf 100644 --- a/src/Core/Services/Implementations/EventService.cs +++ b/src/Core/Services/Implementations/EventService.cs @@ -7,322 +7,321 @@ using Bit.Core.Models.Data.Organizations; using Bit.Core.Repositories; using Bit.Core.Settings; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class EventService : IEventService { - public class EventService : IEventService + private readonly IEventWriteService _eventWriteService; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IProviderUserRepository _providerUserRepository; + private readonly IApplicationCacheService _applicationCacheService; + private readonly ICurrentContext _currentContext; + private readonly GlobalSettings _globalSettings; + + public EventService( + IEventWriteService eventWriteService, + IOrganizationUserRepository organizationUserRepository, + IProviderUserRepository providerUserRepository, + IApplicationCacheService applicationCacheService, + ICurrentContext currentContext, + GlobalSettings globalSettings) { - private readonly IEventWriteService _eventWriteService; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IProviderUserRepository _providerUserRepository; - private readonly IApplicationCacheService _applicationCacheService; - private readonly ICurrentContext _currentContext; - private readonly GlobalSettings _globalSettings; + _eventWriteService = eventWriteService; + _organizationUserRepository = organizationUserRepository; + _providerUserRepository = providerUserRepository; + _applicationCacheService = applicationCacheService; + _currentContext = currentContext; + _globalSettings = globalSettings; + } - public EventService( - IEventWriteService eventWriteService, - IOrganizationUserRepository organizationUserRepository, - IProviderUserRepository providerUserRepository, - IApplicationCacheService applicationCacheService, - ICurrentContext currentContext, - GlobalSettings globalSettings) + public async Task LogUserEventAsync(Guid userId, EventType type, DateTime? date = null) + { + var events = new List { - _eventWriteService = eventWriteService; - _organizationUserRepository = organizationUserRepository; - _providerUserRepository = providerUserRepository; - _applicationCacheService = applicationCacheService; - _currentContext = currentContext; - _globalSettings = globalSettings; - } - - public async Task LogUserEventAsync(Guid userId, EventType type, DateTime? date = null) - { - var events = new List + new EventMessage(_currentContext) { - new EventMessage(_currentContext) - { - UserId = userId, - ActingUserId = userId, - Type = type, - Date = date.GetValueOrDefault(DateTime.UtcNow) - } - }; - - var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - var orgs = await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, userId); - var orgEvents = orgs.Where(o => CanUseEvents(orgAbilities, o.Id)) - .Select(o => new EventMessage(_currentContext) - { - OrganizationId = o.Id, - UserId = userId, - ActingUserId = userId, - Type = type, - Date = DateTime.UtcNow - }); - - var providerAbilities = await _applicationCacheService.GetProviderAbilitiesAsync(); - var providers = await _currentContext.ProviderMembershipAsync(_providerUserRepository, userId); - var providerEvents = providers.Where(o => CanUseProviderEvents(providerAbilities, o.Id)) - .Select(p => new EventMessage(_currentContext) - { - ProviderId = p.Id, - UserId = userId, - ActingUserId = userId, - Type = type, - Date = DateTime.UtcNow - }); - - if (orgEvents.Any() || providerEvents.Any()) - { - events.AddRange(orgEvents); - events.AddRange(providerEvents); - await _eventWriteService.CreateManyAsync(events); - } - else - { - await _eventWriteService.CreateAsync(events.First()); - } - } - - public async Task LogCipherEventAsync(Cipher cipher, EventType type, DateTime? date = null) - { - var e = await BuildCipherEventMessageAsync(cipher, type, date); - if (e != null) - { - await _eventWriteService.CreateAsync(e); - } - } - - public async Task LogCipherEventsAsync(IEnumerable> events) - { - var cipherEvents = new List(); - foreach (var ev in events) - { - var e = await BuildCipherEventMessageAsync(ev.Item1, ev.Item2, ev.Item3); - if (e != null) - { - cipherEvents.Add(e); - } - } - await _eventWriteService.CreateManyAsync(cipherEvents); - } - - private async Task BuildCipherEventMessageAsync(Cipher cipher, EventType type, DateTime? date = null) - { - // Only logging organization cipher events for now. - if (!cipher.OrganizationId.HasValue || (!_currentContext?.UserId.HasValue ?? true)) - { - return null; - } - - if (cipher.OrganizationId.HasValue) - { - var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - if (!CanUseEvents(orgAbilities, cipher.OrganizationId.Value)) - { - return null; - } - } - - return new EventMessage(_currentContext) - { - OrganizationId = cipher.OrganizationId, - UserId = cipher.OrganizationId.HasValue ? null : cipher.UserId, - CipherId = cipher.Id, + UserId = userId, + ActingUserId = userId, Type = type, - ActingUserId = _currentContext?.UserId, - ProviderId = await GetProviderIdAsync(cipher.OrganizationId), Date = date.GetValueOrDefault(DateTime.UtcNow) - }; - } - - public async Task LogCollectionEventAsync(Collection collection, EventType type, DateTime? date = null) - { - var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - if (!CanUseEvents(orgAbilities, collection.OrganizationId)) - { - return; } + }; - var e = new EventMessage(_currentContext) + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + var orgs = await _currentContext.OrganizationMembershipAsync(_organizationUserRepository, userId); + var orgEvents = orgs.Where(o => CanUseEvents(orgAbilities, o.Id)) + .Select(o => new EventMessage(_currentContext) { - OrganizationId = collection.OrganizationId, - CollectionId = collection.Id, + OrganizationId = o.Id, + UserId = userId, + ActingUserId = userId, Type = type, - ActingUserId = _currentContext?.UserId, - ProviderId = await GetProviderIdAsync(collection.OrganizationId), - Date = date.GetValueOrDefault(DateTime.UtcNow) - }; - await _eventWriteService.CreateAsync(e); - } + Date = DateTime.UtcNow + }); - public async Task LogGroupEventAsync(Group group, EventType type, DateTime? date = null) - { - var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - if (!CanUseEvents(orgAbilities, group.OrganizationId)) + var providerAbilities = await _applicationCacheService.GetProviderAbilitiesAsync(); + var providers = await _currentContext.ProviderMembershipAsync(_providerUserRepository, userId); + var providerEvents = providers.Where(o => CanUseProviderEvents(providerAbilities, o.Id)) + .Select(p => new EventMessage(_currentContext) { - return; - } - - var e = new EventMessage(_currentContext) - { - OrganizationId = group.OrganizationId, - GroupId = group.Id, + ProviderId = p.Id, + UserId = userId, + ActingUserId = userId, Type = type, - ActingUserId = _currentContext?.UserId, - ProviderId = await GetProviderIdAsync(@group.OrganizationId), - Date = date.GetValueOrDefault(DateTime.UtcNow) - }; - await _eventWriteService.CreateAsync(e); - } + Date = DateTime.UtcNow + }); - public async Task LogPolicyEventAsync(Policy policy, EventType type, DateTime? date = null) + if (orgEvents.Any() || providerEvents.Any()) { - var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - if (!CanUseEvents(orgAbilities, policy.OrganizationId)) - { - return; - } - - var e = new EventMessage(_currentContext) - { - OrganizationId = policy.OrganizationId, - PolicyId = policy.Id, - Type = type, - ActingUserId = _currentContext?.UserId, - ProviderId = await GetProviderIdAsync(policy.OrganizationId), - Date = date.GetValueOrDefault(DateTime.UtcNow) - }; - await _eventWriteService.CreateAsync(e); + events.AddRange(orgEvents); + events.AddRange(providerEvents); + await _eventWriteService.CreateManyAsync(events); } - - public async Task LogOrganizationUserEventAsync(OrganizationUser organizationUser, EventType type, - DateTime? date = null) => - await LogOrganizationUserEventsAsync(new[] { (organizationUser, type, date) }); - - public async Task LogOrganizationUserEventsAsync(IEnumerable<(OrganizationUser, EventType, DateTime?)> events) + else { - var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - var eventMessages = new List(); - foreach (var (organizationUser, type, date) in events) - { - if (!CanUseEvents(orgAbilities, organizationUser.OrganizationId)) - { - continue; - } - - eventMessages.Add(new EventMessage(_currentContext) - { - OrganizationId = organizationUser.OrganizationId, - UserId = organizationUser.UserId, - OrganizationUserId = organizationUser.Id, - ProviderId = await GetProviderIdAsync(organizationUser.OrganizationId), - Type = type, - ActingUserId = _currentContext?.UserId, - Date = date.GetValueOrDefault(DateTime.UtcNow) - }); - } - - await _eventWriteService.CreateManyAsync(eventMessages); - } - - public async Task LogOrganizationEventAsync(Organization organization, EventType type, DateTime? date = null) - { - if (!organization.Enabled || !organization.UseEvents) - { - return; - } - - var e = new EventMessage(_currentContext) - { - OrganizationId = organization.Id, - ProviderId = await GetProviderIdAsync(organization.Id), - Type = type, - ActingUserId = _currentContext?.UserId, - Date = date.GetValueOrDefault(DateTime.UtcNow), - InstallationId = GetInstallationId(), - }; - await _eventWriteService.CreateAsync(e); - } - - public async Task LogProviderUserEventAsync(ProviderUser providerUser, EventType type, DateTime? date = null) - { - await LogProviderUsersEventAsync(new[] { (providerUser, type, date) }); - } - - public async Task LogProviderUsersEventAsync(IEnumerable<(ProviderUser, EventType, DateTime?)> events) - { - var providerAbilities = await _applicationCacheService.GetProviderAbilitiesAsync(); - var eventMessages = new List(); - foreach (var (providerUser, type, date) in events) - { - if (!CanUseProviderEvents(providerAbilities, providerUser.ProviderId)) - { - continue; - } - eventMessages.Add(new EventMessage(_currentContext) - { - ProviderId = providerUser.ProviderId, - UserId = providerUser.UserId, - ProviderUserId = providerUser.Id, - Type = type, - ActingUserId = _currentContext?.UserId, - Date = date.GetValueOrDefault(DateTime.UtcNow) - }); - } - - await _eventWriteService.CreateManyAsync(eventMessages); - } - - public async Task LogProviderOrganizationEventAsync(ProviderOrganization providerOrganization, EventType type, - DateTime? date = null) - { - var providerAbilities = await _applicationCacheService.GetProviderAbilitiesAsync(); - if (!CanUseProviderEvents(providerAbilities, providerOrganization.ProviderId)) - { - return; - } - - var e = new EventMessage(_currentContext) - { - ProviderId = providerOrganization.ProviderId, - ProviderOrganizationId = providerOrganization.Id, - Type = type, - ActingUserId = _currentContext?.UserId, - Date = date.GetValueOrDefault(DateTime.UtcNow) - }; - await _eventWriteService.CreateAsync(e); - } - - private async Task GetProviderIdAsync(Guid? orgId) - { - if (_currentContext == null || !orgId.HasValue) - { - return null; - } - - return await _currentContext.ProviderIdForOrg(orgId.Value); - } - - private Guid? GetInstallationId() - { - if (_currentContext == null) - { - return null; - } - - return _currentContext.InstallationId; - } - - private bool CanUseEvents(IDictionary orgAbilities, Guid orgId) - { - return orgAbilities != null && orgAbilities.ContainsKey(orgId) && - orgAbilities[orgId].Enabled && orgAbilities[orgId].UseEvents; - } - - private bool CanUseProviderEvents(IDictionary providerAbilities, Guid providerId) - { - return providerAbilities != null && providerAbilities.ContainsKey(providerId) && - providerAbilities[providerId].Enabled && providerAbilities[providerId].UseEvents; + await _eventWriteService.CreateAsync(events.First()); } } + + public async Task LogCipherEventAsync(Cipher cipher, EventType type, DateTime? date = null) + { + var e = await BuildCipherEventMessageAsync(cipher, type, date); + if (e != null) + { + await _eventWriteService.CreateAsync(e); + } + } + + public async Task LogCipherEventsAsync(IEnumerable> events) + { + var cipherEvents = new List(); + foreach (var ev in events) + { + var e = await BuildCipherEventMessageAsync(ev.Item1, ev.Item2, ev.Item3); + if (e != null) + { + cipherEvents.Add(e); + } + } + await _eventWriteService.CreateManyAsync(cipherEvents); + } + + private async Task BuildCipherEventMessageAsync(Cipher cipher, EventType type, DateTime? date = null) + { + // Only logging organization cipher events for now. + if (!cipher.OrganizationId.HasValue || (!_currentContext?.UserId.HasValue ?? true)) + { + return null; + } + + if (cipher.OrganizationId.HasValue) + { + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + if (!CanUseEvents(orgAbilities, cipher.OrganizationId.Value)) + { + return null; + } + } + + return new EventMessage(_currentContext) + { + OrganizationId = cipher.OrganizationId, + UserId = cipher.OrganizationId.HasValue ? null : cipher.UserId, + CipherId = cipher.Id, + Type = type, + ActingUserId = _currentContext?.UserId, + ProviderId = await GetProviderIdAsync(cipher.OrganizationId), + Date = date.GetValueOrDefault(DateTime.UtcNow) + }; + } + + public async Task LogCollectionEventAsync(Collection collection, EventType type, DateTime? date = null) + { + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + if (!CanUseEvents(orgAbilities, collection.OrganizationId)) + { + return; + } + + var e = new EventMessage(_currentContext) + { + OrganizationId = collection.OrganizationId, + CollectionId = collection.Id, + Type = type, + ActingUserId = _currentContext?.UserId, + ProviderId = await GetProviderIdAsync(collection.OrganizationId), + Date = date.GetValueOrDefault(DateTime.UtcNow) + }; + await _eventWriteService.CreateAsync(e); + } + + public async Task LogGroupEventAsync(Group group, EventType type, DateTime? date = null) + { + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + if (!CanUseEvents(orgAbilities, group.OrganizationId)) + { + return; + } + + var e = new EventMessage(_currentContext) + { + OrganizationId = group.OrganizationId, + GroupId = group.Id, + Type = type, + ActingUserId = _currentContext?.UserId, + ProviderId = await GetProviderIdAsync(@group.OrganizationId), + Date = date.GetValueOrDefault(DateTime.UtcNow) + }; + await _eventWriteService.CreateAsync(e); + } + + public async Task LogPolicyEventAsync(Policy policy, EventType type, DateTime? date = null) + { + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + if (!CanUseEvents(orgAbilities, policy.OrganizationId)) + { + return; + } + + var e = new EventMessage(_currentContext) + { + OrganizationId = policy.OrganizationId, + PolicyId = policy.Id, + Type = type, + ActingUserId = _currentContext?.UserId, + ProviderId = await GetProviderIdAsync(policy.OrganizationId), + Date = date.GetValueOrDefault(DateTime.UtcNow) + }; + await _eventWriteService.CreateAsync(e); + } + + public async Task LogOrganizationUserEventAsync(OrganizationUser organizationUser, EventType type, + DateTime? date = null) => + await LogOrganizationUserEventsAsync(new[] { (organizationUser, type, date) }); + + public async Task LogOrganizationUserEventsAsync(IEnumerable<(OrganizationUser, EventType, DateTime?)> events) + { + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + var eventMessages = new List(); + foreach (var (organizationUser, type, date) in events) + { + if (!CanUseEvents(orgAbilities, organizationUser.OrganizationId)) + { + continue; + } + + eventMessages.Add(new EventMessage(_currentContext) + { + OrganizationId = organizationUser.OrganizationId, + UserId = organizationUser.UserId, + OrganizationUserId = organizationUser.Id, + ProviderId = await GetProviderIdAsync(organizationUser.OrganizationId), + Type = type, + ActingUserId = _currentContext?.UserId, + Date = date.GetValueOrDefault(DateTime.UtcNow) + }); + } + + await _eventWriteService.CreateManyAsync(eventMessages); + } + + public async Task LogOrganizationEventAsync(Organization organization, EventType type, DateTime? date = null) + { + if (!organization.Enabled || !organization.UseEvents) + { + return; + } + + var e = new EventMessage(_currentContext) + { + OrganizationId = organization.Id, + ProviderId = await GetProviderIdAsync(organization.Id), + Type = type, + ActingUserId = _currentContext?.UserId, + Date = date.GetValueOrDefault(DateTime.UtcNow), + InstallationId = GetInstallationId(), + }; + await _eventWriteService.CreateAsync(e); + } + + public async Task LogProviderUserEventAsync(ProviderUser providerUser, EventType type, DateTime? date = null) + { + await LogProviderUsersEventAsync(new[] { (providerUser, type, date) }); + } + + public async Task LogProviderUsersEventAsync(IEnumerable<(ProviderUser, EventType, DateTime?)> events) + { + var providerAbilities = await _applicationCacheService.GetProviderAbilitiesAsync(); + var eventMessages = new List(); + foreach (var (providerUser, type, date) in events) + { + if (!CanUseProviderEvents(providerAbilities, providerUser.ProviderId)) + { + continue; + } + eventMessages.Add(new EventMessage(_currentContext) + { + ProviderId = providerUser.ProviderId, + UserId = providerUser.UserId, + ProviderUserId = providerUser.Id, + Type = type, + ActingUserId = _currentContext?.UserId, + Date = date.GetValueOrDefault(DateTime.UtcNow) + }); + } + + await _eventWriteService.CreateManyAsync(eventMessages); + } + + public async Task LogProviderOrganizationEventAsync(ProviderOrganization providerOrganization, EventType type, + DateTime? date = null) + { + var providerAbilities = await _applicationCacheService.GetProviderAbilitiesAsync(); + if (!CanUseProviderEvents(providerAbilities, providerOrganization.ProviderId)) + { + return; + } + + var e = new EventMessage(_currentContext) + { + ProviderId = providerOrganization.ProviderId, + ProviderOrganizationId = providerOrganization.Id, + Type = type, + ActingUserId = _currentContext?.UserId, + Date = date.GetValueOrDefault(DateTime.UtcNow) + }; + await _eventWriteService.CreateAsync(e); + } + + private async Task GetProviderIdAsync(Guid? orgId) + { + if (_currentContext == null || !orgId.HasValue) + { + return null; + } + + return await _currentContext.ProviderIdForOrg(orgId.Value); + } + + private Guid? GetInstallationId() + { + if (_currentContext == null) + { + return null; + } + + return _currentContext.InstallationId; + } + + private bool CanUseEvents(IDictionary orgAbilities, Guid orgId) + { + return orgAbilities != null && orgAbilities.ContainsKey(orgId) && + orgAbilities[orgId].Enabled && orgAbilities[orgId].UseEvents; + } + + private bool CanUseProviderEvents(IDictionary providerAbilities, Guid providerId) + { + return providerAbilities != null && providerAbilities.ContainsKey(providerId) && + providerAbilities[providerId].Enabled && providerAbilities[providerId].UseEvents; + } } diff --git a/src/Core/Services/Implementations/GroupService.cs b/src/Core/Services/Implementations/GroupService.cs index 3d7872f70..c637fd0ce 100644 --- a/src/Core/Services/Implementations/GroupService.cs +++ b/src/Core/Services/Implementations/GroupService.cs @@ -5,82 +5,81 @@ using Bit.Core.Models.Business; using Bit.Core.Models.Data; using Bit.Core.Repositories; -namespace Bit.Core.Services -{ - public class GroupService : IGroupService - { - private readonly IEventService _eventService; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IGroupRepository _groupRepository; - private readonly IReferenceEventService _referenceEventService; +namespace Bit.Core.Services; - public GroupService( - IEventService eventService, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IGroupRepository groupRepository, - IReferenceEventService referenceEventService) +public class GroupService : IGroupService +{ + private readonly IEventService _eventService; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IGroupRepository _groupRepository; + private readonly IReferenceEventService _referenceEventService; + + public GroupService( + IEventService eventService, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IGroupRepository groupRepository, + IReferenceEventService referenceEventService) + { + _eventService = eventService; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _groupRepository = groupRepository; + _referenceEventService = referenceEventService; + } + + public async Task SaveAsync(Group group, IEnumerable collections = null) + { + var org = await _organizationRepository.GetByIdAsync(group.OrganizationId); + if (org == null) { - _eventService = eventService; - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _groupRepository = groupRepository; - _referenceEventService = referenceEventService; + throw new BadRequestException("Organization not found"); } - public async Task SaveAsync(Group group, IEnumerable collections = null) + if (!org.UseGroups) { - var org = await _organizationRepository.GetByIdAsync(group.OrganizationId); - if (org == null) + throw new BadRequestException("This organization cannot use groups."); + } + + if (group.Id == default(Guid)) + { + group.CreationDate = group.RevisionDate = DateTime.UtcNow; + + if (collections == null) { - throw new BadRequestException("Organization not found"); - } - - if (!org.UseGroups) - { - throw new BadRequestException("This organization cannot use groups."); - } - - if (group.Id == default(Guid)) - { - group.CreationDate = group.RevisionDate = DateTime.UtcNow; - - if (collections == null) - { - await _groupRepository.CreateAsync(group); - } - else - { - await _groupRepository.CreateAsync(group, collections); - } - - await _eventService.LogGroupEventAsync(group, Enums.EventType.Group_Created); - await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.GroupCreated, org)); + await _groupRepository.CreateAsync(group); } else { - group.RevisionDate = DateTime.UtcNow; - await _groupRepository.ReplaceAsync(group, collections ?? new List()); - await _eventService.LogGroupEventAsync(group, Enums.EventType.Group_Updated); + await _groupRepository.CreateAsync(group, collections); } - } - public async Task DeleteAsync(Group group) - { - await _groupRepository.DeleteAsync(group); - await _eventService.LogGroupEventAsync(group, Enums.EventType.Group_Deleted); + await _eventService.LogGroupEventAsync(group, Enums.EventType.Group_Created); + await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.GroupCreated, org)); } - - public async Task DeleteUserAsync(Group group, Guid organizationUserId) + else { - var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); - if (orgUser == null || orgUser.OrganizationId != group.OrganizationId) - { - throw new NotFoundException(); - } - await _groupRepository.DeleteUserAsync(group.Id, organizationUserId); - await _eventService.LogOrganizationUserEventAsync(orgUser, Enums.EventType.OrganizationUser_UpdatedGroups); + group.RevisionDate = DateTime.UtcNow; + await _groupRepository.ReplaceAsync(group, collections ?? new List()); + await _eventService.LogGroupEventAsync(group, Enums.EventType.Group_Updated); } } + + public async Task DeleteAsync(Group group) + { + await _groupRepository.DeleteAsync(group); + await _eventService.LogGroupEventAsync(group, Enums.EventType.Group_Deleted); + } + + public async Task DeleteUserAsync(Group group, Guid organizationUserId) + { + var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); + if (orgUser == null || orgUser.OrganizationId != group.OrganizationId) + { + throw new NotFoundException(); + } + await _groupRepository.DeleteUserAsync(group.Id, organizationUserId); + await _eventService.LogOrganizationUserEventAsync(orgUser, Enums.EventType.OrganizationUser_UpdatedGroups); + } } diff --git a/src/Core/Services/Implementations/HCaptchaValidationService.cs b/src/Core/Services/Implementations/HCaptchaValidationService.cs index 0b72d5286..b8a63c642 100644 --- a/src/Core/Services/Implementations/HCaptchaValidationService.cs +++ b/src/Core/Services/Implementations/HCaptchaValidationService.cs @@ -8,125 +8,124 @@ using Bit.Core.Settings; using Bit.Core.Tokens; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class HCaptchaValidationService : ICaptchaValidationService { - public class HCaptchaValidationService : ICaptchaValidationService + private readonly ILogger _logger; + private readonly IHttpClientFactory _httpClientFactory; + private readonly GlobalSettings _globalSettings; + private readonly IDataProtectorTokenFactory _tokenizer; + + public HCaptchaValidationService( + ILogger logger, + IHttpClientFactory httpClientFactory, + IDataProtectorTokenFactory tokenizer, + GlobalSettings globalSettings) { - private readonly ILogger _logger; - private readonly IHttpClientFactory _httpClientFactory; - private readonly GlobalSettings _globalSettings; - private readonly IDataProtectorTokenFactory _tokenizer; + _logger = logger; + _httpClientFactory = httpClientFactory; + _globalSettings = globalSettings; + _tokenizer = tokenizer; + } - public HCaptchaValidationService( - ILogger logger, - IHttpClientFactory httpClientFactory, - IDataProtectorTokenFactory tokenizer, - GlobalSettings globalSettings) + public string SiteKeyResponseKeyName => "HCaptcha_SiteKey"; + public string SiteKey => _globalSettings.Captcha.HCaptchaSiteKey; + + public string GenerateCaptchaBypassToken(User user) => _tokenizer.Protect(new HCaptchaTokenable(user)); + + public async Task ValidateCaptchaResponseAsync(string captchaResponse, string clientIpAddress, + User user = null) + { + var response = new CaptchaResponse { Success = false }; + if (string.IsNullOrWhiteSpace(captchaResponse)) { - _logger = logger; - _httpClientFactory = httpClientFactory; - _globalSettings = globalSettings; - _tokenizer = tokenizer; - } - - public string SiteKeyResponseKeyName => "HCaptcha_SiteKey"; - public string SiteKey => _globalSettings.Captcha.HCaptchaSiteKey; - - public string GenerateCaptchaBypassToken(User user) => _tokenizer.Protect(new HCaptchaTokenable(user)); - - public async Task ValidateCaptchaResponseAsync(string captchaResponse, string clientIpAddress, - User user = null) - { - var response = new CaptchaResponse { Success = false }; - if (string.IsNullOrWhiteSpace(captchaResponse)) - { - return response; - } - - if (user != null && ValidateCaptchaBypassToken(captchaResponse, user)) - { - response.Success = true; - return response; - } - - var httpClient = _httpClientFactory.CreateClient("HCaptchaValidationService"); - - var requestMessage = new HttpRequestMessage - { - Method = HttpMethod.Post, - RequestUri = new Uri("https://hcaptcha.com/siteverify"), - Content = new FormUrlEncodedContent(new Dictionary - { - { "response", captchaResponse.TrimStart("hcaptcha|".ToCharArray()) }, - { "secret", _globalSettings.Captcha.HCaptchaSecretKey }, - { "sitekey", SiteKey }, - { "remoteip", clientIpAddress } - }) - }; - - HttpResponseMessage responseMessage; - try - { - responseMessage = await httpClient.SendAsync(requestMessage); - } - catch (Exception e) - { - _logger.LogError(11389, e, "Unable to verify with HCaptcha."); - return response; - } - - if (!responseMessage.IsSuccessStatusCode) - { - return response; - } - - using var hcaptchaResponse = await responseMessage.Content.ReadFromJsonAsync(); - response.Success = hcaptchaResponse.Success; - var score = hcaptchaResponse.Score.GetValueOrDefault(); - response.MaybeBot = score >= _globalSettings.Captcha.MaybeBotScoreThreshold; - response.IsBot = score >= _globalSettings.Captcha.IsBotScoreThreshold; - response.Score = score; return response; } - public bool RequireCaptchaValidation(ICurrentContext currentContext, User user = null) + if (user != null && ValidateCaptchaBypassToken(captchaResponse, user)) { - if (user == null) + response.Success = true; + return response; + } + + var httpClient = _httpClientFactory.CreateClient("HCaptchaValidationService"); + + var requestMessage = new HttpRequestMessage + { + Method = HttpMethod.Post, + RequestUri = new Uri("https://hcaptcha.com/siteverify"), + Content = new FormUrlEncodedContent(new Dictionary { - return currentContext.IsBot || _globalSettings.Captcha.ForceCaptchaRequired; - } + { "response", captchaResponse.TrimStart("hcaptcha|".ToCharArray()) }, + { "secret", _globalSettings.Captcha.HCaptchaSecretKey }, + { "sitekey", SiteKey }, + { "remoteip", clientIpAddress } + }) + }; - var failedLoginCeiling = _globalSettings.Captcha.MaximumFailedLoginAttempts; - var failedLoginCount = user?.FailedLoginCount ?? 0; - var cloudEmailUnverified = !_globalSettings.SelfHosted && !user.EmailVerified; - return currentContext.IsBot || - _globalSettings.Captcha.ForceCaptchaRequired || - cloudEmailUnverified || - failedLoginCeiling > 0 && failedLoginCount >= failedLoginCeiling; - } - - private static bool TokenIsValidApiKey(string bypassToken, User user) => - !string.IsNullOrWhiteSpace(bypassToken) && user != null && user.ApiKey == bypassToken; - - private bool TokenIsValidCaptchaBypassToken(string encryptedToken, User user) + HttpResponseMessage responseMessage; + try { - return _tokenizer.TryUnprotect(encryptedToken, out var data) && - data.Valid && data.TokenIsValid(user); + responseMessage = await httpClient.SendAsync(requestMessage); } - - private bool ValidateCaptchaBypassToken(string bypassToken, User user) => - TokenIsValidApiKey(bypassToken, user) || TokenIsValidCaptchaBypassToken(bypassToken, user); - - public class HCaptchaResponse : IDisposable + catch (Exception e) { - [JsonPropertyName("success")] - public bool Success { get; set; } - [JsonPropertyName("score")] - public double? Score { get; set; } - [JsonPropertyName("score_reason")] - public List ScoreReason { get; set; } - - public void Dispose() { } + _logger.LogError(11389, e, "Unable to verify with HCaptcha."); + return response; } + + if (!responseMessage.IsSuccessStatusCode) + { + return response; + } + + using var hcaptchaResponse = await responseMessage.Content.ReadFromJsonAsync(); + response.Success = hcaptchaResponse.Success; + var score = hcaptchaResponse.Score.GetValueOrDefault(); + response.MaybeBot = score >= _globalSettings.Captcha.MaybeBotScoreThreshold; + response.IsBot = score >= _globalSettings.Captcha.IsBotScoreThreshold; + response.Score = score; + return response; + } + + public bool RequireCaptchaValidation(ICurrentContext currentContext, User user = null) + { + if (user == null) + { + return currentContext.IsBot || _globalSettings.Captcha.ForceCaptchaRequired; + } + + var failedLoginCeiling = _globalSettings.Captcha.MaximumFailedLoginAttempts; + var failedLoginCount = user?.FailedLoginCount ?? 0; + var cloudEmailUnverified = !_globalSettings.SelfHosted && !user.EmailVerified; + return currentContext.IsBot || + _globalSettings.Captcha.ForceCaptchaRequired || + cloudEmailUnverified || + failedLoginCeiling > 0 && failedLoginCount >= failedLoginCeiling; + } + + private static bool TokenIsValidApiKey(string bypassToken, User user) => + !string.IsNullOrWhiteSpace(bypassToken) && user != null && user.ApiKey == bypassToken; + + private bool TokenIsValidCaptchaBypassToken(string encryptedToken, User user) + { + return _tokenizer.TryUnprotect(encryptedToken, out var data) && + data.Valid && data.TokenIsValid(user); + } + + private bool ValidateCaptchaBypassToken(string bypassToken, User user) => + TokenIsValidApiKey(bypassToken, user) || TokenIsValidCaptchaBypassToken(bypassToken, user); + + public class HCaptchaResponse : IDisposable + { + [JsonPropertyName("success")] + public bool Success { get; set; } + [JsonPropertyName("score")] + public double? Score { get; set; } + [JsonPropertyName("score_reason")] + public List ScoreReason { get; set; } + + public void Dispose() { } } } diff --git a/src/Core/Services/Implementations/HandlebarsMailService.cs b/src/Core/Services/Implementations/HandlebarsMailService.cs index a1cfb61ce..3688c8f15 100644 --- a/src/Core/Services/Implementations/HandlebarsMailService.cs +++ b/src/Core/Services/Implementations/HandlebarsMailService.cs @@ -10,880 +10,879 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using HandlebarsDotNet; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class HandlebarsMailService : IMailService { - public class HandlebarsMailService : IMailService + private const string Namespace = "Bit.Core.MailTemplates.Handlebars"; + + private readonly GlobalSettings _globalSettings; + private readonly IMailDeliveryService _mailDeliveryService; + private readonly IMailEnqueuingService _mailEnqueuingService; + private readonly Dictionary> _templateCache = + new Dictionary>(); + + private bool _registeredHelpersAndPartials = false; + + public HandlebarsMailService( + GlobalSettings globalSettings, + IMailDeliveryService mailDeliveryService, + IMailEnqueuingService mailEnqueuingService) { - private const string Namespace = "Bit.Core.MailTemplates.Handlebars"; + _globalSettings = globalSettings; + _mailDeliveryService = mailDeliveryService; + _mailEnqueuingService = mailEnqueuingService; + } - private readonly GlobalSettings _globalSettings; - private readonly IMailDeliveryService _mailDeliveryService; - private readonly IMailEnqueuingService _mailEnqueuingService; - private readonly Dictionary> _templateCache = - new Dictionary>(); - - private bool _registeredHelpersAndPartials = false; - - public HandlebarsMailService( - GlobalSettings globalSettings, - IMailDeliveryService mailDeliveryService, - IMailEnqueuingService mailEnqueuingService) + public async Task SendVerifyEmailEmailAsync(string email, Guid userId, string token) + { + var message = CreateDefaultMessage("Verify Your Email", email); + var model = new VerifyEmailModel { - _globalSettings = globalSettings; - _mailDeliveryService = mailDeliveryService; - _mailEnqueuingService = mailEnqueuingService; + Token = WebUtility.UrlEncode(token), + UserId = userId, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "VerifyEmail", model); + message.MetaData.Add("SendGridBypassListManagement", true); + message.Category = "VerifyEmail"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendVerifyDeleteEmailAsync(string email, Guid userId, string token) + { + var message = CreateDefaultMessage("Delete Your Account", email); + var model = new VerifyDeleteModel + { + Token = WebUtility.UrlEncode(token), + UserId = userId, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + Email = email, + EmailEncoded = WebUtility.UrlEncode(email) + }; + await AddMessageContentAsync(message, "VerifyDelete", model); + message.MetaData.Add("SendGridBypassListManagement", true); + message.Category = "VerifyDelete"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendChangeEmailAlreadyExistsEmailAsync(string fromEmail, string toEmail) + { + var message = CreateDefaultMessage("Your Email Change", toEmail); + var model = new ChangeEmailExistsViewModel + { + FromEmail = fromEmail, + ToEmail = toEmail, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "ChangeEmailAlreadyExists", model); + message.Category = "ChangeEmailAlreadyExists"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendChangeEmailEmailAsync(string newEmailAddress, string token) + { + var message = CreateDefaultMessage("Your Email Change", newEmailAddress); + var model = new EmailTokenViewModel + { + Token = token, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "ChangeEmail", model); + message.MetaData.Add("SendGridBypassListManagement", true); + message.Category = "ChangeEmail"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendTwoFactorEmailAsync(string email, string token) + { + var message = CreateDefaultMessage("Your Two-step Login Verification Code", email); + var model = new EmailTokenViewModel + { + Token = token, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "TwoFactorEmail", model); + message.MetaData.Add("SendGridBypassListManagement", true); + message.Category = "TwoFactorEmail"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendNewDeviceLoginTwoFactorEmailAsync(string email, string token) + { + var message = CreateDefaultMessage("New Device Login Verification Code", email); + var model = new EmailTokenViewModel + { + Token = token, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "NewDeviceLoginTwoFactorEmail", model); + message.MetaData.Add("SendGridBypassListManagement", true); + message.Category = "TwoFactorEmail"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendMasterPasswordHintEmailAsync(string email, string hint) + { + var message = CreateDefaultMessage("Your Master Password Hint", email); + var model = new MasterPasswordHintViewModel + { + Hint = CoreHelpers.SanitizeForEmail(hint, false), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "MasterPasswordHint", model); + message.Category = "MasterPasswordHint"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendNoMasterPasswordHintEmailAsync(string email) + { + var message = CreateDefaultMessage("Your Master Password Hint", email); + var model = new BaseMailModel + { + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "NoMasterPasswordHint", model); + message.Category = "NoMasterPasswordHint"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendOrganizationAutoscaledEmailAsync(Organization organization, int initialSeatCount, IEnumerable ownerEmails) + { + var message = CreateDefaultMessage($"{organization.Name} Seat Count Has Increased", ownerEmails); + var model = new OrganizationSeatsAutoscaledViewModel + { + OrganizationId = organization.Id, + InitialSeatCount = initialSeatCount, + CurrentSeatCount = organization.Seats.Value, + }; + + await AddMessageContentAsync(message, "OrganizationSeatsAutoscaled", model); + message.Category = "OrganizationSeatsAutoscaled"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendOrganizationMaxSeatLimitReachedEmailAsync(Organization organization, int maxSeatCount, IEnumerable ownerEmails) + { + var message = CreateDefaultMessage($"{organization.Name} Seat Limit Reached", ownerEmails); + var model = new OrganizationSeatsMaxReachedViewModel + { + OrganizationId = organization.Id, + MaxSeatCount = maxSeatCount, + }; + + await AddMessageContentAsync(message, "OrganizationSeatsMaxReached", model); + message.Category = "OrganizationSeatsMaxReached"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendOrganizationAcceptedEmailAsync(Organization organization, string userIdentifier, + IEnumerable adminEmails) + { + var message = CreateDefaultMessage($"Action Required: {userIdentifier} Needs to Be Confirmed", adminEmails); + var model = new OrganizationUserAcceptedViewModel + { + OrganizationId = organization.Id, + OrganizationName = CoreHelpers.SanitizeForEmail(organization.Name, false), + UserIdentifier = userIdentifier, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "OrganizationUserAccepted", model); + message.Category = "OrganizationUserAccepted"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendOrganizationConfirmedEmailAsync(string organizationName, string email) + { + var message = CreateDefaultMessage($"You Have Been Confirmed To {organizationName}", email); + var model = new OrganizationUserConfirmedViewModel + { + OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "OrganizationUserConfirmed", model); + message.Category = "OrganizationUserConfirmed"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public Task SendOrganizationInviteEmailAsync(string organizationName, OrganizationUser orgUser, ExpiringToken token) => + BulkSendOrganizationInviteEmailAsync(organizationName, new[] { (orgUser, token) }); + + public async Task BulkSendOrganizationInviteEmailAsync(string organizationName, IEnumerable<(OrganizationUser orgUser, ExpiringToken token)> invites) + { + MailQueueMessage CreateMessage(string email, object model) + { + var message = CreateDefaultMessage($"Join {organizationName}", email); + return new MailQueueMessage(message, "OrganizationUserInvited", model); } - public async Task SendVerifyEmailEmailAsync(string email, Guid userId, string token) - { - var message = CreateDefaultMessage("Verify Your Email", email); - var model = new VerifyEmailModel - { - Token = WebUtility.UrlEncode(token), - UserId = userId, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "VerifyEmail", model); - message.MetaData.Add("SendGridBypassListManagement", true); - message.Category = "VerifyEmail"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendVerifyDeleteEmailAsync(string email, Guid userId, string token) - { - var message = CreateDefaultMessage("Delete Your Account", email); - var model = new VerifyDeleteModel - { - Token = WebUtility.UrlEncode(token), - UserId = userId, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - Email = email, - EmailEncoded = WebUtility.UrlEncode(email) - }; - await AddMessageContentAsync(message, "VerifyDelete", model); - message.MetaData.Add("SendGridBypassListManagement", true); - message.Category = "VerifyDelete"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendChangeEmailAlreadyExistsEmailAsync(string fromEmail, string toEmail) - { - var message = CreateDefaultMessage("Your Email Change", toEmail); - var model = new ChangeEmailExistsViewModel - { - FromEmail = fromEmail, - ToEmail = toEmail, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "ChangeEmailAlreadyExists", model); - message.Category = "ChangeEmailAlreadyExists"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendChangeEmailEmailAsync(string newEmailAddress, string token) - { - var message = CreateDefaultMessage("Your Email Change", newEmailAddress); - var model = new EmailTokenViewModel - { - Token = token, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "ChangeEmail", model); - message.MetaData.Add("SendGridBypassListManagement", true); - message.Category = "ChangeEmail"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendTwoFactorEmailAsync(string email, string token) - { - var message = CreateDefaultMessage("Your Two-step Login Verification Code", email); - var model = new EmailTokenViewModel - { - Token = token, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "TwoFactorEmail", model); - message.MetaData.Add("SendGridBypassListManagement", true); - message.Category = "TwoFactorEmail"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendNewDeviceLoginTwoFactorEmailAsync(string email, string token) - { - var message = CreateDefaultMessage("New Device Login Verification Code", email); - var model = new EmailTokenViewModel - { - Token = token, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "NewDeviceLoginTwoFactorEmail", model); - message.MetaData.Add("SendGridBypassListManagement", true); - message.Category = "TwoFactorEmail"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendMasterPasswordHintEmailAsync(string email, string hint) - { - var message = CreateDefaultMessage("Your Master Password Hint", email); - var model = new MasterPasswordHintViewModel - { - Hint = CoreHelpers.SanitizeForEmail(hint, false), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "MasterPasswordHint", model); - message.Category = "MasterPasswordHint"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendNoMasterPasswordHintEmailAsync(string email) - { - var message = CreateDefaultMessage("Your Master Password Hint", email); - var model = new BaseMailModel - { - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "NoMasterPasswordHint", model); - message.Category = "NoMasterPasswordHint"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendOrganizationAutoscaledEmailAsync(Organization organization, int initialSeatCount, IEnumerable ownerEmails) - { - var message = CreateDefaultMessage($"{organization.Name} Seat Count Has Increased", ownerEmails); - var model = new OrganizationSeatsAutoscaledViewModel - { - OrganizationId = organization.Id, - InitialSeatCount = initialSeatCount, - CurrentSeatCount = organization.Seats.Value, - }; - - await AddMessageContentAsync(message, "OrganizationSeatsAutoscaled", model); - message.Category = "OrganizationSeatsAutoscaled"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendOrganizationMaxSeatLimitReachedEmailAsync(Organization organization, int maxSeatCount, IEnumerable ownerEmails) - { - var message = CreateDefaultMessage($"{organization.Name} Seat Limit Reached", ownerEmails); - var model = new OrganizationSeatsMaxReachedViewModel - { - OrganizationId = organization.Id, - MaxSeatCount = maxSeatCount, - }; - - await AddMessageContentAsync(message, "OrganizationSeatsMaxReached", model); - message.Category = "OrganizationSeatsMaxReached"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendOrganizationAcceptedEmailAsync(Organization organization, string userIdentifier, - IEnumerable adminEmails) - { - var message = CreateDefaultMessage($"Action Required: {userIdentifier} Needs to Be Confirmed", adminEmails); - var model = new OrganizationUserAcceptedViewModel - { - OrganizationId = organization.Id, - OrganizationName = CoreHelpers.SanitizeForEmail(organization.Name, false), - UserIdentifier = userIdentifier, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "OrganizationUserAccepted", model); - message.Category = "OrganizationUserAccepted"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendOrganizationConfirmedEmailAsync(string organizationName, string email) - { - var message = CreateDefaultMessage($"You Have Been Confirmed To {organizationName}", email); - var model = new OrganizationUserConfirmedViewModel + var messageModels = invites.Select(invite => CreateMessage(invite.orgUser.Email, + new OrganizationUserInvitedViewModel { OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "OrganizationUserConfirmed", model); - message.Category = "OrganizationUserConfirmed"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public Task SendOrganizationInviteEmailAsync(string organizationName, OrganizationUser orgUser, ExpiringToken token) => - BulkSendOrganizationInviteEmailAsync(organizationName, new[] { (orgUser, token) }); - - public async Task BulkSendOrganizationInviteEmailAsync(string organizationName, IEnumerable<(OrganizationUser orgUser, ExpiringToken token)> invites) - { - MailQueueMessage CreateMessage(string email, object model) - { - var message = CreateDefaultMessage($"Join {organizationName}", email); - return new MailQueueMessage(message, "OrganizationUserInvited", model); - } - - var messageModels = invites.Select(invite => CreateMessage(invite.orgUser.Email, - new OrganizationUserInvitedViewModel - { - OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), - Email = WebUtility.UrlEncode(invite.orgUser.Email), - OrganizationId = invite.orgUser.OrganizationId.ToString(), - OrganizationUserId = invite.orgUser.Id.ToString(), - Token = WebUtility.UrlEncode(invite.token.Token), - ExpirationDate = $"{invite.token.ExpirationDate.ToLongDateString()} {invite.token.ExpirationDate.ToShortTimeString()} UTC", - OrganizationNameUrlEncoded = WebUtility.UrlEncode(organizationName), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - } - )); - - await EnqueueMailAsync(messageModels); - } - - public async Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string organizationName, string email) - { - var message = CreateDefaultMessage($"You have been removed from {organizationName}", email); - var model = new OrganizationUserRemovedForPolicyTwoStepViewModel - { - OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "OrganizationUserRemovedForPolicyTwoStep", model); - message.Category = "OrganizationUserRemovedForPolicyTwoStep"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendWelcomeEmailAsync(User user) - { - var message = CreateDefaultMessage("Welcome to Bitwarden!", user.Email); - var model = new BaseMailModel - { - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "Welcome", model); - message.Category = "Welcome"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendPasswordlessSignInAsync(string returnUrl, string token, string email) - { - var message = CreateDefaultMessage("[Admin] Continue Logging In", email); - var url = CoreHelpers.ExtendQuery(new Uri($"{_globalSettings.BaseServiceUri.Admin}/login/confirm"), - new Dictionary - { - ["returnUrl"] = returnUrl, - ["email"] = email, - ["token"] = token, - }); - var model = new PasswordlessSignInModel - { - Url = url.ToString() - }; - await AddMessageContentAsync(message, "PasswordlessSignIn", model); - message.Category = "PasswordlessSignIn"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, - List items, bool mentionInvoices) - { - var message = CreateDefaultMessage("Your Subscription Will Renew Soon", email); - var model = new InvoiceUpcomingViewModel - { + Email = WebUtility.UrlEncode(invite.orgUser.Email), + OrganizationId = invite.orgUser.OrganizationId.ToString(), + OrganizationUserId = invite.orgUser.Id.ToString(), + Token = WebUtility.UrlEncode(invite.token.Token), + ExpirationDate = $"{invite.token.ExpirationDate.ToLongDateString()} {invite.token.ExpirationDate.ToShortTimeString()} UTC", + OrganizationNameUrlEncoded = WebUtility.UrlEncode(organizationName), WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, SiteName = _globalSettings.SiteName, - AmountDue = amount, - DueDate = dueDate, - Items = items, - MentionInvoices = mentionInvoices - }; - await AddMessageContentAsync(message, "InvoiceUpcoming", model); - message.Category = "InvoiceUpcoming"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices) - { - var message = CreateDefaultMessage("Payment Failed", email); - var model = new PaymentFailedViewModel - { - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - Amount = amount, - MentionInvoices = mentionInvoices - }; - await AddMessageContentAsync(message, "PaymentFailed", model); - message.Category = "PaymentFailed"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendAddedCreditAsync(string email, decimal amount) - { - var message = CreateDefaultMessage("Account Credit Payment Processed", email); - var model = new AddedCreditViewModel - { - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - Amount = amount - }; - await AddMessageContentAsync(message, "AddedCredit", model); - message.Category = "AddedCredit"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendLicenseExpiredAsync(IEnumerable emails, string organizationName = null) - { - var message = CreateDefaultMessage("License Expired", emails); - var model = new LicenseExpiredViewModel - { - OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), - }; - await AddMessageContentAsync(message, "LicenseExpired", model); - message.Category = "LicenseExpired"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendNewDeviceLoggedInEmail(string email, string deviceType, DateTime timestamp, string ip) - { - var message = CreateDefaultMessage($"New Device Logged In From {deviceType}", email); - var model = new NewDeviceLoggedInModel - { - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - DeviceType = deviceType, - TheDate = timestamp.ToLongDateString(), - TheTime = timestamp.ToShortTimeString(), - TimeZone = "UTC", - IpAddress = ip - }; - await AddMessageContentAsync(message, "NewDeviceLoggedIn", model); - message.Category = "NewDeviceLoggedIn"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendRecoverTwoFactorEmail(string email, DateTime timestamp, string ip) - { - var message = CreateDefaultMessage($"Recover 2FA From {ip}", email); - var model = new RecoverTwoFactorModel - { - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - TheDate = timestamp.ToLongDateString(), - TheTime = timestamp.ToShortTimeString(), - TimeZone = "UTC", - IpAddress = ip - }; - await AddMessageContentAsync(message, "RecoverTwoFactor", model); - message.Category = "RecoverTwoFactor"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(string organizationName, string email) - { - var message = CreateDefaultMessage($"You have been removed from {organizationName}", email); - var model = new OrganizationUserRemovedForPolicySingleOrgViewModel - { - OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "OrganizationUserRemovedForPolicySingleOrg", model); - message.Category = "OrganizationUserRemovedForPolicySingleOrg"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendEnqueuedMailMessageAsync(IMailQueueMessage queueMessage) - { - var message = CreateDefaultMessage(queueMessage.Subject, queueMessage.ToEmails); - message.BccEmails = queueMessage.BccEmails; - message.Category = queueMessage.Category; - await AddMessageContentAsync(message, queueMessage.TemplateName, queueMessage.Model); - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName) - { - var message = CreateDefaultMessage("Master Password Has Been Changed", email); - var model = new AdminResetPasswordViewModel() - { - UserName = GetUserIdentifier(email, userName), - OrgName = CoreHelpers.SanitizeForEmail(orgName, false), - }; - await AddMessageContentAsync(message, "AdminResetPassword", model); - message.Category = "AdminResetPassword"; - await _mailDeliveryService.SendEmailAsync(message); - } - - private Task EnqueueMailAsync(IMailQueueMessage queueMessage) => - _mailEnqueuingService.EnqueueAsync(queueMessage, SendEnqueuedMailMessageAsync); - - private Task EnqueueMailAsync(IEnumerable queueMessages) => - _mailEnqueuingService.EnqueueManyAsync(queueMessages, SendEnqueuedMailMessageAsync); - - private MailMessage CreateDefaultMessage(string subject, string toEmail) - { - return CreateDefaultMessage(subject, new List { toEmail }); - } - - private MailMessage CreateDefaultMessage(string subject, IEnumerable toEmails) - { - return new MailMessage - { - ToEmails = toEmails, - Subject = subject, - MetaData = new Dictionary() - }; - } - - private async Task AddMessageContentAsync(MailMessage message, string templateName, T model) - { - message.HtmlContent = await RenderAsync($"{templateName}.html", model); - message.TextContent = await RenderAsync($"{templateName}.text", model); - } - - private async Task RenderAsync(string templateName, T model) - { - await RegisterHelpersAndPartialsAsync(); - if (!_templateCache.TryGetValue(templateName, out var template)) - { - var source = await ReadSourceAsync(templateName); - if (source != null) - { - template = Handlebars.Compile(source); - _templateCache.Add(templateName, template); - } } - return template != null ? template(model) : null; - } + )); - private async Task ReadSourceAsync(string templateName) + await EnqueueMailAsync(messageModels); + } + + public async Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string organizationName, string email) + { + var message = CreateDefaultMessage($"You have been removed from {organizationName}", email); + var model = new OrganizationUserRemovedForPolicyTwoStepViewModel { - var assembly = typeof(HandlebarsMailService).GetTypeInfo().Assembly; - var fullTemplateName = $"{Namespace}.{templateName}.hbs"; - if (!assembly.GetManifestResourceNames().Any(f => f == fullTemplateName)) - { - return null; - } - using (var s = assembly.GetManifestResourceStream(fullTemplateName)) - using (var sr = new StreamReader(s)) - { - return await sr.ReadToEndAsync(); - } - } + OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "OrganizationUserRemovedForPolicyTwoStep", model); + message.Category = "OrganizationUserRemovedForPolicyTwoStep"; + await _mailDeliveryService.SendEmailAsync(message); + } - private async Task RegisterHelpersAndPartialsAsync() + public async Task SendWelcomeEmailAsync(User user) + { + var message = CreateDefaultMessage("Welcome to Bitwarden!", user.Email); + var model = new BaseMailModel { - if (_registeredHelpersAndPartials) - { - return; - } - _registeredHelpersAndPartials = true; + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "Welcome", model); + message.Category = "Welcome"; + await _mailDeliveryService.SendEmailAsync(message); + } - var basicHtmlLayoutSource = await ReadSourceAsync("Layouts.Basic.html"); - Handlebars.RegisterTemplate("BasicHtmlLayout", basicHtmlLayoutSource); - var basicTextLayoutSource = await ReadSourceAsync("Layouts.Basic.text"); - Handlebars.RegisterTemplate("BasicTextLayout", basicTextLayoutSource); - var fullHtmlLayoutSource = await ReadSourceAsync("Layouts.Full.html"); - Handlebars.RegisterTemplate("FullHtmlLayout", fullHtmlLayoutSource); - var fullTextLayoutSource = await ReadSourceAsync("Layouts.Full.text"); - Handlebars.RegisterTemplate("FullTextLayout", fullTextLayoutSource); - - Handlebars.RegisterHelper("date", (writer, context, parameters) => + public async Task SendPasswordlessSignInAsync(string returnUrl, string token, string email) + { + var message = CreateDefaultMessage("[Admin] Continue Logging In", email); + var url = CoreHelpers.ExtendQuery(new Uri($"{_globalSettings.BaseServiceUri.Admin}/login/confirm"), + new Dictionary { - if (parameters.Length == 0 || !(parameters[0] is DateTime)) - { - writer.WriteSafeString(string.Empty); - return; - } - if (parameters.Length > 0 && parameters[1] is string) - { - writer.WriteSafeString(((DateTime)parameters[0]).ToString(parameters[1].ToString())); - } - else - { - writer.WriteSafeString(((DateTime)parameters[0]).ToString()); - } + ["returnUrl"] = returnUrl, + ["email"] = email, + ["token"] = token, }); - - Handlebars.RegisterHelper("usd", (writer, context, parameters) => - { - if (parameters.Length == 0 || !(parameters[0] is decimal)) - { - writer.WriteSafeString(string.Empty); - return; - } - writer.WriteSafeString(((decimal)parameters[0]).ToString("C")); - }); - - Handlebars.RegisterHelper("link", (writer, context, parameters) => - { - if (parameters.Length == 0) - { - writer.WriteSafeString(string.Empty); - return; - } - - var text = parameters[0].ToString(); - var href = text; - var clickTrackingOff = false; - if (parameters.Length == 2) - { - if (parameters[1] is string) - { - var p1 = parameters[1].ToString(); - if (p1 == "true" || p1 == "false") - { - clickTrackingOff = p1 == "true"; - } - else - { - href = p1; - } - } - else if (parameters[1] is bool) - { - clickTrackingOff = (bool)parameters[1]; - } - } - else if (parameters.Length > 2) - { - if (parameters[1] is string) - { - href = parameters[1].ToString(); - } - if (parameters[2] is string) - { - var p2 = parameters[2].ToString(); - if (p2 == "true" || p2 == "false") - { - clickTrackingOff = p2 == "true"; - } - } - else if (parameters[2] is bool) - { - clickTrackingOff = (bool)parameters[2]; - } - } - - var clickTrackingText = (clickTrackingOff ? "clicktracking=off" : string.Empty); - writer.WriteSafeString($"{text}"); - }); - } - - public async Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token) + var model = new PasswordlessSignInModel { - var message = CreateDefaultMessage($"Emergency Access Contact Invite", emergencyAccess.Email); - var model = new EmergencyAccessInvitedViewModel - { - Name = CoreHelpers.SanitizeForEmail(name), - Email = WebUtility.UrlEncode(emergencyAccess.Email), - Id = emergencyAccess.Id.ToString(), - Token = WebUtility.UrlEncode(token), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "EmergencyAccessInvited", model); - message.Category = "EmergencyAccessInvited"; - await _mailDeliveryService.SendEmailAsync(message); - } + Url = url.ToString() + }; + await AddMessageContentAsync(message, "PasswordlessSignIn", model); + message.Category = "PasswordlessSignIn"; + await _mailDeliveryService.SendEmailAsync(message); + } - public async Task SendEmergencyAccessAcceptedEmailAsync(string granteeEmail, string email) + public async Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, + List items, bool mentionInvoices) + { + var message = CreateDefaultMessage("Your Subscription Will Renew Soon", email); + var model = new InvoiceUpcomingViewModel { - var message = CreateDefaultMessage($"Accepted Emergency Access", email); - var model = new EmergencyAccessAcceptedViewModel - { - GranteeEmail = granteeEmail, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "EmergencyAccessAccepted", model); - message.Category = "EmergencyAccessAccepted"; - await _mailDeliveryService.SendEmailAsync(message); - } + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + AmountDue = amount, + DueDate = dueDate, + Items = items, + MentionInvoices = mentionInvoices + }; + await AddMessageContentAsync(message, "InvoiceUpcoming", model); + message.Category = "InvoiceUpcoming"; + await _mailDeliveryService.SendEmailAsync(message); + } - public async Task SendEmergencyAccessConfirmedEmailAsync(string grantorName, string email) + public async Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices) + { + var message = CreateDefaultMessage("Payment Failed", email); + var model = new PaymentFailedViewModel { - var message = CreateDefaultMessage($"You Have Been Confirmed as Emergency Access Contact", email); - var model = new EmergencyAccessConfirmedViewModel - { - Name = CoreHelpers.SanitizeForEmail(grantorName), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "EmergencyAccessConfirmed", model); - message.Category = "EmergencyAccessConfirmed"; - await _mailDeliveryService.SendEmailAsync(message); - } + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + Amount = amount, + MentionInvoices = mentionInvoices + }; + await AddMessageContentAsync(message, "PaymentFailed", model); + message.Category = "PaymentFailed"; + await _mailDeliveryService.SendEmailAsync(message); + } - public async Task SendEmergencyAccessRecoveryInitiated(EmergencyAccess emergencyAccess, string initiatingName, string email) + public async Task SendAddedCreditAsync(string email, decimal amount) + { + var message = CreateDefaultMessage("Account Credit Payment Processed", email); + var model = new AddedCreditViewModel { - var message = CreateDefaultMessage("Emergency Access Initiated", email); + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + Amount = amount + }; + await AddMessageContentAsync(message, "AddedCredit", model); + message.Category = "AddedCredit"; + await _mailDeliveryService.SendEmailAsync(message); + } - var remainingTime = DateTime.UtcNow - emergencyAccess.RecoveryInitiatedDate.GetValueOrDefault(); - - var model = new EmergencyAccessRecoveryViewModel - { - Name = CoreHelpers.SanitizeForEmail(initiatingName), - Action = emergencyAccess.Type.ToString(), - DaysLeft = emergencyAccess.WaitTimeDays - Convert.ToInt32((remainingTime).TotalDays), - }; - await AddMessageContentAsync(message, "EmergencyAccessRecovery", model); - message.Category = "EmergencyAccessRecovery"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendEmergencyAccessRecoveryApproved(EmergencyAccess emergencyAccess, string approvingName, string email) + public async Task SendLicenseExpiredAsync(IEnumerable emails, string organizationName = null) + { + var message = CreateDefaultMessage("License Expired", emails); + var model = new LicenseExpiredViewModel { - var message = CreateDefaultMessage("Emergency Access Approved", email); - var model = new EmergencyAccessApprovedViewModel - { - Name = CoreHelpers.SanitizeForEmail(approvingName), - }; - await AddMessageContentAsync(message, "EmergencyAccessApproved", model); - message.Category = "EmergencyAccessApproved"; - await _mailDeliveryService.SendEmailAsync(message); - } + OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), + }; + await AddMessageContentAsync(message, "LicenseExpired", model); + message.Category = "LicenseExpired"; + await _mailDeliveryService.SendEmailAsync(message); + } - public async Task SendEmergencyAccessRecoveryRejected(EmergencyAccess emergencyAccess, string rejectingName, string email) + public async Task SendNewDeviceLoggedInEmail(string email, string deviceType, DateTime timestamp, string ip) + { + var message = CreateDefaultMessage($"New Device Logged In From {deviceType}", email); + var model = new NewDeviceLoggedInModel { - var message = CreateDefaultMessage("Emergency Access Rejected", email); - var model = new EmergencyAccessRejectedViewModel - { - Name = CoreHelpers.SanitizeForEmail(rejectingName), - }; - await AddMessageContentAsync(message, "EmergencyAccessRejected", model); - message.Category = "EmergencyAccessRejected"; - await _mailDeliveryService.SendEmailAsync(message); - } + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + DeviceType = deviceType, + TheDate = timestamp.ToLongDateString(), + TheTime = timestamp.ToShortTimeString(), + TimeZone = "UTC", + IpAddress = ip + }; + await AddMessageContentAsync(message, "NewDeviceLoggedIn", model); + message.Category = "NewDeviceLoggedIn"; + await _mailDeliveryService.SendEmailAsync(message); + } - public async Task SendEmergencyAccessRecoveryReminder(EmergencyAccess emergencyAccess, string initiatingName, string email) + public async Task SendRecoverTwoFactorEmail(string email, DateTime timestamp, string ip) + { + var message = CreateDefaultMessage($"Recover 2FA From {ip}", email); + var model = new RecoverTwoFactorModel { - var message = CreateDefaultMessage("Pending Emergency Access Request", email); + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + TheDate = timestamp.ToLongDateString(), + TheTime = timestamp.ToShortTimeString(), + TimeZone = "UTC", + IpAddress = ip + }; + await AddMessageContentAsync(message, "RecoverTwoFactor", model); + message.Category = "RecoverTwoFactor"; + await _mailDeliveryService.SendEmailAsync(message); + } - var remainingTime = DateTime.UtcNow - emergencyAccess.RecoveryInitiatedDate.GetValueOrDefault(); - - var model = new EmergencyAccessRecoveryViewModel - { - Name = CoreHelpers.SanitizeForEmail(initiatingName), - Action = emergencyAccess.Type.ToString(), - DaysLeft = emergencyAccess.WaitTimeDays - Convert.ToInt32((remainingTime).TotalDays), - }; - await AddMessageContentAsync(message, "EmergencyAccessRecoveryReminder", model); - message.Category = "EmergencyAccessRecoveryReminder"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendEmergencyAccessRecoveryTimedOut(EmergencyAccess emergencyAccess, string initiatingName, string email) + public async Task SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(string organizationName, string email) + { + var message = CreateDefaultMessage($"You have been removed from {organizationName}", email); + var model = new OrganizationUserRemovedForPolicySingleOrgViewModel { - var message = CreateDefaultMessage("Emergency Access Granted", email); - var model = new EmergencyAccessRecoveryTimedOutViewModel - { - Name = CoreHelpers.SanitizeForEmail(initiatingName), - Action = emergencyAccess.Type.ToString(), - }; - await AddMessageContentAsync(message, "EmergencyAccessRecoveryTimedOut", model); - message.Category = "EmergencyAccessRecoveryTimedOut"; - await _mailDeliveryService.SendEmailAsync(message); - } + OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "OrganizationUserRemovedForPolicySingleOrg", model); + message.Category = "OrganizationUserRemovedForPolicySingleOrg"; + await _mailDeliveryService.SendEmailAsync(message); + } - public async Task SendProviderSetupInviteEmailAsync(Provider provider, string token, string email) + public async Task SendEnqueuedMailMessageAsync(IMailQueueMessage queueMessage) + { + var message = CreateDefaultMessage(queueMessage.Subject, queueMessage.ToEmails); + message.BccEmails = queueMessage.BccEmails; + message.Category = queueMessage.Category; + await AddMessageContentAsync(message, queueMessage.TemplateName, queueMessage.Model); + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName) + { + var message = CreateDefaultMessage("Master Password Has Been Changed", email); + var model = new AdminResetPasswordViewModel() { - var message = CreateDefaultMessage($"Create a Provider", email); - var model = new ProviderSetupInviteViewModel - { - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - ProviderId = provider.Id.ToString(), - Email = WebUtility.UrlEncode(email), - Token = WebUtility.UrlEncode(token), - }; - await AddMessageContentAsync(message, "Provider.ProviderSetupInvite", model); - message.Category = "ProviderSetupInvite"; - await _mailDeliveryService.SendEmailAsync(message); - } + UserName = GetUserIdentifier(email, userName), + OrgName = CoreHelpers.SanitizeForEmail(orgName, false), + }; + await AddMessageContentAsync(message, "AdminResetPassword", model); + message.Category = "AdminResetPassword"; + await _mailDeliveryService.SendEmailAsync(message); + } - public async Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email) + private Task EnqueueMailAsync(IMailQueueMessage queueMessage) => + _mailEnqueuingService.EnqueueAsync(queueMessage, SendEnqueuedMailMessageAsync); + + private Task EnqueueMailAsync(IEnumerable queueMessages) => + _mailEnqueuingService.EnqueueManyAsync(queueMessages, SendEnqueuedMailMessageAsync); + + private MailMessage CreateDefaultMessage(string subject, string toEmail) + { + return CreateDefaultMessage(subject, new List { toEmail }); + } + + private MailMessage CreateDefaultMessage(string subject, IEnumerable toEmails) + { + return new MailMessage { - var message = CreateDefaultMessage($"Join {providerName}", email); - var model = new ProviderUserInvitedViewModel - { - ProviderName = CoreHelpers.SanitizeForEmail(providerName), - Email = WebUtility.UrlEncode(providerUser.Email), - ProviderId = providerUser.ProviderId.ToString(), - ProviderUserId = providerUser.Id.ToString(), - ProviderNameUrlEncoded = WebUtility.UrlEncode(providerName), - Token = WebUtility.UrlEncode(token), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - }; - await AddMessageContentAsync(message, "Provider.ProviderUserInvited", model); - message.Category = "ProviderSetupInvite"; - await _mailDeliveryService.SendEmailAsync(message); - } + ToEmails = toEmails, + Subject = subject, + MetaData = new Dictionary() + }; + } - public async Task SendProviderConfirmedEmailAsync(string providerName, string email) + private async Task AddMessageContentAsync(MailMessage message, string templateName, T model) + { + message.HtmlContent = await RenderAsync($"{templateName}.html", model); + message.TextContent = await RenderAsync($"{templateName}.text", model); + } + + private async Task RenderAsync(string templateName, T model) + { + await RegisterHelpersAndPartialsAsync(); + if (!_templateCache.TryGetValue(templateName, out var template)) { - var message = CreateDefaultMessage($"You Have Been Confirmed To {providerName}", email); - var model = new ProviderUserConfirmedViewModel + var source = await ReadSourceAsync(templateName); + if (source != null) { - ProviderName = CoreHelpers.SanitizeForEmail(providerName), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "Provider.ProviderUserConfirmed", model); - message.Category = "ProviderUserConfirmed"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendProviderUserRemoved(string providerName, string email) - { - var message = CreateDefaultMessage($"You Have Been Removed from {providerName}", email); - var model = new ProviderUserRemovedViewModel - { - ProviderName = CoreHelpers.SanitizeForEmail(providerName), - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName - }; - await AddMessageContentAsync(message, "Provider.ProviderUserRemoved", model); - message.Category = "ProviderUserRemoved"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendUpdatedTempPasswordEmailAsync(string email, string userName) - { - var message = CreateDefaultMessage("Master Password Has Been Changed", email); - var model = new UpdateTempPasswordViewModel() - { - UserName = GetUserIdentifier(email, userName) - }; - await AddMessageContentAsync(message, "UpdatedTempPassword", model); - message.Category = "UpdatedTempPassword"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendFamiliesForEnterpriseOfferEmailAsync(string sponsorOrgName, string email, bool existingAccount, string token) => - await BulkSendFamiliesForEnterpriseOfferEmailAsync(sponsorOrgName, new[] { (email, existingAccount, token) }); - - public async Task BulkSendFamiliesForEnterpriseOfferEmailAsync(string sponsorOrgName, IEnumerable<(string Email, bool ExistingAccount, string Token)> invites) - { - MailQueueMessage CreateMessage((string Email, bool ExistingAccount, string Token) invite) - { - var message = CreateDefaultMessage("Accept Your Free Families Subscription", invite.Email); - message.Category = "FamiliesForEnterpriseOffer"; - var model = new FamiliesForEnterpriseOfferViewModel - { - SponsorOrgName = sponsorOrgName, - SponsoredEmail = WebUtility.UrlEncode(invite.Email), - ExistingAccount = invite.ExistingAccount, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - SponsorshipToken = invite.Token, - }; - var templateName = invite.ExistingAccount ? - "FamiliesForEnterprise.FamiliesForEnterpriseOfferExistingAccount" : - "FamiliesForEnterprise.FamiliesForEnterpriseOfferNewAccount"; - - return new MailQueueMessage(message, templateName, model); + template = Handlebars.Compile(source); + _templateCache.Add(templateName, template); } - var messageModels = invites.Select(invite => CreateMessage(invite)); - await EnqueueMailAsync(messageModels); } + return template != null ? template(model) : null; + } - public async Task SendFamiliesForEnterpriseRedeemedEmailsAsync(string familyUserEmail, string sponsorEmail) + private async Task ReadSourceAsync(string templateName) + { + var assembly = typeof(HandlebarsMailService).GetTypeInfo().Assembly; + var fullTemplateName = $"{Namespace}.{templateName}.hbs"; + if (!assembly.GetManifestResourceNames().Any(f => f == fullTemplateName)) { - // Email family user - await SendFamiliesForEnterpriseInviteRedeemedToFamilyUserEmailAsync(familyUserEmail); - - // Email enterprise org user - await SendFamiliesForEnterpriseInviteRedeemedToEnterpriseUserEmailAsync(sponsorEmail); + return null; } - - private async Task SendFamiliesForEnterpriseInviteRedeemedToFamilyUserEmailAsync(string email) + using (var s = assembly.GetManifestResourceStream(fullTemplateName)) + using (var sr = new StreamReader(s)) { - var message = CreateDefaultMessage("Success! Families Subscription Accepted", email); - await AddMessageContentAsync(message, "FamiliesForEnterprise.FamiliesForEnterpriseRedeemedToFamilyUser", new BaseMailModel()); - message.Category = "FamilyForEnterpriseRedeemedToFamilyUser"; - await _mailDeliveryService.SendEmailAsync(message); - } - - private async Task SendFamiliesForEnterpriseInviteRedeemedToEnterpriseUserEmailAsync(string email) - { - var message = CreateDefaultMessage("Success! Families Subscription Accepted", email); - await AddMessageContentAsync(message, "FamiliesForEnterprise.FamiliesForEnterpriseRedeemedToEnterpriseUser", new BaseMailModel()); - message.Category = "FamilyForEnterpriseRedeemedToEnterpriseUser"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(string email, DateTime expirationDate) - { - var message = CreateDefaultMessage("Your Families Sponsorship was Removed", email); - var model = new FamiliesForEnterpriseSponsorshipRevertingViewModel - { - ExpirationDate = expirationDate, - }; - await AddMessageContentAsync(message, "FamiliesForEnterprise.FamiliesForEnterpriseSponsorshipReverting", model); - message.Category = "FamiliesForEnterpriseSponsorshipReverting"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendOTPEmailAsync(string email, string token) - { - var message = CreateDefaultMessage("Your Bitwarden Verification Code", email); - var model = new EmailTokenViewModel - { - Token = token, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - }; - await AddMessageContentAsync(message, "OTPEmail", model); - message.MetaData.Add("SendGridBypassListManagement", true); - message.Category = "OTP"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendFailedLoginAttemptsEmailAsync(string email, DateTime utcNow, string ip) - { - var message = CreateDefaultMessage("Failed login attempts detected", email); - var model = new FailedAuthAttemptsModel() - { - TheDate = utcNow.ToLongDateString(), - TheTime = utcNow.ToShortTimeString(), - TimeZone = "UTC", - IpAddress = ip, - AffectedEmail = email - - }; - await AddMessageContentAsync(message, "FailedLoginAttempts", model); - message.Category = "FailedLoginAttempts"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendFailedTwoFactorAttemptsEmailAsync(string email, DateTime utcNow, string ip) - { - var message = CreateDefaultMessage("Failed login attempts detected", email); - var model = new FailedAuthAttemptsModel() - { - TheDate = utcNow.ToLongDateString(), - TheTime = utcNow.ToShortTimeString(), - TimeZone = "UTC", - IpAddress = ip, - AffectedEmail = email - - }; - await AddMessageContentAsync(message, "FailedTwoFactorAttempts", model); - message.Category = "FailedTwoFactorAttempts"; - await _mailDeliveryService.SendEmailAsync(message); - } - - private static string GetUserIdentifier(string email, string userName) - { - return string.IsNullOrEmpty(userName) ? email : CoreHelpers.SanitizeForEmail(userName, false); + return await sr.ReadToEndAsync(); } } + + private async Task RegisterHelpersAndPartialsAsync() + { + if (_registeredHelpersAndPartials) + { + return; + } + _registeredHelpersAndPartials = true; + + var basicHtmlLayoutSource = await ReadSourceAsync("Layouts.Basic.html"); + Handlebars.RegisterTemplate("BasicHtmlLayout", basicHtmlLayoutSource); + var basicTextLayoutSource = await ReadSourceAsync("Layouts.Basic.text"); + Handlebars.RegisterTemplate("BasicTextLayout", basicTextLayoutSource); + var fullHtmlLayoutSource = await ReadSourceAsync("Layouts.Full.html"); + Handlebars.RegisterTemplate("FullHtmlLayout", fullHtmlLayoutSource); + var fullTextLayoutSource = await ReadSourceAsync("Layouts.Full.text"); + Handlebars.RegisterTemplate("FullTextLayout", fullTextLayoutSource); + + Handlebars.RegisterHelper("date", (writer, context, parameters) => + { + if (parameters.Length == 0 || !(parameters[0] is DateTime)) + { + writer.WriteSafeString(string.Empty); + return; + } + if (parameters.Length > 0 && parameters[1] is string) + { + writer.WriteSafeString(((DateTime)parameters[0]).ToString(parameters[1].ToString())); + } + else + { + writer.WriteSafeString(((DateTime)parameters[0]).ToString()); + } + }); + + Handlebars.RegisterHelper("usd", (writer, context, parameters) => + { + if (parameters.Length == 0 || !(parameters[0] is decimal)) + { + writer.WriteSafeString(string.Empty); + return; + } + writer.WriteSafeString(((decimal)parameters[0]).ToString("C")); + }); + + Handlebars.RegisterHelper("link", (writer, context, parameters) => + { + if (parameters.Length == 0) + { + writer.WriteSafeString(string.Empty); + return; + } + + var text = parameters[0].ToString(); + var href = text; + var clickTrackingOff = false; + if (parameters.Length == 2) + { + if (parameters[1] is string) + { + var p1 = parameters[1].ToString(); + if (p1 == "true" || p1 == "false") + { + clickTrackingOff = p1 == "true"; + } + else + { + href = p1; + } + } + else if (parameters[1] is bool) + { + clickTrackingOff = (bool)parameters[1]; + } + } + else if (parameters.Length > 2) + { + if (parameters[1] is string) + { + href = parameters[1].ToString(); + } + if (parameters[2] is string) + { + var p2 = parameters[2].ToString(); + if (p2 == "true" || p2 == "false") + { + clickTrackingOff = p2 == "true"; + } + } + else if (parameters[2] is bool) + { + clickTrackingOff = (bool)parameters[2]; + } + } + + var clickTrackingText = (clickTrackingOff ? "clicktracking=off" : string.Empty); + writer.WriteSafeString($"{text}"); + }); + } + + public async Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token) + { + var message = CreateDefaultMessage($"Emergency Access Contact Invite", emergencyAccess.Email); + var model = new EmergencyAccessInvitedViewModel + { + Name = CoreHelpers.SanitizeForEmail(name), + Email = WebUtility.UrlEncode(emergencyAccess.Email), + Id = emergencyAccess.Id.ToString(), + Token = WebUtility.UrlEncode(token), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "EmergencyAccessInvited", model); + message.Category = "EmergencyAccessInvited"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendEmergencyAccessAcceptedEmailAsync(string granteeEmail, string email) + { + var message = CreateDefaultMessage($"Accepted Emergency Access", email); + var model = new EmergencyAccessAcceptedViewModel + { + GranteeEmail = granteeEmail, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "EmergencyAccessAccepted", model); + message.Category = "EmergencyAccessAccepted"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendEmergencyAccessConfirmedEmailAsync(string grantorName, string email) + { + var message = CreateDefaultMessage($"You Have Been Confirmed as Emergency Access Contact", email); + var model = new EmergencyAccessConfirmedViewModel + { + Name = CoreHelpers.SanitizeForEmail(grantorName), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "EmergencyAccessConfirmed", model); + message.Category = "EmergencyAccessConfirmed"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendEmergencyAccessRecoveryInitiated(EmergencyAccess emergencyAccess, string initiatingName, string email) + { + var message = CreateDefaultMessage("Emergency Access Initiated", email); + + var remainingTime = DateTime.UtcNow - emergencyAccess.RecoveryInitiatedDate.GetValueOrDefault(); + + var model = new EmergencyAccessRecoveryViewModel + { + Name = CoreHelpers.SanitizeForEmail(initiatingName), + Action = emergencyAccess.Type.ToString(), + DaysLeft = emergencyAccess.WaitTimeDays - Convert.ToInt32((remainingTime).TotalDays), + }; + await AddMessageContentAsync(message, "EmergencyAccessRecovery", model); + message.Category = "EmergencyAccessRecovery"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendEmergencyAccessRecoveryApproved(EmergencyAccess emergencyAccess, string approvingName, string email) + { + var message = CreateDefaultMessage("Emergency Access Approved", email); + var model = new EmergencyAccessApprovedViewModel + { + Name = CoreHelpers.SanitizeForEmail(approvingName), + }; + await AddMessageContentAsync(message, "EmergencyAccessApproved", model); + message.Category = "EmergencyAccessApproved"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendEmergencyAccessRecoveryRejected(EmergencyAccess emergencyAccess, string rejectingName, string email) + { + var message = CreateDefaultMessage("Emergency Access Rejected", email); + var model = new EmergencyAccessRejectedViewModel + { + Name = CoreHelpers.SanitizeForEmail(rejectingName), + }; + await AddMessageContentAsync(message, "EmergencyAccessRejected", model); + message.Category = "EmergencyAccessRejected"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendEmergencyAccessRecoveryReminder(EmergencyAccess emergencyAccess, string initiatingName, string email) + { + var message = CreateDefaultMessage("Pending Emergency Access Request", email); + + var remainingTime = DateTime.UtcNow - emergencyAccess.RecoveryInitiatedDate.GetValueOrDefault(); + + var model = new EmergencyAccessRecoveryViewModel + { + Name = CoreHelpers.SanitizeForEmail(initiatingName), + Action = emergencyAccess.Type.ToString(), + DaysLeft = emergencyAccess.WaitTimeDays - Convert.ToInt32((remainingTime).TotalDays), + }; + await AddMessageContentAsync(message, "EmergencyAccessRecoveryReminder", model); + message.Category = "EmergencyAccessRecoveryReminder"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendEmergencyAccessRecoveryTimedOut(EmergencyAccess emergencyAccess, string initiatingName, string email) + { + var message = CreateDefaultMessage("Emergency Access Granted", email); + var model = new EmergencyAccessRecoveryTimedOutViewModel + { + Name = CoreHelpers.SanitizeForEmail(initiatingName), + Action = emergencyAccess.Type.ToString(), + }; + await AddMessageContentAsync(message, "EmergencyAccessRecoveryTimedOut", model); + message.Category = "EmergencyAccessRecoveryTimedOut"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendProviderSetupInviteEmailAsync(Provider provider, string token, string email) + { + var message = CreateDefaultMessage($"Create a Provider", email); + var model = new ProviderSetupInviteViewModel + { + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + ProviderId = provider.Id.ToString(), + Email = WebUtility.UrlEncode(email), + Token = WebUtility.UrlEncode(token), + }; + await AddMessageContentAsync(message, "Provider.ProviderSetupInvite", model); + message.Category = "ProviderSetupInvite"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email) + { + var message = CreateDefaultMessage($"Join {providerName}", email); + var model = new ProviderUserInvitedViewModel + { + ProviderName = CoreHelpers.SanitizeForEmail(providerName), + Email = WebUtility.UrlEncode(providerUser.Email), + ProviderId = providerUser.ProviderId.ToString(), + ProviderUserId = providerUser.Id.ToString(), + ProviderNameUrlEncoded = WebUtility.UrlEncode(providerName), + Token = WebUtility.UrlEncode(token), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + }; + await AddMessageContentAsync(message, "Provider.ProviderUserInvited", model); + message.Category = "ProviderSetupInvite"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendProviderConfirmedEmailAsync(string providerName, string email) + { + var message = CreateDefaultMessage($"You Have Been Confirmed To {providerName}", email); + var model = new ProviderUserConfirmedViewModel + { + ProviderName = CoreHelpers.SanitizeForEmail(providerName), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "Provider.ProviderUserConfirmed", model); + message.Category = "ProviderUserConfirmed"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendProviderUserRemoved(string providerName, string email) + { + var message = CreateDefaultMessage($"You Have Been Removed from {providerName}", email); + var model = new ProviderUserRemovedViewModel + { + ProviderName = CoreHelpers.SanitizeForEmail(providerName), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "Provider.ProviderUserRemoved", model); + message.Category = "ProviderUserRemoved"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendUpdatedTempPasswordEmailAsync(string email, string userName) + { + var message = CreateDefaultMessage("Master Password Has Been Changed", email); + var model = new UpdateTempPasswordViewModel() + { + UserName = GetUserIdentifier(email, userName) + }; + await AddMessageContentAsync(message, "UpdatedTempPassword", model); + message.Category = "UpdatedTempPassword"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendFamiliesForEnterpriseOfferEmailAsync(string sponsorOrgName, string email, bool existingAccount, string token) => + await BulkSendFamiliesForEnterpriseOfferEmailAsync(sponsorOrgName, new[] { (email, existingAccount, token) }); + + public async Task BulkSendFamiliesForEnterpriseOfferEmailAsync(string sponsorOrgName, IEnumerable<(string Email, bool ExistingAccount, string Token)> invites) + { + MailQueueMessage CreateMessage((string Email, bool ExistingAccount, string Token) invite) + { + var message = CreateDefaultMessage("Accept Your Free Families Subscription", invite.Email); + message.Category = "FamiliesForEnterpriseOffer"; + var model = new FamiliesForEnterpriseOfferViewModel + { + SponsorOrgName = sponsorOrgName, + SponsoredEmail = WebUtility.UrlEncode(invite.Email), + ExistingAccount = invite.ExistingAccount, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + SponsorshipToken = invite.Token, + }; + var templateName = invite.ExistingAccount ? + "FamiliesForEnterprise.FamiliesForEnterpriseOfferExistingAccount" : + "FamiliesForEnterprise.FamiliesForEnterpriseOfferNewAccount"; + + return new MailQueueMessage(message, templateName, model); + } + var messageModels = invites.Select(invite => CreateMessage(invite)); + await EnqueueMailAsync(messageModels); + } + + public async Task SendFamiliesForEnterpriseRedeemedEmailsAsync(string familyUserEmail, string sponsorEmail) + { + // Email family user + await SendFamiliesForEnterpriseInviteRedeemedToFamilyUserEmailAsync(familyUserEmail); + + // Email enterprise org user + await SendFamiliesForEnterpriseInviteRedeemedToEnterpriseUserEmailAsync(sponsorEmail); + } + + private async Task SendFamiliesForEnterpriseInviteRedeemedToFamilyUserEmailAsync(string email) + { + var message = CreateDefaultMessage("Success! Families Subscription Accepted", email); + await AddMessageContentAsync(message, "FamiliesForEnterprise.FamiliesForEnterpriseRedeemedToFamilyUser", new BaseMailModel()); + message.Category = "FamilyForEnterpriseRedeemedToFamilyUser"; + await _mailDeliveryService.SendEmailAsync(message); + } + + private async Task SendFamiliesForEnterpriseInviteRedeemedToEnterpriseUserEmailAsync(string email) + { + var message = CreateDefaultMessage("Success! Families Subscription Accepted", email); + await AddMessageContentAsync(message, "FamiliesForEnterprise.FamiliesForEnterpriseRedeemedToEnterpriseUser", new BaseMailModel()); + message.Category = "FamilyForEnterpriseRedeemedToEnterpriseUser"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(string email, DateTime expirationDate) + { + var message = CreateDefaultMessage("Your Families Sponsorship was Removed", email); + var model = new FamiliesForEnterpriseSponsorshipRevertingViewModel + { + ExpirationDate = expirationDate, + }; + await AddMessageContentAsync(message, "FamiliesForEnterprise.FamiliesForEnterpriseSponsorshipReverting", model); + message.Category = "FamiliesForEnterpriseSponsorshipReverting"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendOTPEmailAsync(string email, string token) + { + var message = CreateDefaultMessage("Your Bitwarden Verification Code", email); + var model = new EmailTokenViewModel + { + Token = token, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + }; + await AddMessageContentAsync(message, "OTPEmail", model); + message.MetaData.Add("SendGridBypassListManagement", true); + message.Category = "OTP"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendFailedLoginAttemptsEmailAsync(string email, DateTime utcNow, string ip) + { + var message = CreateDefaultMessage("Failed login attempts detected", email); + var model = new FailedAuthAttemptsModel() + { + TheDate = utcNow.ToLongDateString(), + TheTime = utcNow.ToShortTimeString(), + TimeZone = "UTC", + IpAddress = ip, + AffectedEmail = email + + }; + await AddMessageContentAsync(message, "FailedLoginAttempts", model); + message.Category = "FailedLoginAttempts"; + await _mailDeliveryService.SendEmailAsync(message); + } + + public async Task SendFailedTwoFactorAttemptsEmailAsync(string email, DateTime utcNow, string ip) + { + var message = CreateDefaultMessage("Failed login attempts detected", email); + var model = new FailedAuthAttemptsModel() + { + TheDate = utcNow.ToLongDateString(), + TheTime = utcNow.ToShortTimeString(), + TimeZone = "UTC", + IpAddress = ip, + AffectedEmail = email + + }; + await AddMessageContentAsync(message, "FailedTwoFactorAttempts", model); + message.Category = "FailedTwoFactorAttempts"; + await _mailDeliveryService.SendEmailAsync(message); + } + + private static string GetUserIdentifier(string email, string userName) + { + return string.IsNullOrEmpty(userName) ? email : CoreHelpers.SanitizeForEmail(userName, false); + } } diff --git a/src/Core/Services/Implementations/I18nService.cs b/src/Core/Services/Implementations/I18nService.cs index e9675ca58..7d99dacba 100644 --- a/src/Core/Services/Implementations/I18nService.cs +++ b/src/Core/Services/Implementations/I18nService.cs @@ -2,36 +2,35 @@ using Bit.Core.Resources; using Microsoft.Extensions.Localization; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class I18nService : II18nService { - public class I18nService : II18nService + private readonly IStringLocalizer _localizer; + + public I18nService(IStringLocalizerFactory factory) { - private readonly IStringLocalizer _localizer; + var assemblyName = new AssemblyName(typeof(SharedResources).GetTypeInfo().Assembly.FullName); + _localizer = factory.Create("SharedResources", assemblyName.Name); + } - public I18nService(IStringLocalizerFactory factory) - { - var assemblyName = new AssemblyName(typeof(SharedResources).GetTypeInfo().Assembly.FullName); - _localizer = factory.Create("SharedResources", assemblyName.Name); - } + public LocalizedString GetLocalizedHtmlString(string key) + { + return _localizer[key]; + } - public LocalizedString GetLocalizedHtmlString(string key) - { - return _localizer[key]; - } + public LocalizedString GetLocalizedHtmlString(string key, params object[] args) + { + return _localizer[key, args]; + } - public LocalizedString GetLocalizedHtmlString(string key, params object[] args) - { - return _localizer[key, args]; - } + public string Translate(string key, params object[] args) + { + return string.Format(GetLocalizedHtmlString(key).ToString(), args); + } - public string Translate(string key, params object[] args) - { - return string.Format(GetLocalizedHtmlString(key).ToString(), args); - } - - public string T(string key, params object[] args) - { - return Translate(key, args); - } + public string T(string key, params object[] args) + { + return Translate(key, args); } } diff --git a/src/Core/Services/Implementations/I18nViewLocalizer.cs b/src/Core/Services/Implementations/I18nViewLocalizer.cs index 4a8d86678..69699d9c4 100644 --- a/src/Core/Services/Implementations/I18nViewLocalizer.cs +++ b/src/Core/Services/Implementations/I18nViewLocalizer.cs @@ -3,29 +3,28 @@ using Bit.Core.Resources; using Microsoft.AspNetCore.Mvc.Localization; using Microsoft.Extensions.Localization; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class I18nViewLocalizer : IViewLocalizer { - public class I18nViewLocalizer : IViewLocalizer + private readonly IStringLocalizer _stringLocalizer; + private readonly IHtmlLocalizer _htmlLocalizer; + + public I18nViewLocalizer(IStringLocalizerFactory stringFactory, + IHtmlLocalizerFactory htmlFactory) { - private readonly IStringLocalizer _stringLocalizer; - private readonly IHtmlLocalizer _htmlLocalizer; - - public I18nViewLocalizer(IStringLocalizerFactory stringFactory, - IHtmlLocalizerFactory htmlFactory) - { - var assemblyName = new AssemblyName(typeof(SharedResources).GetTypeInfo().Assembly.FullName); - _stringLocalizer = stringFactory.Create("SharedResources", assemblyName.Name); - _htmlLocalizer = htmlFactory.Create("SharedResources", assemblyName.Name); - } - - public LocalizedHtmlString this[string name] => _htmlLocalizer[name]; - public LocalizedHtmlString this[string name, params object[] args] => _htmlLocalizer[name, args]; - - public IEnumerable GetAllStrings(bool includeParentCultures) => - _stringLocalizer.GetAllStrings(includeParentCultures); - - public LocalizedString GetString(string name) => _stringLocalizer[name]; - public LocalizedString GetString(string name, params object[] arguments) => - _stringLocalizer[name, arguments]; + var assemblyName = new AssemblyName(typeof(SharedResources).GetTypeInfo().Assembly.FullName); + _stringLocalizer = stringFactory.Create("SharedResources", assemblyName.Name); + _htmlLocalizer = htmlFactory.Create("SharedResources", assemblyName.Name); } + + public LocalizedHtmlString this[string name] => _htmlLocalizer[name]; + public LocalizedHtmlString this[string name, params object[] args] => _htmlLocalizer[name, args]; + + public IEnumerable GetAllStrings(bool includeParentCultures) => + _stringLocalizer.GetAllStrings(includeParentCultures); + + public LocalizedString GetString(string name) => _stringLocalizer[name]; + public LocalizedString GetString(string name, params object[] arguments) => + _stringLocalizer[name, arguments]; } diff --git a/src/Core/Services/Implementations/InMemoryApplicationCacheService.cs b/src/Core/Services/Implementations/InMemoryApplicationCacheService.cs index 98333ff55..dc23fcdb8 100644 --- a/src/Core/Services/Implementations/InMemoryApplicationCacheService.cs +++ b/src/Core/Services/Implementations/InMemoryApplicationCacheService.cs @@ -4,97 +4,96 @@ using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations; using Bit.Core.Repositories; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class InMemoryApplicationCacheService : IApplicationCacheService { - public class InMemoryApplicationCacheService : IApplicationCacheService + private readonly IOrganizationRepository _organizationRepository; + private readonly IProviderRepository _providerRepository; + private DateTime _lastOrgAbilityRefresh = DateTime.MinValue; + private IDictionary _orgAbilities; + private TimeSpan _orgAbilitiesRefreshInterval = TimeSpan.FromMinutes(10); + + private IDictionary _providerAbilities; + + public InMemoryApplicationCacheService( + IOrganizationRepository organizationRepository, IProviderRepository providerRepository) { - private readonly IOrganizationRepository _organizationRepository; - private readonly IProviderRepository _providerRepository; - private DateTime _lastOrgAbilityRefresh = DateTime.MinValue; - private IDictionary _orgAbilities; - private TimeSpan _orgAbilitiesRefreshInterval = TimeSpan.FromMinutes(10); + _organizationRepository = organizationRepository; + _providerRepository = providerRepository; + } - private IDictionary _providerAbilities; + public virtual async Task> GetOrganizationAbilitiesAsync() + { + await InitOrganizationAbilitiesAsync(); + return _orgAbilities; + } - public InMemoryApplicationCacheService( - IOrganizationRepository organizationRepository, IProviderRepository providerRepository) + public virtual async Task> GetProviderAbilitiesAsync() + { + await InitProviderAbilitiesAsync(); + return _providerAbilities; + } + + public virtual async Task UpsertProviderAbilityAsync(Provider provider) + { + await InitProviderAbilitiesAsync(); + var newAbility = new ProviderAbility(provider); + + if (_providerAbilities.ContainsKey(provider.Id)) { - _organizationRepository = organizationRepository; - _providerRepository = providerRepository; + _providerAbilities[provider.Id] = newAbility; + } + else + { + _providerAbilities.Add(provider.Id, newAbility); + } + } + + public virtual async Task UpsertOrganizationAbilityAsync(Organization organization) + { + await InitOrganizationAbilitiesAsync(); + var newAbility = new OrganizationAbility(organization); + + if (_orgAbilities.ContainsKey(organization.Id)) + { + _orgAbilities[organization.Id] = newAbility; + } + else + { + _orgAbilities.Add(organization.Id, newAbility); + } + } + + public virtual Task DeleteOrganizationAbilityAsync(Guid organizationId) + { + if (_orgAbilities != null && _orgAbilities.ContainsKey(organizationId)) + { + _orgAbilities.Remove(organizationId); } - public virtual async Task> GetOrganizationAbilitiesAsync() + return Task.FromResult(0); + } + + private async Task InitOrganizationAbilitiesAsync() + { + var now = DateTime.UtcNow; + if (_orgAbilities == null || (now - _lastOrgAbilityRefresh) > _orgAbilitiesRefreshInterval) { - await InitOrganizationAbilitiesAsync(); - return _orgAbilities; + var abilities = await _organizationRepository.GetManyAbilitiesAsync(); + _orgAbilities = abilities.ToDictionary(a => a.Id); + _lastOrgAbilityRefresh = now; } + } - public virtual async Task> GetProviderAbilitiesAsync() + private async Task InitProviderAbilitiesAsync() + { + var now = DateTime.UtcNow; + if (_providerAbilities == null || (now - _lastOrgAbilityRefresh) > _orgAbilitiesRefreshInterval) { - await InitProviderAbilitiesAsync(); - return _providerAbilities; - } - - public virtual async Task UpsertProviderAbilityAsync(Provider provider) - { - await InitProviderAbilitiesAsync(); - var newAbility = new ProviderAbility(provider); - - if (_providerAbilities.ContainsKey(provider.Id)) - { - _providerAbilities[provider.Id] = newAbility; - } - else - { - _providerAbilities.Add(provider.Id, newAbility); - } - } - - public virtual async Task UpsertOrganizationAbilityAsync(Organization organization) - { - await InitOrganizationAbilitiesAsync(); - var newAbility = new OrganizationAbility(organization); - - if (_orgAbilities.ContainsKey(organization.Id)) - { - _orgAbilities[organization.Id] = newAbility; - } - else - { - _orgAbilities.Add(organization.Id, newAbility); - } - } - - public virtual Task DeleteOrganizationAbilityAsync(Guid organizationId) - { - if (_orgAbilities != null && _orgAbilities.ContainsKey(organizationId)) - { - _orgAbilities.Remove(organizationId); - } - - return Task.FromResult(0); - } - - private async Task InitOrganizationAbilitiesAsync() - { - var now = DateTime.UtcNow; - if (_orgAbilities == null || (now - _lastOrgAbilityRefresh) > _orgAbilitiesRefreshInterval) - { - var abilities = await _organizationRepository.GetManyAbilitiesAsync(); - _orgAbilities = abilities.ToDictionary(a => a.Id); - _lastOrgAbilityRefresh = now; - } - } - - private async Task InitProviderAbilitiesAsync() - { - var now = DateTime.UtcNow; - if (_providerAbilities == null || (now - _lastOrgAbilityRefresh) > _orgAbilitiesRefreshInterval) - { - var abilities = await _providerRepository.GetManyAbilitiesAsync(); - _providerAbilities = abilities.ToDictionary(a => a.Id); - _lastOrgAbilityRefresh = now; - } + var abilities = await _providerRepository.GetManyAbilitiesAsync(); + _providerAbilities = abilities.ToDictionary(a => a.Id); + _lastOrgAbilityRefresh = now; } } } diff --git a/src/Core/Services/Implementations/InMemoryServiceBusApplicationCacheService.cs b/src/Core/Services/Implementations/InMemoryServiceBusApplicationCacheService.cs index c12efb409..1c059e4ca 100644 --- a/src/Core/Services/Implementations/InMemoryServiceBusApplicationCacheService.cs +++ b/src/Core/Services/Implementations/InMemoryServiceBusApplicationCacheService.cs @@ -5,62 +5,61 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.Azure.ServiceBus; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class InMemoryServiceBusApplicationCacheService : InMemoryApplicationCacheService, IApplicationCacheService { - public class InMemoryServiceBusApplicationCacheService : InMemoryApplicationCacheService, IApplicationCacheService + private readonly TopicClient _topicClient; + private readonly string _subName; + + public InMemoryServiceBusApplicationCacheService( + IOrganizationRepository organizationRepository, + IProviderRepository providerRepository, + GlobalSettings globalSettings) + : base(organizationRepository, providerRepository) { - private readonly TopicClient _topicClient; - private readonly string _subName; + _subName = CoreHelpers.GetApplicationCacheServiceBusSubcriptionName(globalSettings); + _topicClient = new TopicClient(globalSettings.ServiceBus.ConnectionString, + globalSettings.ServiceBus.ApplicationCacheTopicName); + } - public InMemoryServiceBusApplicationCacheService( - IOrganizationRepository organizationRepository, - IProviderRepository providerRepository, - GlobalSettings globalSettings) - : base(organizationRepository, providerRepository) + public override async Task UpsertOrganizationAbilityAsync(Organization organization) + { + await base.UpsertOrganizationAbilityAsync(organization); + var message = new Message { - _subName = CoreHelpers.GetApplicationCacheServiceBusSubcriptionName(globalSettings); - _topicClient = new TopicClient(globalSettings.ServiceBus.ConnectionString, - globalSettings.ServiceBus.ApplicationCacheTopicName); - } - - public override async Task UpsertOrganizationAbilityAsync(Organization organization) - { - await base.UpsertOrganizationAbilityAsync(organization); - var message = new Message + Label = _subName, + UserProperties = { - Label = _subName, - UserProperties = - { - { "type", (byte)ApplicationCacheMessageType.UpsertOrganizationAbility }, - { "id", organization.Id }, - } - }; - var task = _topicClient.SendAsync(message); - } + { "type", (byte)ApplicationCacheMessageType.UpsertOrganizationAbility }, + { "id", organization.Id }, + } + }; + var task = _topicClient.SendAsync(message); + } - public override async Task DeleteOrganizationAbilityAsync(Guid organizationId) + public override async Task DeleteOrganizationAbilityAsync(Guid organizationId) + { + await base.DeleteOrganizationAbilityAsync(organizationId); + var message = new Message { - await base.DeleteOrganizationAbilityAsync(organizationId); - var message = new Message + Label = _subName, + UserProperties = { - Label = _subName, - UserProperties = - { - { "type", (byte)ApplicationCacheMessageType.DeleteOrganizationAbility }, - { "id", organizationId }, - } - }; - var task = _topicClient.SendAsync(message); - } + { "type", (byte)ApplicationCacheMessageType.DeleteOrganizationAbility }, + { "id", organizationId }, + } + }; + var task = _topicClient.SendAsync(message); + } - public async Task BaseUpsertOrganizationAbilityAsync(Organization organization) - { - await base.UpsertOrganizationAbilityAsync(organization); - } + public async Task BaseUpsertOrganizationAbilityAsync(Organization organization) + { + await base.UpsertOrganizationAbilityAsync(organization); + } - public async Task BaseDeleteOrganizationAbilityAsync(Guid organizationId) - { - await base.DeleteOrganizationAbilityAsync(organizationId); - } + public async Task BaseDeleteOrganizationAbilityAsync(Guid organizationId) + { + await base.DeleteOrganizationAbilityAsync(organizationId); } } diff --git a/src/Core/Services/Implementations/LicensingService.cs b/src/Core/Services/Implementations/LicensingService.cs index 70625e758..893ea2268 100644 --- a/src/Core/Services/Implementations/LicensingService.cs +++ b/src/Core/Services/Implementations/LicensingService.cs @@ -10,252 +10,251 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class LicensingService : ILicensingService { - public class LicensingService : ILicensingService + private readonly X509Certificate2 _certificate; + private readonly IGlobalSettings _globalSettings; + private readonly IUserRepository _userRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IMailService _mailService; + private readonly ILogger _logger; + + private IDictionary _userCheckCache = new Dictionary(); + + public LicensingService( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IMailService mailService, + IWebHostEnvironment environment, + ILogger logger, + IGlobalSettings globalSettings) { - private readonly X509Certificate2 _certificate; - private readonly IGlobalSettings _globalSettings; - private readonly IUserRepository _userRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IMailService _mailService; - private readonly ILogger _logger; + _userRepository = userRepository; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _mailService = mailService; + _logger = logger; + _globalSettings = globalSettings; - private IDictionary _userCheckCache = new Dictionary(); - - public LicensingService( - IUserRepository userRepository, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IMailService mailService, - IWebHostEnvironment environment, - ILogger logger, - IGlobalSettings globalSettings) + var certThumbprint = environment.IsDevelopment() ? + "207E64A231E8AA32AAF68A61037C075EBEBD553F" : + "‎B34876439FCDA2846505B2EFBBA6C4A951313EBE"; + if (_globalSettings.SelfHosted) { - _userRepository = userRepository; - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _mailService = mailService; - _logger = logger; - _globalSettings = globalSettings; - - var certThumbprint = environment.IsDevelopment() ? - "207E64A231E8AA32AAF68A61037C075EBEBD553F" : - "‎B34876439FCDA2846505B2EFBBA6C4A951313EBE"; - if (_globalSettings.SelfHosted) - { - _certificate = CoreHelpers.GetEmbeddedCertificateAsync(environment.IsDevelopment() ? "licensing_dev.cer" : "licensing.cer", null) - .GetAwaiter().GetResult(); - } - else if (CoreHelpers.SettingHasValue(_globalSettings.Storage?.ConnectionString) && - CoreHelpers.SettingHasValue(_globalSettings.LicenseCertificatePassword)) - { - _certificate = CoreHelpers.GetBlobCertificateAsync(globalSettings.Storage.ConnectionString, "certificates", - "licensing.pfx", _globalSettings.LicenseCertificatePassword) - .GetAwaiter().GetResult(); - } - else - { - _certificate = CoreHelpers.GetCertificate(certThumbprint); - } - - if (_certificate == null || !_certificate.Thumbprint.Equals(CoreHelpers.CleanCertificateThumbprint(certThumbprint), - StringComparison.InvariantCultureIgnoreCase)) - { - throw new Exception("Invalid licensing certificate."); - } - - if (_globalSettings.SelfHosted && !CoreHelpers.SettingHasValue(_globalSettings.LicenseDirectory)) - { - throw new InvalidOperationException("No license directory."); - } + _certificate = CoreHelpers.GetEmbeddedCertificateAsync(environment.IsDevelopment() ? "licensing_dev.cer" : "licensing.cer", null) + .GetAwaiter().GetResult(); + } + else if (CoreHelpers.SettingHasValue(_globalSettings.Storage?.ConnectionString) && + CoreHelpers.SettingHasValue(_globalSettings.LicenseCertificatePassword)) + { + _certificate = CoreHelpers.GetBlobCertificateAsync(globalSettings.Storage.ConnectionString, "certificates", + "licensing.pfx", _globalSettings.LicenseCertificatePassword) + .GetAwaiter().GetResult(); + } + else + { + _certificate = CoreHelpers.GetCertificate(certThumbprint); } - public async Task ValidateOrganizationsAsync() + if (_certificate == null || !_certificate.Thumbprint.Equals(CoreHelpers.CleanCertificateThumbprint(certThumbprint), + StringComparison.InvariantCultureIgnoreCase)) { - if (!_globalSettings.SelfHosted) - { - return; - } - - var enabledOrgs = await _organizationRepository.GetManyByEnabledAsync(); - _logger.LogInformation(Constants.BypassFiltersEventId, null, - "Validating licenses for {0} organizations.", enabledOrgs.Count); - - foreach (var org in enabledOrgs) - { - var license = await ReadOrganizationLicenseAsync(org); - if (license == null) - { - await DisableOrganizationAsync(org, null, "No license file."); - continue; - } - - var totalLicensedOrgs = enabledOrgs.Count(o => o.LicenseKey.Equals(license.LicenseKey)); - if (totalLicensedOrgs > 1) - { - await DisableOrganizationAsync(org, license, "Multiple organizations."); - continue; - } - - if (!license.VerifyData(org, _globalSettings)) - { - await DisableOrganizationAsync(org, license, "Invalid data."); - continue; - } - - if (!license.VerifySignature(_certificate)) - { - await DisableOrganizationAsync(org, license, "Invalid signature."); - continue; - } - } + throw new Exception("Invalid licensing certificate."); } - private async Task DisableOrganizationAsync(Organization org, ILicense license, string reason) + if (_globalSettings.SelfHosted && !CoreHelpers.SettingHasValue(_globalSettings.LicenseDirectory)) { - _logger.LogInformation(Constants.BypassFiltersEventId, null, - "Organization {0} ({1}) has an invalid license and is being disabled. Reason: {2}", - org.Id, org.Name, reason); - org.Enabled = false; - org.ExpirationDate = license?.Expires ?? DateTime.UtcNow; - org.RevisionDate = DateTime.UtcNow; - await _organizationRepository.ReplaceAsync(org); + throw new InvalidOperationException("No license directory."); + } + } - await _mailService.SendLicenseExpiredAsync(new List { org.BillingEmail }, org.Name); + public async Task ValidateOrganizationsAsync() + { + if (!_globalSettings.SelfHosted) + { + return; } - public async Task ValidateUsersAsync() + var enabledOrgs = await _organizationRepository.GetManyByEnabledAsync(); + _logger.LogInformation(Constants.BypassFiltersEventId, null, + "Validating licenses for {0} organizations.", enabledOrgs.Count); + + foreach (var org in enabledOrgs) { - if (!_globalSettings.SelfHosted) - { - return; - } - - var premiumUsers = await _userRepository.GetManyByPremiumAsync(true); - _logger.LogInformation(Constants.BypassFiltersEventId, null, - "Validating premium for {0} users.", premiumUsers.Count); - - foreach (var user in premiumUsers) - { - await ProcessUserValidationAsync(user); - } - } - - public async Task ValidateUserPremiumAsync(User user) - { - if (!_globalSettings.SelfHosted) - { - return user.Premium; - } - - if (!user.Premium) - { - return false; - } - - // Only check once per day - var now = DateTime.UtcNow; - if (_userCheckCache.ContainsKey(user.Id)) - { - var lastCheck = _userCheckCache[user.Id]; - if (lastCheck < now && now - lastCheck < TimeSpan.FromDays(1)) - { - return user.Premium; - } - else - { - _userCheckCache[user.Id] = now; - } - } - else - { - _userCheckCache.Add(user.Id, now); - } - - _logger.LogInformation(Constants.BypassFiltersEventId, null, - "Validating premium license for user {0}({1}).", user.Id, user.Email); - return await ProcessUserValidationAsync(user); - } - - private async Task ProcessUserValidationAsync(User user) - { - var license = ReadUserLicense(user); + var license = await ReadOrganizationLicenseAsync(org); if (license == null) { - await DisablePremiumAsync(user, null, "No license file."); - return false; + await DisableOrganizationAsync(org, null, "No license file."); + continue; } - if (!license.VerifyData(user)) + var totalLicensedOrgs = enabledOrgs.Count(o => o.LicenseKey.Equals(license.LicenseKey)); + if (totalLicensedOrgs > 1) { - await DisablePremiumAsync(user, license, "Invalid data."); - return false; + await DisableOrganizationAsync(org, license, "Multiple organizations."); + continue; + } + + if (!license.VerifyData(org, _globalSettings)) + { + await DisableOrganizationAsync(org, license, "Invalid data."); + continue; } if (!license.VerifySignature(_certificate)) { - await DisablePremiumAsync(user, license, "Invalid signature."); - return false; + await DisableOrganizationAsync(org, license, "Invalid signature."); + continue; } - - return true; - } - - private async Task DisablePremiumAsync(User user, ILicense license, string reason) - { - _logger.LogInformation(Constants.BypassFiltersEventId, null, - "User {0}({1}) has an invalid license and premium is being disabled. Reason: {2}", - user.Id, user.Email, reason); - - user.Premium = false; - user.PremiumExpirationDate = license?.Expires ?? DateTime.UtcNow; - user.RevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); - - await _mailService.SendLicenseExpiredAsync(new List { user.Email }); - } - - public bool VerifyLicense(ILicense license) - { - return license.VerifySignature(_certificate); - } - - public byte[] SignLicense(ILicense license) - { - if (_globalSettings.SelfHosted || !_certificate.HasPrivateKey) - { - throw new InvalidOperationException("Cannot sign licenses."); - } - - return license.Sign(_certificate); - } - - private UserLicense ReadUserLicense(User user) - { - var filePath = $"{_globalSettings.LicenseDirectory}/user/{user.Id}.json"; - if (!File.Exists(filePath)) - { - return null; - } - - var data = File.ReadAllText(filePath, Encoding.UTF8); - return JsonSerializer.Deserialize(data); - } - - public Task ReadOrganizationLicenseAsync(Organization organization) => - ReadOrganizationLicenseAsync(organization.Id); - public async Task ReadOrganizationLicenseAsync(Guid organizationId) - { - var filePath = Path.Combine(_globalSettings.LicenseDirectory, "organization", $"{organizationId}.json"); - if (!File.Exists(filePath)) - { - return null; - } - - using var fs = File.OpenRead(filePath); - return await JsonSerializer.DeserializeAsync(fs); } } + + private async Task DisableOrganizationAsync(Organization org, ILicense license, string reason) + { + _logger.LogInformation(Constants.BypassFiltersEventId, null, + "Organization {0} ({1}) has an invalid license and is being disabled. Reason: {2}", + org.Id, org.Name, reason); + org.Enabled = false; + org.ExpirationDate = license?.Expires ?? DateTime.UtcNow; + org.RevisionDate = DateTime.UtcNow; + await _organizationRepository.ReplaceAsync(org); + + await _mailService.SendLicenseExpiredAsync(new List { org.BillingEmail }, org.Name); + } + + public async Task ValidateUsersAsync() + { + if (!_globalSettings.SelfHosted) + { + return; + } + + var premiumUsers = await _userRepository.GetManyByPremiumAsync(true); + _logger.LogInformation(Constants.BypassFiltersEventId, null, + "Validating premium for {0} users.", premiumUsers.Count); + + foreach (var user in premiumUsers) + { + await ProcessUserValidationAsync(user); + } + } + + public async Task ValidateUserPremiumAsync(User user) + { + if (!_globalSettings.SelfHosted) + { + return user.Premium; + } + + if (!user.Premium) + { + return false; + } + + // Only check once per day + var now = DateTime.UtcNow; + if (_userCheckCache.ContainsKey(user.Id)) + { + var lastCheck = _userCheckCache[user.Id]; + if (lastCheck < now && now - lastCheck < TimeSpan.FromDays(1)) + { + return user.Premium; + } + else + { + _userCheckCache[user.Id] = now; + } + } + else + { + _userCheckCache.Add(user.Id, now); + } + + _logger.LogInformation(Constants.BypassFiltersEventId, null, + "Validating premium license for user {0}({1}).", user.Id, user.Email); + return await ProcessUserValidationAsync(user); + } + + private async Task ProcessUserValidationAsync(User user) + { + var license = ReadUserLicense(user); + if (license == null) + { + await DisablePremiumAsync(user, null, "No license file."); + return false; + } + + if (!license.VerifyData(user)) + { + await DisablePremiumAsync(user, license, "Invalid data."); + return false; + } + + if (!license.VerifySignature(_certificate)) + { + await DisablePremiumAsync(user, license, "Invalid signature."); + return false; + } + + return true; + } + + private async Task DisablePremiumAsync(User user, ILicense license, string reason) + { + _logger.LogInformation(Constants.BypassFiltersEventId, null, + "User {0}({1}) has an invalid license and premium is being disabled. Reason: {2}", + user.Id, user.Email, reason); + + user.Premium = false; + user.PremiumExpirationDate = license?.Expires ?? DateTime.UtcNow; + user.RevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + + await _mailService.SendLicenseExpiredAsync(new List { user.Email }); + } + + public bool VerifyLicense(ILicense license) + { + return license.VerifySignature(_certificate); + } + + public byte[] SignLicense(ILicense license) + { + if (_globalSettings.SelfHosted || !_certificate.HasPrivateKey) + { + throw new InvalidOperationException("Cannot sign licenses."); + } + + return license.Sign(_certificate); + } + + private UserLicense ReadUserLicense(User user) + { + var filePath = $"{_globalSettings.LicenseDirectory}/user/{user.Id}.json"; + if (!File.Exists(filePath)) + { + return null; + } + + var data = File.ReadAllText(filePath, Encoding.UTF8); + return JsonSerializer.Deserialize(data); + } + + public Task ReadOrganizationLicenseAsync(Organization organization) => + ReadOrganizationLicenseAsync(organization.Id); + public async Task ReadOrganizationLicenseAsync(Guid organizationId) + { + var filePath = Path.Combine(_globalSettings.LicenseDirectory, "organization", $"{organizationId}.json"); + if (!File.Exists(filePath)) + { + return null; + } + + using var fs = File.OpenRead(filePath); + return await JsonSerializer.DeserializeAsync(fs); + } } diff --git a/src/Core/Services/Implementations/LocalAttachmentStorageService.cs b/src/Core/Services/Implementations/LocalAttachmentStorageService.cs index d24a561e3..4949ff312 100644 --- a/src/Core/Services/Implementations/LocalAttachmentStorageService.cs +++ b/src/Core/Services/Implementations/LocalAttachmentStorageService.cs @@ -3,195 +3,194 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Settings; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class LocalAttachmentStorageService : IAttachmentStorageService { - public class LocalAttachmentStorageService : IAttachmentStorageService + private readonly string _baseAttachmentUrl; + private readonly string _baseDirPath; + private readonly string _baseTempDirPath; + + public FileUploadType FileUploadType => FileUploadType.Direct; + + public LocalAttachmentStorageService( + IGlobalSettings globalSettings) { - private readonly string _baseAttachmentUrl; - private readonly string _baseDirPath; - private readonly string _baseTempDirPath; + _baseDirPath = globalSettings.Attachment.BaseDirectory; + _baseTempDirPath = $"{_baseDirPath}/temp"; + _baseAttachmentUrl = globalSettings.Attachment.BaseUrl; + } - public FileUploadType FileUploadType => FileUploadType.Direct; + public async Task GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) + { + await InitAsync(); + return $"{_baseAttachmentUrl}/{cipher.Id}/{attachmentData.AttachmentId}"; + } - public LocalAttachmentStorageService( - IGlobalSettings globalSettings) + public async Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData) + { + await InitAsync(); + var cipherDirPath = CipherDirectoryPath(cipher.Id, temp: false); + CreateDirectoryIfNotExists(cipherDirPath); + + using (var fs = File.Create(AttachmentFilePath(cipherDirPath, attachmentData.AttachmentId))) { - _baseDirPath = globalSettings.Attachment.BaseDirectory; - _baseTempDirPath = $"{_baseDirPath}/temp"; - _baseAttachmentUrl = globalSettings.Attachment.BaseUrl; - } - - public async Task GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) - { - await InitAsync(); - return $"{_baseAttachmentUrl}/{cipher.Id}/{attachmentData.AttachmentId}"; - } - - public async Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData) - { - await InitAsync(); - var cipherDirPath = CipherDirectoryPath(cipher.Id, temp: false); - CreateDirectoryIfNotExists(cipherDirPath); - - using (var fs = File.Create(AttachmentFilePath(cipherDirPath, attachmentData.AttachmentId))) - { - stream.Seek(0, SeekOrigin.Begin); - await stream.CopyToAsync(fs); - } - } - - public async Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) - { - await InitAsync(); - var tempCipherOrgDirPath = OrganizationDirectoryPath(cipherId, organizationId, temp: true); - CreateDirectoryIfNotExists(tempCipherOrgDirPath); - - using (var fs = File.Create(AttachmentFilePath(tempCipherOrgDirPath, attachmentData.AttachmentId))) - { - stream.Seek(0, SeekOrigin.Begin); - await stream.CopyToAsync(fs); - } - } - - public async Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) - { - await InitAsync(); - var sourceFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, organizationId, temp: true); - if (!File.Exists(sourceFilePath)) - { - return; - } - - var destFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: false); - if (!File.Exists(destFilePath)) - { - return; - } - - var originalFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: true); - DeleteFileIfExists(originalFilePath); - - File.Move(destFilePath, originalFilePath); - DeleteFileIfExists(destFilePath); - - File.Move(sourceFilePath, destFilePath); - } - - public async Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer) - { - await InitAsync(); - DeleteFileIfExists(AttachmentFilePath(attachmentData.AttachmentId, cipherId, organizationId, temp: true)); - - var originalFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: true); - if (!File.Exists(originalFilePath)) - { - return; - } - - var destFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: false); - DeleteFileIfExists(destFilePath); - - File.Move(originalFilePath, destFilePath); - DeleteFileIfExists(originalFilePath); - } - - public async Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData) - { - await InitAsync(); - DeleteFileIfExists(AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: false)); - } - - public async Task CleanupAsync(Guid cipherId) - { - await InitAsync(); - DeleteDirectoryIfExists(CipherDirectoryPath(cipherId, temp: true)); - } - - public async Task DeleteAttachmentsForCipherAsync(Guid cipherId) - { - await InitAsync(); - DeleteDirectoryIfExists(CipherDirectoryPath(cipherId, temp: false)); - } - - public async Task DeleteAttachmentsForOrganizationAsync(Guid organizationId) - { - await InitAsync(); - } - - public async Task DeleteAttachmentsForUserAsync(Guid userId) - { - await InitAsync(); - } - - private void DeleteFileIfExists(string path) - { - if (File.Exists(path)) - { - File.Delete(path); - } - } - - private void DeleteDirectoryIfExists(string path) - { - if (Directory.Exists(path)) - { - Directory.Delete(path, true); - } - } - - private void CreateDirectoryIfNotExists(string path) - { - if (!Directory.Exists(path)) - { - Directory.CreateDirectory(path); - } - } - - private Task InitAsync() - { - if (!Directory.Exists(_baseDirPath)) - { - Directory.CreateDirectory(_baseDirPath); - } - - if (!Directory.Exists(_baseTempDirPath)) - { - Directory.CreateDirectory(_baseTempDirPath); - } - - return Task.FromResult(0); - } - - private string CipherDirectoryPath(Guid cipherId, bool temp = false) => - Path.Combine(temp ? _baseTempDirPath : _baseDirPath, cipherId.ToString()); - private string OrganizationDirectoryPath(Guid cipherId, Guid organizationId, bool temp = false) => - Path.Combine(temp ? _baseTempDirPath : _baseDirPath, cipherId.ToString(), organizationId.ToString()); - - private string AttachmentFilePath(string dir, string attachmentId) => Path.Combine(dir, attachmentId); - private string AttachmentFilePath(string attachmentId, Guid cipherId, Guid? organizationId = null, - bool temp = false) => - organizationId.HasValue ? - AttachmentFilePath(OrganizationDirectoryPath(cipherId, organizationId.Value, temp), attachmentId) : - AttachmentFilePath(CipherDirectoryPath(cipherId, temp), attachmentId); - public Task GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) - => Task.FromResult($"{cipher.Id}/attachment/{attachmentData.AttachmentId}"); - - public Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway) - { - long? length = null; - var path = AttachmentFilePath(attachmentData.AttachmentId, cipher.Id, temp: false); - if (!File.Exists(path)) - { - return Task.FromResult((false, length)); - } - - length = new FileInfo(path).Length; - if (attachmentData.Size < length - leeway || attachmentData.Size > length + leeway) - { - return Task.FromResult((false, length)); - } - - return Task.FromResult((true, length)); + stream.Seek(0, SeekOrigin.Begin); + await stream.CopyToAsync(fs); } } + + public async Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) + { + await InitAsync(); + var tempCipherOrgDirPath = OrganizationDirectoryPath(cipherId, organizationId, temp: true); + CreateDirectoryIfNotExists(tempCipherOrgDirPath); + + using (var fs = File.Create(AttachmentFilePath(tempCipherOrgDirPath, attachmentData.AttachmentId))) + { + stream.Seek(0, SeekOrigin.Begin); + await stream.CopyToAsync(fs); + } + } + + public async Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) + { + await InitAsync(); + var sourceFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, organizationId, temp: true); + if (!File.Exists(sourceFilePath)) + { + return; + } + + var destFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: false); + if (!File.Exists(destFilePath)) + { + return; + } + + var originalFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: true); + DeleteFileIfExists(originalFilePath); + + File.Move(destFilePath, originalFilePath); + DeleteFileIfExists(destFilePath); + + File.Move(sourceFilePath, destFilePath); + } + + public async Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer) + { + await InitAsync(); + DeleteFileIfExists(AttachmentFilePath(attachmentData.AttachmentId, cipherId, organizationId, temp: true)); + + var originalFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: true); + if (!File.Exists(originalFilePath)) + { + return; + } + + var destFilePath = AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: false); + DeleteFileIfExists(destFilePath); + + File.Move(originalFilePath, destFilePath); + DeleteFileIfExists(originalFilePath); + } + + public async Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData) + { + await InitAsync(); + DeleteFileIfExists(AttachmentFilePath(attachmentData.AttachmentId, cipherId, temp: false)); + } + + public async Task CleanupAsync(Guid cipherId) + { + await InitAsync(); + DeleteDirectoryIfExists(CipherDirectoryPath(cipherId, temp: true)); + } + + public async Task DeleteAttachmentsForCipherAsync(Guid cipherId) + { + await InitAsync(); + DeleteDirectoryIfExists(CipherDirectoryPath(cipherId, temp: false)); + } + + public async Task DeleteAttachmentsForOrganizationAsync(Guid organizationId) + { + await InitAsync(); + } + + public async Task DeleteAttachmentsForUserAsync(Guid userId) + { + await InitAsync(); + } + + private void DeleteFileIfExists(string path) + { + if (File.Exists(path)) + { + File.Delete(path); + } + } + + private void DeleteDirectoryIfExists(string path) + { + if (Directory.Exists(path)) + { + Directory.Delete(path, true); + } + } + + private void CreateDirectoryIfNotExists(string path) + { + if (!Directory.Exists(path)) + { + Directory.CreateDirectory(path); + } + } + + private Task InitAsync() + { + if (!Directory.Exists(_baseDirPath)) + { + Directory.CreateDirectory(_baseDirPath); + } + + if (!Directory.Exists(_baseTempDirPath)) + { + Directory.CreateDirectory(_baseTempDirPath); + } + + return Task.FromResult(0); + } + + private string CipherDirectoryPath(Guid cipherId, bool temp = false) => + Path.Combine(temp ? _baseTempDirPath : _baseDirPath, cipherId.ToString()); + private string OrganizationDirectoryPath(Guid cipherId, Guid organizationId, bool temp = false) => + Path.Combine(temp ? _baseTempDirPath : _baseDirPath, cipherId.ToString(), organizationId.ToString()); + + private string AttachmentFilePath(string dir, string attachmentId) => Path.Combine(dir, attachmentId); + private string AttachmentFilePath(string attachmentId, Guid cipherId, Guid? organizationId = null, + bool temp = false) => + organizationId.HasValue ? + AttachmentFilePath(OrganizationDirectoryPath(cipherId, organizationId.Value, temp), attachmentId) : + AttachmentFilePath(CipherDirectoryPath(cipherId, temp), attachmentId); + public Task GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) + => Task.FromResult($"{cipher.Id}/attachment/{attachmentData.AttachmentId}"); + + public Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway) + { + long? length = null; + var path = AttachmentFilePath(attachmentData.AttachmentId, cipher.Id, temp: false); + if (!File.Exists(path)) + { + return Task.FromResult((false, length)); + } + + length = new FileInfo(path).Length; + if (attachmentData.Size < length - leeway || attachmentData.Size > length + leeway) + { + return Task.FromResult((false, length)); + } + + return Task.FromResult((true, length)); + } } diff --git a/src/Core/Services/Implementations/LocalSendStorageService.cs b/src/Core/Services/Implementations/LocalSendStorageService.cs index 200309f5b..30872cbcc 100644 --- a/src/Core/Services/Implementations/LocalSendStorageService.cs +++ b/src/Core/Services/Implementations/LocalSendStorageService.cs @@ -2,105 +2,104 @@ using Bit.Core.Enums; using Bit.Core.Settings; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class LocalSendStorageService : ISendFileStorageService { - public class LocalSendStorageService : ISendFileStorageService + private readonly string _baseDirPath; + private readonly string _baseSendUrl; + + private string RelativeFilePath(Send send, string fileID) => $"{send.Id}/{fileID}"; + private string FilePath(Send send, string fileID) => $"{_baseDirPath}/{RelativeFilePath(send, fileID)}"; + public FileUploadType FileUploadType => FileUploadType.Direct; + + public LocalSendStorageService( + GlobalSettings globalSettings) { - private readonly string _baseDirPath; - private readonly string _baseSendUrl; + _baseDirPath = globalSettings.Send.BaseDirectory; + _baseSendUrl = globalSettings.Send.BaseUrl; + } - private string RelativeFilePath(Send send, string fileID) => $"{send.Id}/{fileID}"; - private string FilePath(Send send, string fileID) => $"{_baseDirPath}/{RelativeFilePath(send, fileID)}"; - public FileUploadType FileUploadType => FileUploadType.Direct; - - public LocalSendStorageService( - GlobalSettings globalSettings) + public async Task UploadNewFileAsync(Stream stream, Send send, string fileId) + { + await InitAsync(); + var path = FilePath(send, fileId); + Directory.CreateDirectory(Path.GetDirectoryName(path)); + using (var fs = File.Create(path)) { - _baseDirPath = globalSettings.Send.BaseDirectory; - _baseSendUrl = globalSettings.Send.BaseUrl; - } - - public async Task UploadNewFileAsync(Stream stream, Send send, string fileId) - { - await InitAsync(); - var path = FilePath(send, fileId); - Directory.CreateDirectory(Path.GetDirectoryName(path)); - using (var fs = File.Create(path)) - { - stream.Seek(0, SeekOrigin.Begin); - await stream.CopyToAsync(fs); - } - } - - public async Task DeleteFileAsync(Send send, string fileId) - { - await InitAsync(); - var path = FilePath(send, fileId); - DeleteFileIfExists(path); - DeleteDirectoryIfExistsAndEmpty(Path.GetDirectoryName(path)); - } - - public async Task DeleteFilesForOrganizationAsync(Guid organizationId) - { - await InitAsync(); - } - - public async Task DeleteFilesForUserAsync(Guid userId) - { - await InitAsync(); - } - - public async Task GetSendFileDownloadUrlAsync(Send send, string fileId) - { - await InitAsync(); - return $"{_baseSendUrl}/{RelativeFilePath(send, fileId)}"; - } - - private void DeleteFileIfExists(string path) - { - if (File.Exists(path)) - { - File.Delete(path); - } - } - - private void DeleteDirectoryIfExistsAndEmpty(string path) - { - if (Directory.Exists(path) && !Directory.EnumerateFiles(path).Any()) - { - Directory.Delete(path); - } - } - - private Task InitAsync() - { - if (!Directory.Exists(_baseDirPath)) - { - Directory.CreateDirectory(_baseDirPath); - } - - return Task.FromResult(0); - } - - public Task GetSendFileUploadUrlAsync(Send send, string fileId) - => Task.FromResult($"/sends/{send.Id}/file/{fileId}"); - - public Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway) - { - long? length = null; - var path = FilePath(send, fileId); - if (!File.Exists(path)) - { - return Task.FromResult((false, length)); - } - - length = new FileInfo(path).Length; - if (expectedFileSize < length - leeway || expectedFileSize > length + leeway) - { - return Task.FromResult((false, length)); - } - - return Task.FromResult((true, length)); + stream.Seek(0, SeekOrigin.Begin); + await stream.CopyToAsync(fs); } } + + public async Task DeleteFileAsync(Send send, string fileId) + { + await InitAsync(); + var path = FilePath(send, fileId); + DeleteFileIfExists(path); + DeleteDirectoryIfExistsAndEmpty(Path.GetDirectoryName(path)); + } + + public async Task DeleteFilesForOrganizationAsync(Guid organizationId) + { + await InitAsync(); + } + + public async Task DeleteFilesForUserAsync(Guid userId) + { + await InitAsync(); + } + + public async Task GetSendFileDownloadUrlAsync(Send send, string fileId) + { + await InitAsync(); + return $"{_baseSendUrl}/{RelativeFilePath(send, fileId)}"; + } + + private void DeleteFileIfExists(string path) + { + if (File.Exists(path)) + { + File.Delete(path); + } + } + + private void DeleteDirectoryIfExistsAndEmpty(string path) + { + if (Directory.Exists(path) && !Directory.EnumerateFiles(path).Any()) + { + Directory.Delete(path); + } + } + + private Task InitAsync() + { + if (!Directory.Exists(_baseDirPath)) + { + Directory.CreateDirectory(_baseDirPath); + } + + return Task.FromResult(0); + } + + public Task GetSendFileUploadUrlAsync(Send send, string fileId) + => Task.FromResult($"/sends/{send.Id}/file/{fileId}"); + + public Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway) + { + long? length = null; + var path = FilePath(send, fileId); + if (!File.Exists(path)) + { + return Task.FromResult((false, length)); + } + + length = new FileInfo(path).Length; + if (expectedFileSize < length - leeway || expectedFileSize > length + leeway) + { + return Task.FromResult((false, length)); + } + + return Task.FromResult((true, length)); + } } diff --git a/src/Core/Services/Implementations/MailKitSmtpMailDeliveryService.cs b/src/Core/Services/Implementations/MailKitSmtpMailDeliveryService.cs index b4b93278e..4e7b7ee10 100644 --- a/src/Core/Services/Implementations/MailKitSmtpMailDeliveryService.cs +++ b/src/Core/Services/Implementations/MailKitSmtpMailDeliveryService.cs @@ -4,98 +4,97 @@ using MailKit.Net.Smtp; using Microsoft.Extensions.Logging; using MimeKit; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class MailKitSmtpMailDeliveryService : IMailDeliveryService { - public class MailKitSmtpMailDeliveryService : IMailDeliveryService + private readonly GlobalSettings _globalSettings; + private readonly ILogger _logger; + private readonly string _replyDomain; + private readonly string _replyEmail; + + public MailKitSmtpMailDeliveryService( + GlobalSettings globalSettings, + ILogger logger) { - private readonly GlobalSettings _globalSettings; - private readonly ILogger _logger; - private readonly string _replyDomain; - private readonly string _replyEmail; - - public MailKitSmtpMailDeliveryService( - GlobalSettings globalSettings, - ILogger logger) + if (globalSettings.Mail?.Smtp?.Host == null) { - if (globalSettings.Mail?.Smtp?.Host == null) - { - throw new ArgumentNullException(nameof(globalSettings.Mail.Smtp.Host)); - } - - _replyEmail = CoreHelpers.PunyEncode(globalSettings.Mail?.ReplyToEmail); - - if (_replyEmail.Contains("@")) - { - _replyDomain = _replyEmail.Split('@')[1]; - } - - _globalSettings = globalSettings; - _logger = logger; + throw new ArgumentNullException(nameof(globalSettings.Mail.Smtp.Host)); } - public async Task SendEmailAsync(Models.Mail.MailMessage message) - { - var mimeMessage = new MimeMessage(); - mimeMessage.From.Add(new MailboxAddress(_globalSettings.SiteName, _replyEmail)); - mimeMessage.Subject = message.Subject; - if (!string.IsNullOrWhiteSpace(_replyDomain)) - { - mimeMessage.MessageId = $"<{Guid.NewGuid()}@{_replyDomain}>"; - } + _replyEmail = CoreHelpers.PunyEncode(globalSettings.Mail?.ReplyToEmail); - foreach (var address in message.ToEmails) + if (_replyEmail.Contains("@")) + { + _replyDomain = _replyEmail.Split('@')[1]; + } + + _globalSettings = globalSettings; + _logger = logger; + } + + public async Task SendEmailAsync(Models.Mail.MailMessage message) + { + var mimeMessage = new MimeMessage(); + mimeMessage.From.Add(new MailboxAddress(_globalSettings.SiteName, _replyEmail)); + mimeMessage.Subject = message.Subject; + if (!string.IsNullOrWhiteSpace(_replyDomain)) + { + mimeMessage.MessageId = $"<{Guid.NewGuid()}@{_replyDomain}>"; + } + + foreach (var address in message.ToEmails) + { + var punyencoded = CoreHelpers.PunyEncode(address); + mimeMessage.To.Add(MailboxAddress.Parse(punyencoded)); + } + + if (message.BccEmails != null) + { + foreach (var address in message.BccEmails) { var punyencoded = CoreHelpers.PunyEncode(address); - mimeMessage.To.Add(MailboxAddress.Parse(punyencoded)); + mimeMessage.Bcc.Add(MailboxAddress.Parse(punyencoded)); } + } - if (message.BccEmails != null) + var builder = new BodyBuilder(); + if (!string.IsNullOrWhiteSpace(message.TextContent)) + { + builder.TextBody = message.TextContent; + } + builder.HtmlBody = message.HtmlContent; + mimeMessage.Body = builder.ToMessageBody(); + + using (var client = new SmtpClient()) + { + if (_globalSettings.Mail.Smtp.TrustServer) { - foreach (var address in message.BccEmails) - { - var punyencoded = CoreHelpers.PunyEncode(address); - mimeMessage.Bcc.Add(MailboxAddress.Parse(punyencoded)); - } + client.ServerCertificateValidationCallback = (s, c, h, e) => true; } - var builder = new BodyBuilder(); - if (!string.IsNullOrWhiteSpace(message.TextContent)) + if (!_globalSettings.Mail.Smtp.StartTls && !_globalSettings.Mail.Smtp.Ssl && + _globalSettings.Mail.Smtp.Port == 25) { - builder.TextBody = message.TextContent; + await client.ConnectAsync(_globalSettings.Mail.Smtp.Host, _globalSettings.Mail.Smtp.Port, + MailKit.Security.SecureSocketOptions.None); } - builder.HtmlBody = message.HtmlContent; - mimeMessage.Body = builder.ToMessageBody(); - - using (var client = new SmtpClient()) + else { - if (_globalSettings.Mail.Smtp.TrustServer) - { - client.ServerCertificateValidationCallback = (s, c, h, e) => true; - } - - if (!_globalSettings.Mail.Smtp.StartTls && !_globalSettings.Mail.Smtp.Ssl && - _globalSettings.Mail.Smtp.Port == 25) - { - await client.ConnectAsync(_globalSettings.Mail.Smtp.Host, _globalSettings.Mail.Smtp.Port, - MailKit.Security.SecureSocketOptions.None); - } - else - { - var useSsl = _globalSettings.Mail.Smtp.Port == 587 && !_globalSettings.Mail.Smtp.SslOverride ? - false : _globalSettings.Mail.Smtp.Ssl; - await client.ConnectAsync(_globalSettings.Mail.Smtp.Host, _globalSettings.Mail.Smtp.Port, useSsl); - } - - if (CoreHelpers.SettingHasValue(_globalSettings.Mail.Smtp.Username) && - CoreHelpers.SettingHasValue(_globalSettings.Mail.Smtp.Password)) - { - await client.AuthenticateAsync(_globalSettings.Mail.Smtp.Username, - _globalSettings.Mail.Smtp.Password); - } - - await client.SendAsync(mimeMessage); - await client.DisconnectAsync(true); + var useSsl = _globalSettings.Mail.Smtp.Port == 587 && !_globalSettings.Mail.Smtp.SslOverride ? + false : _globalSettings.Mail.Smtp.Ssl; + await client.ConnectAsync(_globalSettings.Mail.Smtp.Host, _globalSettings.Mail.Smtp.Port, useSsl); } + + if (CoreHelpers.SettingHasValue(_globalSettings.Mail.Smtp.Username) && + CoreHelpers.SettingHasValue(_globalSettings.Mail.Smtp.Password)) + { + await client.AuthenticateAsync(_globalSettings.Mail.Smtp.Username, + _globalSettings.Mail.Smtp.Password); + } + + await client.SendAsync(mimeMessage); + await client.DisconnectAsync(true); } } } diff --git a/src/Core/Services/Implementations/MultiServiceMailDeliveryService.cs b/src/Core/Services/Implementations/MultiServiceMailDeliveryService.cs index 286415fc2..e08841096 100644 --- a/src/Core/Services/Implementations/MultiServiceMailDeliveryService.cs +++ b/src/Core/Services/Implementations/MultiServiceMailDeliveryService.cs @@ -3,40 +3,39 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class MultiServiceMailDeliveryService : IMailDeliveryService { - public class MultiServiceMailDeliveryService : IMailDeliveryService + private readonly IMailDeliveryService _sesService; + private readonly IMailDeliveryService _sendGridService; + private readonly int _sendGridPercentage; + + private static Random _random = new Random(); + + public MultiServiceMailDeliveryService( + GlobalSettings globalSettings, + IWebHostEnvironment hostingEnvironment, + ILogger sesLogger, + ILogger sendGridLogger) { - private readonly IMailDeliveryService _sesService; - private readonly IMailDeliveryService _sendGridService; - private readonly int _sendGridPercentage; + _sesService = new AmazonSesMailDeliveryService(globalSettings, hostingEnvironment, sesLogger); + _sendGridService = new SendGridMailDeliveryService(globalSettings, hostingEnvironment, sendGridLogger); - private static Random _random = new Random(); + // disabled by default (-1) + _sendGridPercentage = (globalSettings.Mail?.SendGridPercentage).GetValueOrDefault(-1); + } - public MultiServiceMailDeliveryService( - GlobalSettings globalSettings, - IWebHostEnvironment hostingEnvironment, - ILogger sesLogger, - ILogger sendGridLogger) + public async Task SendEmailAsync(MailMessage message) + { + var roll = _random.Next(0, 99); + if (roll < _sendGridPercentage) { - _sesService = new AmazonSesMailDeliveryService(globalSettings, hostingEnvironment, sesLogger); - _sendGridService = new SendGridMailDeliveryService(globalSettings, hostingEnvironment, sendGridLogger); - - // disabled by default (-1) - _sendGridPercentage = (globalSettings.Mail?.SendGridPercentage).GetValueOrDefault(-1); + await _sendGridService.SendEmailAsync(message); } - - public async Task SendEmailAsync(MailMessage message) + else { - var roll = _random.Next(0, 99); - if (roll < _sendGridPercentage) - { - await _sendGridService.SendEmailAsync(message); - } - else - { - await _sesService.SendEmailAsync(message); - } + await _sesService.SendEmailAsync(message); } } } diff --git a/src/Core/Services/Implementations/MultiServicePushNotificationService.cs b/src/Core/Services/Implementations/MultiServicePushNotificationService.cs index f940bad00..4e1678da6 100644 --- a/src/Core/Services/Implementations/MultiServicePushNotificationService.cs +++ b/src/Core/Services/Implementations/MultiServicePushNotificationService.cs @@ -6,161 +6,160 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class MultiServicePushNotificationService : IPushNotificationService { - public class MultiServicePushNotificationService : IPushNotificationService + private readonly List _services = new List(); + private readonly ILogger _logger; + + public MultiServicePushNotificationService( + IHttpClientFactory httpFactory, + IDeviceRepository deviceRepository, + IInstallationDeviceRepository installationDeviceRepository, + GlobalSettings globalSettings, + IHttpContextAccessor httpContextAccessor, + ILogger logger, + ILogger relayLogger, + ILogger hubLogger) { - private readonly List _services = new List(); - private readonly ILogger _logger; - - public MultiServicePushNotificationService( - IHttpClientFactory httpFactory, - IDeviceRepository deviceRepository, - IInstallationDeviceRepository installationDeviceRepository, - GlobalSettings globalSettings, - IHttpContextAccessor httpContextAccessor, - ILogger logger, - ILogger relayLogger, - ILogger hubLogger) + if (globalSettings.SelfHosted) { - if (globalSettings.SelfHosted) + if (CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) && + globalSettings.Installation?.Id != null && + CoreHelpers.SettingHasValue(globalSettings.Installation?.Key)) { - if (CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) && - globalSettings.Installation?.Id != null && - CoreHelpers.SettingHasValue(globalSettings.Installation?.Key)) - { - _services.Add(new RelayPushNotificationService(httpFactory, deviceRepository, globalSettings, - httpContextAccessor, relayLogger)); - } - if (CoreHelpers.SettingHasValue(globalSettings.InternalIdentityKey) && - CoreHelpers.SettingHasValue(globalSettings.BaseServiceUri.InternalNotifications)) - { - _services.Add(new NotificationsApiPushNotificationService( - httpFactory, globalSettings, httpContextAccessor, hubLogger)); - } + _services.Add(new RelayPushNotificationService(httpFactory, deviceRepository, globalSettings, + httpContextAccessor, relayLogger)); } - else + if (CoreHelpers.SettingHasValue(globalSettings.InternalIdentityKey) && + CoreHelpers.SettingHasValue(globalSettings.BaseServiceUri.InternalNotifications)) { - if (CoreHelpers.SettingHasValue(globalSettings.NotificationHub.ConnectionString)) - { - _services.Add(new NotificationHubPushNotificationService(installationDeviceRepository, - globalSettings, httpContextAccessor)); - } - if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString)) - { - _services.Add(new AzureQueuePushNotificationService(globalSettings, httpContextAccessor)); - } + _services.Add(new NotificationsApiPushNotificationService( + httpFactory, globalSettings, httpContextAccessor, hubLogger)); } - - _logger = logger; } - - public Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) + else { - PushToServices((s) => s.PushSyncCipherCreateAsync(cipher, collectionIds)); - return Task.FromResult(0); - } - - public Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) - { - PushToServices((s) => s.PushSyncCipherUpdateAsync(cipher, collectionIds)); - return Task.FromResult(0); - } - - public Task PushSyncCipherDeleteAsync(Cipher cipher) - { - PushToServices((s) => s.PushSyncCipherDeleteAsync(cipher)); - return Task.FromResult(0); - } - - public Task PushSyncFolderCreateAsync(Folder folder) - { - PushToServices((s) => s.PushSyncFolderCreateAsync(folder)); - return Task.FromResult(0); - } - - public Task PushSyncFolderUpdateAsync(Folder folder) - { - PushToServices((s) => s.PushSyncFolderUpdateAsync(folder)); - return Task.FromResult(0); - } - - public Task PushSyncFolderDeleteAsync(Folder folder) - { - PushToServices((s) => s.PushSyncFolderDeleteAsync(folder)); - return Task.FromResult(0); - } - - public Task PushSyncCiphersAsync(Guid userId) - { - PushToServices((s) => s.PushSyncCiphersAsync(userId)); - return Task.FromResult(0); - } - - public Task PushSyncVaultAsync(Guid userId) - { - PushToServices((s) => s.PushSyncVaultAsync(userId)); - return Task.FromResult(0); - } - - public Task PushSyncOrgKeysAsync(Guid userId) - { - PushToServices((s) => s.PushSyncOrgKeysAsync(userId)); - return Task.FromResult(0); - } - - public Task PushSyncSettingsAsync(Guid userId) - { - PushToServices((s) => s.PushSyncSettingsAsync(userId)); - return Task.FromResult(0); - } - - public Task PushLogOutAsync(Guid userId) - { - PushToServices((s) => s.PushLogOutAsync(userId)); - return Task.FromResult(0); - } - - public Task PushSyncSendCreateAsync(Send send) - { - PushToServices((s) => s.PushSyncSendCreateAsync(send)); - return Task.FromResult(0); - } - - public Task PushSyncSendUpdateAsync(Send send) - { - PushToServices((s) => s.PushSyncSendUpdateAsync(send)); - return Task.FromResult(0); - } - - public Task PushSyncSendDeleteAsync(Send send) - { - PushToServices((s) => s.PushSyncSendDeleteAsync(send)); - return Task.FromResult(0); - } - - public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, - string deviceId = null) - { - PushToServices((s) => s.SendPayloadToUserAsync(userId, type, payload, identifier, deviceId)); - return Task.FromResult(0); - } - - public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null) - { - PushToServices((s) => s.SendPayloadToOrganizationAsync(orgId, type, payload, identifier, deviceId)); - return Task.FromResult(0); - } - - private void PushToServices(Func pushFunc) - { - if (_services != null) + if (CoreHelpers.SettingHasValue(globalSettings.NotificationHub.ConnectionString)) { - foreach (var service in _services) - { - pushFunc(service); - } + _services.Add(new NotificationHubPushNotificationService(installationDeviceRepository, + globalSettings, httpContextAccessor)); + } + if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString)) + { + _services.Add(new AzureQueuePushNotificationService(globalSettings, httpContextAccessor)); + } + } + + _logger = logger; + } + + public Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) + { + PushToServices((s) => s.PushSyncCipherCreateAsync(cipher, collectionIds)); + return Task.FromResult(0); + } + + public Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) + { + PushToServices((s) => s.PushSyncCipherUpdateAsync(cipher, collectionIds)); + return Task.FromResult(0); + } + + public Task PushSyncCipherDeleteAsync(Cipher cipher) + { + PushToServices((s) => s.PushSyncCipherDeleteAsync(cipher)); + return Task.FromResult(0); + } + + public Task PushSyncFolderCreateAsync(Folder folder) + { + PushToServices((s) => s.PushSyncFolderCreateAsync(folder)); + return Task.FromResult(0); + } + + public Task PushSyncFolderUpdateAsync(Folder folder) + { + PushToServices((s) => s.PushSyncFolderUpdateAsync(folder)); + return Task.FromResult(0); + } + + public Task PushSyncFolderDeleteAsync(Folder folder) + { + PushToServices((s) => s.PushSyncFolderDeleteAsync(folder)); + return Task.FromResult(0); + } + + public Task PushSyncCiphersAsync(Guid userId) + { + PushToServices((s) => s.PushSyncCiphersAsync(userId)); + return Task.FromResult(0); + } + + public Task PushSyncVaultAsync(Guid userId) + { + PushToServices((s) => s.PushSyncVaultAsync(userId)); + return Task.FromResult(0); + } + + public Task PushSyncOrgKeysAsync(Guid userId) + { + PushToServices((s) => s.PushSyncOrgKeysAsync(userId)); + return Task.FromResult(0); + } + + public Task PushSyncSettingsAsync(Guid userId) + { + PushToServices((s) => s.PushSyncSettingsAsync(userId)); + return Task.FromResult(0); + } + + public Task PushLogOutAsync(Guid userId) + { + PushToServices((s) => s.PushLogOutAsync(userId)); + return Task.FromResult(0); + } + + public Task PushSyncSendCreateAsync(Send send) + { + PushToServices((s) => s.PushSyncSendCreateAsync(send)); + return Task.FromResult(0); + } + + public Task PushSyncSendUpdateAsync(Send send) + { + PushToServices((s) => s.PushSyncSendUpdateAsync(send)); + return Task.FromResult(0); + } + + public Task PushSyncSendDeleteAsync(Send send) + { + PushToServices((s) => s.PushSyncSendDeleteAsync(send)); + return Task.FromResult(0); + } + + public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, + string deviceId = null) + { + PushToServices((s) => s.SendPayloadToUserAsync(userId, type, payload, identifier, deviceId)); + return Task.FromResult(0); + } + + public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, + string deviceId = null) + { + PushToServices((s) => s.SendPayloadToOrganizationAsync(orgId, type, payload, identifier, deviceId)); + return Task.FromResult(0); + } + + private void PushToServices(Func pushFunc) + { + if (_services != null) + { + foreach (var service in _services) + { + pushFunc(service); } } } diff --git a/src/Core/Services/Implementations/NotificationHubPushNotificationService.cs b/src/Core/Services/Implementations/NotificationHubPushNotificationService.cs index dbf4e55aa..fbd7ab9ce 100644 --- a/src/Core/Services/Implementations/NotificationHubPushNotificationService.cs +++ b/src/Core/Services/Implementations/NotificationHubPushNotificationService.cs @@ -10,231 +10,230 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Http; using Microsoft.Azure.NotificationHubs; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class NotificationHubPushNotificationService : IPushNotificationService { - public class NotificationHubPushNotificationService : IPushNotificationService + private readonly IInstallationDeviceRepository _installationDeviceRepository; + private readonly GlobalSettings _globalSettings; + private readonly IHttpContextAccessor _httpContextAccessor; + + private NotificationHubClient _client = null; + + public NotificationHubPushNotificationService( + IInstallationDeviceRepository installationDeviceRepository, + GlobalSettings globalSettings, + IHttpContextAccessor httpContextAccessor) { - private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; + _installationDeviceRepository = installationDeviceRepository; + _globalSettings = globalSettings; + _httpContextAccessor = httpContextAccessor; + _client = NotificationHubClient.CreateClientFromConnectionString( + _globalSettings.NotificationHub.ConnectionString, + _globalSettings.NotificationHub.HubName); + } - private NotificationHubClient _client = null; + public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) + { + await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds); + } - public NotificationHubPushNotificationService( - IInstallationDeviceRepository installationDeviceRepository, - GlobalSettings globalSettings, - IHttpContextAccessor httpContextAccessor) + public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) + { + await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds); + } + + public async Task PushSyncCipherDeleteAsync(Cipher cipher) + { + await PushCipherAsync(cipher, PushType.SyncLoginDelete, null); + } + + private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable collectionIds) + { + if (cipher.OrganizationId.HasValue) { - _installationDeviceRepository = installationDeviceRepository; - _globalSettings = globalSettings; - _httpContextAccessor = httpContextAccessor; - _client = NotificationHubClient.CreateClientFromConnectionString( - _globalSettings.NotificationHub.ConnectionString, - _globalSettings.NotificationHub.HubName); + // We cannot send org pushes since access logic is much more complicated than just the fact that they belong + // to the organization. Potentially we could blindly send to just users that have the access all permission + // device registration needs to be more granular to handle that appropriately. A more brute force approach could + // me to send "full sync" push to all org users, but that has the potential to DDOS the API in bursts. + + // await SendPayloadToOrganizationAsync(cipher.OrganizationId.Value, type, message, true); } - - public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) + else if (cipher.UserId.HasValue) { - await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds); - } - - public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) - { - await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds); - } - - public async Task PushSyncCipherDeleteAsync(Cipher cipher) - { - await PushCipherAsync(cipher, PushType.SyncLoginDelete, null); - } - - private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable collectionIds) - { - if (cipher.OrganizationId.HasValue) + var message = new SyncCipherPushNotification { - // We cannot send org pushes since access logic is much more complicated than just the fact that they belong - // to the organization. Potentially we could blindly send to just users that have the access all permission - // device registration needs to be more granular to handle that appropriately. A more brute force approach could - // me to send "full sync" push to all org users, but that has the potential to DDOS the API in bursts. - - // await SendPayloadToOrganizationAsync(cipher.OrganizationId.Value, type, message, true); - } - else if (cipher.UserId.HasValue) - { - var message = new SyncCipherPushNotification - { - Id = cipher.Id, - UserId = cipher.UserId, - OrganizationId = cipher.OrganizationId, - RevisionDate = cipher.RevisionDate, - CollectionIds = collectionIds, - }; - - await SendPayloadToUserAsync(cipher.UserId.Value, type, message, true); - } - } - - public async Task PushSyncFolderCreateAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderCreate); - } - - public async Task PushSyncFolderUpdateAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderUpdate); - } - - public async Task PushSyncFolderDeleteAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderDelete); - } - - private async Task PushFolderAsync(Folder folder, PushType type) - { - var message = new SyncFolderPushNotification - { - Id = folder.Id, - UserId = folder.UserId, - RevisionDate = folder.RevisionDate + Id = cipher.Id, + UserId = cipher.UserId, + OrganizationId = cipher.OrganizationId, + RevisionDate = cipher.RevisionDate, + CollectionIds = collectionIds, }; - await SendPayloadToUserAsync(folder.UserId, type, message, true); - } - - public async Task PushSyncCiphersAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncCiphers); - } - - public async Task PushSyncVaultAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncVault); - } - - public async Task PushSyncOrgKeysAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncOrgKeys); - } - - public async Task PushSyncSettingsAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncSettings); - } - - public async Task PushLogOutAsync(Guid userId) - { - await PushUserAsync(userId, PushType.LogOut); - } - - private async Task PushUserAsync(Guid userId, PushType type) - { - var message = new UserPushNotification - { - UserId = userId, - Date = DateTime.UtcNow - }; - - await SendPayloadToUserAsync(userId, type, message, false); - } - - public async Task PushSyncSendCreateAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendCreate); - } - - public async Task PushSyncSendUpdateAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendUpdate); - } - - public async Task PushSyncSendDeleteAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendDelete); - } - - private async Task PushSendAsync(Send send, PushType type) - { - if (send.UserId.HasValue) - { - var message = new SyncSendPushNotification - { - Id = send.Id, - UserId = send.UserId.Value, - RevisionDate = send.RevisionDate - }; - - await SendPayloadToUserAsync(message.UserId, type, message, true); - } - } - - private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext) - { - await SendPayloadToUserAsync(userId.ToString(), type, payload, GetContextIdentifier(excludeCurrentContext)); - } - - private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload, bool excludeCurrentContext) - { - await SendPayloadToUserAsync(orgId.ToString(), type, payload, GetContextIdentifier(excludeCurrentContext)); - } - - public async Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, - string deviceId = null) - { - var tag = BuildTag($"template:payload_userId:{SanitizeTagInput(userId)}", identifier); - await SendPayloadAsync(tag, type, payload); - if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) - { - await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId)); - } - } - - public async Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null) - { - var tag = BuildTag($"template:payload && organizationId:{SanitizeTagInput(orgId)}", identifier); - await SendPayloadAsync(tag, type, payload); - if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) - { - await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId)); - } - } - - private string GetContextIdentifier(bool excludeCurrentContext) - { - if (!excludeCurrentContext) - { - return null; - } - - var currentContext = _httpContextAccessor?.HttpContext?. - RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; - return currentContext?.DeviceIdentifier; - } - - private string BuildTag(string tag, string identifier) - { - if (!string.IsNullOrWhiteSpace(identifier)) - { - tag += $" && !deviceIdentifier:{SanitizeTagInput(identifier)}"; - } - - return $"({tag})"; - } - - private async Task SendPayloadAsync(string tag, PushType type, object payload) - { - await _client.SendTemplateNotificationAsync( - new Dictionary - { - { "type", ((byte)type).ToString() }, - { "payload", JsonSerializer.Serialize(payload) } - }, tag); - } - - private string SanitizeTagInput(string input) - { - // Only allow a-z, A-Z, 0-9, and special characters -_: - return Regex.Replace(input, "[^a-zA-Z0-9-_:]", string.Empty); + await SendPayloadToUserAsync(cipher.UserId.Value, type, message, true); } } + + public async Task PushSyncFolderCreateAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderCreate); + } + + public async Task PushSyncFolderUpdateAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderUpdate); + } + + public async Task PushSyncFolderDeleteAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderDelete); + } + + private async Task PushFolderAsync(Folder folder, PushType type) + { + var message = new SyncFolderPushNotification + { + Id = folder.Id, + UserId = folder.UserId, + RevisionDate = folder.RevisionDate + }; + + await SendPayloadToUserAsync(folder.UserId, type, message, true); + } + + public async Task PushSyncCiphersAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncCiphers); + } + + public async Task PushSyncVaultAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncVault); + } + + public async Task PushSyncOrgKeysAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncOrgKeys); + } + + public async Task PushSyncSettingsAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncSettings); + } + + public async Task PushLogOutAsync(Guid userId) + { + await PushUserAsync(userId, PushType.LogOut); + } + + private async Task PushUserAsync(Guid userId, PushType type) + { + var message = new UserPushNotification + { + UserId = userId, + Date = DateTime.UtcNow + }; + + await SendPayloadToUserAsync(userId, type, message, false); + } + + public async Task PushSyncSendCreateAsync(Send send) + { + await PushSendAsync(send, PushType.SyncSendCreate); + } + + public async Task PushSyncSendUpdateAsync(Send send) + { + await PushSendAsync(send, PushType.SyncSendUpdate); + } + + public async Task PushSyncSendDeleteAsync(Send send) + { + await PushSendAsync(send, PushType.SyncSendDelete); + } + + private async Task PushSendAsync(Send send, PushType type) + { + if (send.UserId.HasValue) + { + var message = new SyncSendPushNotification + { + Id = send.Id, + UserId = send.UserId.Value, + RevisionDate = send.RevisionDate + }; + + await SendPayloadToUserAsync(message.UserId, type, message, true); + } + } + + private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext) + { + await SendPayloadToUserAsync(userId.ToString(), type, payload, GetContextIdentifier(excludeCurrentContext)); + } + + private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload, bool excludeCurrentContext) + { + await SendPayloadToUserAsync(orgId.ToString(), type, payload, GetContextIdentifier(excludeCurrentContext)); + } + + public async Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, + string deviceId = null) + { + var tag = BuildTag($"template:payload_userId:{SanitizeTagInput(userId)}", identifier); + await SendPayloadAsync(tag, type, payload); + if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) + { + await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId)); + } + } + + public async Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, + string deviceId = null) + { + var tag = BuildTag($"template:payload && organizationId:{SanitizeTagInput(orgId)}", identifier); + await SendPayloadAsync(tag, type, payload); + if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) + { + await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId)); + } + } + + private string GetContextIdentifier(bool excludeCurrentContext) + { + if (!excludeCurrentContext) + { + return null; + } + + var currentContext = _httpContextAccessor?.HttpContext?. + RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; + return currentContext?.DeviceIdentifier; + } + + private string BuildTag(string tag, string identifier) + { + if (!string.IsNullOrWhiteSpace(identifier)) + { + tag += $" && !deviceIdentifier:{SanitizeTagInput(identifier)}"; + } + + return $"({tag})"; + } + + private async Task SendPayloadAsync(string tag, PushType type, object payload) + { + await _client.SendTemplateNotificationAsync( + new Dictionary + { + { "type", ((byte)type).ToString() }, + { "payload", JsonSerializer.Serialize(payload) } + }, tag); + } + + private string SanitizeTagInput(string input) + { + // Only allow a-z, A-Z, 0-9, and special characters -_: + return Regex.Replace(input, "[^a-zA-Z0-9-_:]", string.Empty); + } } diff --git a/src/Core/Services/Implementations/NotificationHubPushRegistrationService.cs b/src/Core/Services/Implementations/NotificationHubPushRegistrationService.cs index be3f8735f..6f0937539 100644 --- a/src/Core/Services/Implementations/NotificationHubPushRegistrationService.cs +++ b/src/Core/Services/Implementations/NotificationHubPushRegistrationService.cs @@ -4,192 +4,191 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Microsoft.Azure.NotificationHubs; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class NotificationHubPushRegistrationService : IPushRegistrationService { - public class NotificationHubPushRegistrationService : IPushRegistrationService + private readonly IInstallationDeviceRepository _installationDeviceRepository; + private readonly GlobalSettings _globalSettings; + + private NotificationHubClient _client = null; + + public NotificationHubPushRegistrationService( + IInstallationDeviceRepository installationDeviceRepository, + GlobalSettings globalSettings) { - private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly GlobalSettings _globalSettings; + _installationDeviceRepository = installationDeviceRepository; + _globalSettings = globalSettings; + _client = NotificationHubClient.CreateClientFromConnectionString( + _globalSettings.NotificationHub.ConnectionString, + _globalSettings.NotificationHub.HubName); + } - private NotificationHubClient _client = null; - - public NotificationHubPushRegistrationService( - IInstallationDeviceRepository installationDeviceRepository, - GlobalSettings globalSettings) + public async Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, + string identifier, DeviceType type) + { + if (string.IsNullOrWhiteSpace(pushToken)) { - _installationDeviceRepository = installationDeviceRepository; - _globalSettings = globalSettings; - _client = NotificationHubClient.CreateClientFromConnectionString( - _globalSettings.NotificationHub.ConnectionString, - _globalSettings.NotificationHub.HubName); + return; } - public async Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, - string identifier, DeviceType type) + var installation = new Installation { - if (string.IsNullOrWhiteSpace(pushToken)) + InstallationId = deviceId, + PushChannel = pushToken, + Templates = new Dictionary() + }; + + installation.Tags = new List + { + $"userId:{userId}" + }; + + if (!string.IsNullOrWhiteSpace(identifier)) + { + installation.Tags.Add("deviceIdentifier:" + identifier); + } + + string payloadTemplate = null, messageTemplate = null, badgeMessageTemplate = null; + switch (type) + { + case DeviceType.Android: + payloadTemplate = "{\"data\":{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}}"; + messageTemplate = "{\"data\":{\"data\":{\"type\":\"#(type)\"}," + + "\"notification\":{\"title\":\"$(title)\",\"body\":\"$(message)\"}}}"; + + installation.Platform = NotificationPlatform.Fcm; + break; + case DeviceType.iOS: + payloadTemplate = "{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}," + + "\"aps\":{\"content-available\":1}}"; + messageTemplate = "{\"data\":{\"type\":\"#(type)\"}," + + "\"aps\":{\"alert\":\"$(message)\",\"badge\":null,\"content-available\":1}}"; + badgeMessageTemplate = "{\"data\":{\"type\":\"#(type)\"}," + + "\"aps\":{\"alert\":\"$(message)\",\"badge\":\"#(badge)\",\"content-available\":1}}"; + + installation.Platform = NotificationPlatform.Apns; + break; + case DeviceType.AndroidAmazon: + payloadTemplate = "{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}"; + messageTemplate = "{\"data\":{\"type\":\"#(type)\",\"message\":\"$(message)\"}}"; + + installation.Platform = NotificationPlatform.Adm; + break; + default: + break; + } + + BuildInstallationTemplate(installation, "payload", payloadTemplate, userId, identifier); + BuildInstallationTemplate(installation, "message", messageTemplate, userId, identifier); + BuildInstallationTemplate(installation, "badgeMessage", badgeMessageTemplate ?? messageTemplate, + userId, identifier); + + await _client.CreateOrUpdateInstallationAsync(installation); + if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) + { + await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId)); + } + } + + private void BuildInstallationTemplate(Installation installation, string templateId, string templateBody, + string userId, string identifier) + { + if (templateBody == null) + { + return; + } + + var fullTemplateId = $"template:{templateId}"; + + var template = new InstallationTemplate + { + Body = templateBody, + Tags = new List { - return; + fullTemplateId, + $"{fullTemplateId}_userId:{userId}" } + }; - var installation = new Installation - { - InstallationId = deviceId, - PushChannel = pushToken, - Templates = new Dictionary() - }; + if (!string.IsNullOrWhiteSpace(identifier)) + { + template.Tags.Add($"{fullTemplateId}_deviceIdentifier:{identifier}"); + } - installation.Tags = new List - { - $"userId:{userId}" - }; + installation.Templates.Add(fullTemplateId, template); + } - if (!string.IsNullOrWhiteSpace(identifier)) - { - installation.Tags.Add("deviceIdentifier:" + identifier); - } - - string payloadTemplate = null, messageTemplate = null, badgeMessageTemplate = null; - switch (type) - { - case DeviceType.Android: - payloadTemplate = "{\"data\":{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}}"; - messageTemplate = "{\"data\":{\"data\":{\"type\":\"#(type)\"}," + - "\"notification\":{\"title\":\"$(title)\",\"body\":\"$(message)\"}}}"; - - installation.Platform = NotificationPlatform.Fcm; - break; - case DeviceType.iOS: - payloadTemplate = "{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}," + - "\"aps\":{\"content-available\":1}}"; - messageTemplate = "{\"data\":{\"type\":\"#(type)\"}," + - "\"aps\":{\"alert\":\"$(message)\",\"badge\":null,\"content-available\":1}}"; - badgeMessageTemplate = "{\"data\":{\"type\":\"#(type)\"}," + - "\"aps\":{\"alert\":\"$(message)\",\"badge\":\"#(badge)\",\"content-available\":1}}"; - - installation.Platform = NotificationPlatform.Apns; - break; - case DeviceType.AndroidAmazon: - payloadTemplate = "{\"data\":{\"type\":\"#(type)\",\"payload\":\"$(payload)\"}}"; - messageTemplate = "{\"data\":{\"type\":\"#(type)\",\"message\":\"$(message)\"}}"; - - installation.Platform = NotificationPlatform.Adm; - break; - default: - break; - } - - BuildInstallationTemplate(installation, "payload", payloadTemplate, userId, identifier); - BuildInstallationTemplate(installation, "message", messageTemplate, userId, identifier); - BuildInstallationTemplate(installation, "badgeMessage", badgeMessageTemplate ?? messageTemplate, - userId, identifier); - - await _client.CreateOrUpdateInstallationAsync(installation); + public async Task DeleteRegistrationAsync(string deviceId) + { + try + { + await _client.DeleteInstallationAsync(deviceId); if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) { - await _installationDeviceRepository.UpsertAsync(new InstallationDeviceEntity(deviceId)); + await _installationDeviceRepository.DeleteAsync(new InstallationDeviceEntity(deviceId)); } } - - private void BuildInstallationTemplate(Installation installation, string templateId, string templateBody, - string userId, string identifier) + catch (Exception e) when (e.InnerException == null || !e.InnerException.Message.Contains("(404) Not Found")) { - if (templateBody == null) - { - return; - } + throw; + } + } - var fullTemplateId = $"template:{templateId}"; + public async Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) + { + await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Add, $"organizationId:{organizationId}"); + if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First())) + { + var entities = deviceIds.Select(e => new InstallationDeviceEntity(e)); + await _installationDeviceRepository.UpsertManyAsync(entities.ToList()); + } + } - var template = new InstallationTemplate - { - Body = templateBody, - Tags = new List - { - fullTemplateId, - $"{fullTemplateId}_userId:{userId}" - } - }; + public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) + { + await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Remove, + $"organizationId:{organizationId}"); + if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First())) + { + var entities = deviceIds.Select(e => new InstallationDeviceEntity(e)); + await _installationDeviceRepository.UpsertManyAsync(entities.ToList()); + } + } - if (!string.IsNullOrWhiteSpace(identifier)) - { - template.Tags.Add($"{fullTemplateId}_deviceIdentifier:{identifier}"); - } - - installation.Templates.Add(fullTemplateId, template); + private async Task PatchTagsForUserDevicesAsync(IEnumerable deviceIds, UpdateOperationType op, + string tag) + { + if (!deviceIds.Any()) + { + return; } - public async Task DeleteRegistrationAsync(string deviceId) + var operation = new PartialUpdateOperation + { + Operation = op, + Path = "/tags" + }; + + if (op == UpdateOperationType.Add) + { + operation.Value = tag; + } + else if (op == UpdateOperationType.Remove) + { + operation.Path += $"/{tag}"; + } + + foreach (var id in deviceIds) { try { - await _client.DeleteInstallationAsync(deviceId); - if (InstallationDeviceEntity.IsInstallationDeviceId(deviceId)) - { - await _installationDeviceRepository.DeleteAsync(new InstallationDeviceEntity(deviceId)); - } + await _client.PatchInstallationAsync(id, new List { operation }); } catch (Exception e) when (e.InnerException == null || !e.InnerException.Message.Contains("(404) Not Found")) { throw; } } - - public async Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) - { - await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Add, $"organizationId:{organizationId}"); - if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First())) - { - var entities = deviceIds.Select(e => new InstallationDeviceEntity(e)); - await _installationDeviceRepository.UpsertManyAsync(entities.ToList()); - } - } - - public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) - { - await PatchTagsForUserDevicesAsync(deviceIds, UpdateOperationType.Remove, - $"organizationId:{organizationId}"); - if (deviceIds.Any() && InstallationDeviceEntity.IsInstallationDeviceId(deviceIds.First())) - { - var entities = deviceIds.Select(e => new InstallationDeviceEntity(e)); - await _installationDeviceRepository.UpsertManyAsync(entities.ToList()); - } - } - - private async Task PatchTagsForUserDevicesAsync(IEnumerable deviceIds, UpdateOperationType op, - string tag) - { - if (!deviceIds.Any()) - { - return; - } - - var operation = new PartialUpdateOperation - { - Operation = op, - Path = "/tags" - }; - - if (op == UpdateOperationType.Add) - { - operation.Value = tag; - } - else if (op == UpdateOperationType.Remove) - { - operation.Path += $"/{tag}"; - } - - foreach (var id in deviceIds) - { - try - { - await _client.PatchInstallationAsync(id, new List { operation }); - } - catch (Exception e) when (e.InnerException == null || !e.InnerException.Message.Contains("(404) Not Found")) - { - throw; - } - } - } } } diff --git a/src/Core/Services/Implementations/NotificationsApiPushNotificationService.cs b/src/Core/Services/Implementations/NotificationsApiPushNotificationService.cs index 87729be70..144178f84 100644 --- a/src/Core/Services/Implementations/NotificationsApiPushNotificationService.cs +++ b/src/Core/Services/Implementations/NotificationsApiPushNotificationService.cs @@ -6,198 +6,197 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class NotificationsApiPushNotificationService : BaseIdentityClientService, IPushNotificationService { - public class NotificationsApiPushNotificationService : BaseIdentityClientService, IPushNotificationService + private readonly GlobalSettings _globalSettings; + private readonly IHttpContextAccessor _httpContextAccessor; + + public NotificationsApiPushNotificationService( + IHttpClientFactory httpFactory, + GlobalSettings globalSettings, + IHttpContextAccessor httpContextAccessor, + ILogger logger) + : base( + httpFactory, + globalSettings.BaseServiceUri.InternalNotifications, + globalSettings.BaseServiceUri.InternalIdentity, + "internal", + $"internal.{globalSettings.ProjectName}", + globalSettings.InternalIdentityKey, + logger) { - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; + _globalSettings = globalSettings; + _httpContextAccessor = httpContextAccessor; + } - public NotificationsApiPushNotificationService( - IHttpClientFactory httpFactory, - GlobalSettings globalSettings, - IHttpContextAccessor httpContextAccessor, - ILogger logger) - : base( - httpFactory, - globalSettings.BaseServiceUri.InternalNotifications, - globalSettings.BaseServiceUri.InternalIdentity, - "internal", - $"internal.{globalSettings.ProjectName}", - globalSettings.InternalIdentityKey, - logger) - { - _globalSettings = globalSettings; - _httpContextAccessor = httpContextAccessor; - } + public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) + { + await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds); + } - public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) - { - await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds); - } + public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) + { + await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds); + } - public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) - { - await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds); - } + public async Task PushSyncCipherDeleteAsync(Cipher cipher) + { + await PushCipherAsync(cipher, PushType.SyncLoginDelete, null); + } - public async Task PushSyncCipherDeleteAsync(Cipher cipher) + private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable collectionIds) + { + if (cipher.OrganizationId.HasValue) { - await PushCipherAsync(cipher, PushType.SyncLoginDelete, null); - } - - private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable collectionIds) - { - if (cipher.OrganizationId.HasValue) + var message = new SyncCipherPushNotification { - var message = new SyncCipherPushNotification - { - Id = cipher.Id, - OrganizationId = cipher.OrganizationId, - RevisionDate = cipher.RevisionDate, - CollectionIds = collectionIds, - }; - - await SendMessageAsync(type, message, true); - } - else if (cipher.UserId.HasValue) - { - var message = new SyncCipherPushNotification - { - Id = cipher.Id, - UserId = cipher.UserId, - RevisionDate = cipher.RevisionDate, - CollectionIds = collectionIds, - }; - - await SendMessageAsync(type, message, true); - } - } - - public async Task PushSyncFolderCreateAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderCreate); - } - - public async Task PushSyncFolderUpdateAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderUpdate); - } - - public async Task PushSyncFolderDeleteAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderDelete); - } - - private async Task PushFolderAsync(Folder folder, PushType type) - { - var message = new SyncFolderPushNotification - { - Id = folder.Id, - UserId = folder.UserId, - RevisionDate = folder.RevisionDate + Id = cipher.Id, + OrganizationId = cipher.OrganizationId, + RevisionDate = cipher.RevisionDate, + CollectionIds = collectionIds, }; await SendMessageAsync(type, message, true); } - - public async Task PushSyncCiphersAsync(Guid userId) + else if (cipher.UserId.HasValue) { - await PushUserAsync(userId, PushType.SyncCiphers); - } - - public async Task PushSyncVaultAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncVault); - } - - public async Task PushSyncOrgKeysAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncOrgKeys); - } - - public async Task PushSyncSettingsAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncSettings); - } - - public async Task PushLogOutAsync(Guid userId) - { - await PushUserAsync(userId, PushType.LogOut); - } - - private async Task PushUserAsync(Guid userId, PushType type) - { - var message = new UserPushNotification + var message = new SyncCipherPushNotification { - UserId = userId, - Date = DateTime.UtcNow + Id = cipher.Id, + UserId = cipher.UserId, + RevisionDate = cipher.RevisionDate, + CollectionIds = collectionIds, + }; + + await SendMessageAsync(type, message, true); + } + } + + public async Task PushSyncFolderCreateAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderCreate); + } + + public async Task PushSyncFolderUpdateAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderUpdate); + } + + public async Task PushSyncFolderDeleteAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderDelete); + } + + private async Task PushFolderAsync(Folder folder, PushType type) + { + var message = new SyncFolderPushNotification + { + Id = folder.Id, + UserId = folder.UserId, + RevisionDate = folder.RevisionDate + }; + + await SendMessageAsync(type, message, true); + } + + public async Task PushSyncCiphersAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncCiphers); + } + + public async Task PushSyncVaultAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncVault); + } + + public async Task PushSyncOrgKeysAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncOrgKeys); + } + + public async Task PushSyncSettingsAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncSettings); + } + + public async Task PushLogOutAsync(Guid userId) + { + await PushUserAsync(userId, PushType.LogOut); + } + + private async Task PushUserAsync(Guid userId, PushType type) + { + var message = new UserPushNotification + { + UserId = userId, + Date = DateTime.UtcNow + }; + + await SendMessageAsync(type, message, false); + } + + public async Task PushSyncSendCreateAsync(Send send) + { + await PushSendAsync(send, PushType.SyncSendCreate); + } + + public async Task PushSyncSendUpdateAsync(Send send) + { + await PushSendAsync(send, PushType.SyncSendUpdate); + } + + public async Task PushSyncSendDeleteAsync(Send send) + { + await PushSendAsync(send, PushType.SyncSendDelete); + } + + private async Task PushSendAsync(Send send, PushType type) + { + if (send.UserId.HasValue) + { + var message = new SyncSendPushNotification + { + Id = send.Id, + UserId = send.UserId.Value, + RevisionDate = send.RevisionDate }; await SendMessageAsync(type, message, false); } + } - public async Task PushSyncSendCreateAsync(Send send) + private async Task SendMessageAsync(PushType type, T payload, bool excludeCurrentContext) + { + var contextId = GetContextIdentifier(excludeCurrentContext); + var request = new PushNotificationData(type, payload, contextId); + await SendAsync(HttpMethod.Post, "send", request); + } + + private string GetContextIdentifier(bool excludeCurrentContext) + { + if (!excludeCurrentContext) { - await PushSendAsync(send, PushType.SyncSendCreate); + return null; } - public async Task PushSyncSendUpdateAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendUpdate); - } + var currentContext = _httpContextAccessor?.HttpContext?. + RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; + return currentContext?.DeviceIdentifier; + } - public async Task PushSyncSendDeleteAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendDelete); - } + public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, + string deviceId = null) + { + // Noop + return Task.FromResult(0); + } - private async Task PushSendAsync(Send send, PushType type) - { - if (send.UserId.HasValue) - { - var message = new SyncSendPushNotification - { - Id = send.Id, - UserId = send.UserId.Value, - RevisionDate = send.RevisionDate - }; - - await SendMessageAsync(type, message, false); - } - } - - private async Task SendMessageAsync(PushType type, T payload, bool excludeCurrentContext) - { - var contextId = GetContextIdentifier(excludeCurrentContext); - var request = new PushNotificationData(type, payload, contextId); - await SendAsync(HttpMethod.Post, "send", request); - } - - private string GetContextIdentifier(bool excludeCurrentContext) - { - if (!excludeCurrentContext) - { - return null; - } - - var currentContext = _httpContextAccessor?.HttpContext?. - RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; - return currentContext?.DeviceIdentifier; - } - - public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, - string deviceId = null) - { - // Noop - return Task.FromResult(0); - } - - public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null) - { - // Noop - return Task.FromResult(0); - } + public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, + string deviceId = null) + { + // Noop + return Task.FromResult(0); } } diff --git a/src/Core/Services/Implementations/OrganizationService.cs b/src/Core/Services/Implementations/OrganizationService.cs index b0b3dfc07..3d9f1da1f 100644 --- a/src/Core/Services/Implementations/OrganizationService.cs +++ b/src/Core/Services/Implementations/OrganizationService.cs @@ -14,2441 +14,2440 @@ using Microsoft.AspNetCore.DataProtection; using Microsoft.Extensions.Logging; using Stripe; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class OrganizationService : IOrganizationService { - public class OrganizationService : IOrganizationService + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly ICollectionRepository _collectionRepository; + private readonly IUserRepository _userRepository; + private readonly IGroupRepository _groupRepository; + private readonly IDataProtector _dataProtector; + private readonly IMailService _mailService; + private readonly IPushNotificationService _pushNotificationService; + private readonly IPushRegistrationService _pushRegistrationService; + private readonly IDeviceRepository _deviceRepository; + private readonly ILicensingService _licensingService; + private readonly IEventService _eventService; + private readonly IInstallationRepository _installationRepository; + private readonly IApplicationCacheService _applicationCacheService; + private readonly IPaymentService _paymentService; + private readonly IPolicyRepository _policyRepository; + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly ISsoUserRepository _ssoUserRepository; + private readonly IReferenceEventService _referenceEventService; + private readonly IGlobalSettings _globalSettings; + private readonly ITaxRateRepository _taxRateRepository; + private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; + private readonly IOrganizationConnectionRepository _organizationConnectionRepository; + private readonly ICurrentContext _currentContext; + private readonly ILogger _logger; + + + public OrganizationService( + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository, + IUserRepository userRepository, + IGroupRepository groupRepository, + IDataProtectionProvider dataProtectionProvider, + IMailService mailService, + IPushNotificationService pushNotificationService, + IPushRegistrationService pushRegistrationService, + IDeviceRepository deviceRepository, + ILicensingService licensingService, + IEventService eventService, + IInstallationRepository installationRepository, + IApplicationCacheService applicationCacheService, + IPaymentService paymentService, + IPolicyRepository policyRepository, + ISsoConfigRepository ssoConfigRepository, + ISsoUserRepository ssoUserRepository, + IReferenceEventService referenceEventService, + IGlobalSettings globalSettings, + ITaxRateRepository taxRateRepository, + IOrganizationApiKeyRepository organizationApiKeyRepository, + IOrganizationConnectionRepository organizationConnectionRepository, + ICurrentContext currentContext, + ILogger logger) { - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly ICollectionRepository _collectionRepository; - private readonly IUserRepository _userRepository; - private readonly IGroupRepository _groupRepository; - private readonly IDataProtector _dataProtector; - private readonly IMailService _mailService; - private readonly IPushNotificationService _pushNotificationService; - private readonly IPushRegistrationService _pushRegistrationService; - private readonly IDeviceRepository _deviceRepository; - private readonly ILicensingService _licensingService; - private readonly IEventService _eventService; - private readonly IInstallationRepository _installationRepository; - private readonly IApplicationCacheService _applicationCacheService; - private readonly IPaymentService _paymentService; - private readonly IPolicyRepository _policyRepository; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly ISsoUserRepository _ssoUserRepository; - private readonly IReferenceEventService _referenceEventService; - private readonly IGlobalSettings _globalSettings; - private readonly ITaxRateRepository _taxRateRepository; - private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; - private readonly IOrganizationConnectionRepository _organizationConnectionRepository; - private readonly ICurrentContext _currentContext; - private readonly ILogger _logger; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _collectionRepository = collectionRepository; + _userRepository = userRepository; + _groupRepository = groupRepository; + _dataProtector = dataProtectionProvider.CreateProtector("OrganizationServiceDataProtector"); + _mailService = mailService; + _pushNotificationService = pushNotificationService; + _pushRegistrationService = pushRegistrationService; + _deviceRepository = deviceRepository; + _licensingService = licensingService; + _eventService = eventService; + _installationRepository = installationRepository; + _applicationCacheService = applicationCacheService; + _paymentService = paymentService; + _policyRepository = policyRepository; + _ssoConfigRepository = ssoConfigRepository; + _ssoUserRepository = ssoUserRepository; + _referenceEventService = referenceEventService; + _globalSettings = globalSettings; + _taxRateRepository = taxRateRepository; + _organizationApiKeyRepository = organizationApiKeyRepository; + _organizationConnectionRepository = organizationConnectionRepository; + _currentContext = currentContext; + _logger = logger; + } - - public OrganizationService( - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - ICollectionRepository collectionRepository, - IUserRepository userRepository, - IGroupRepository groupRepository, - IDataProtectionProvider dataProtectionProvider, - IMailService mailService, - IPushNotificationService pushNotificationService, - IPushRegistrationService pushRegistrationService, - IDeviceRepository deviceRepository, - ILicensingService licensingService, - IEventService eventService, - IInstallationRepository installationRepository, - IApplicationCacheService applicationCacheService, - IPaymentService paymentService, - IPolicyRepository policyRepository, - ISsoConfigRepository ssoConfigRepository, - ISsoUserRepository ssoUserRepository, - IReferenceEventService referenceEventService, - IGlobalSettings globalSettings, - ITaxRateRepository taxRateRepository, - IOrganizationApiKeyRepository organizationApiKeyRepository, - IOrganizationConnectionRepository organizationConnectionRepository, - ICurrentContext currentContext, - ILogger logger) + public async Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, + PaymentMethodType paymentMethodType, TaxInfo taxInfo) + { + var organization = await GetOrgById(organizationId); + if (organization == null) { - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _collectionRepository = collectionRepository; - _userRepository = userRepository; - _groupRepository = groupRepository; - _dataProtector = dataProtectionProvider.CreateProtector("OrganizationServiceDataProtector"); - _mailService = mailService; - _pushNotificationService = pushNotificationService; - _pushRegistrationService = pushRegistrationService; - _deviceRepository = deviceRepository; - _licensingService = licensingService; - _eventService = eventService; - _installationRepository = installationRepository; - _applicationCacheService = applicationCacheService; - _paymentService = paymentService; - _policyRepository = policyRepository; - _ssoConfigRepository = ssoConfigRepository; - _ssoUserRepository = ssoUserRepository; - _referenceEventService = referenceEventService; - _globalSettings = globalSettings; - _taxRateRepository = taxRateRepository; - _organizationApiKeyRepository = organizationApiKeyRepository; - _organizationConnectionRepository = organizationConnectionRepository; - _currentContext = currentContext; - _logger = logger; + throw new NotFoundException(); } - public async Task ReplacePaymentMethodAsync(Guid organizationId, string paymentToken, - PaymentMethodType paymentMethodType, TaxInfo taxInfo) + await _paymentService.SaveTaxInfoAsync(organization, taxInfo); + var updated = await _paymentService.UpdatePaymentMethodAsync(organization, + paymentMethodType, paymentToken); + if (updated) { - var organization = await GetOrgById(organizationId); - if (organization == null) - { - throw new NotFoundException(); - } + await ReplaceAndUpdateCache(organization); + } + } - await _paymentService.SaveTaxInfoAsync(organization, taxInfo); - var updated = await _paymentService.UpdatePaymentMethodAsync(organization, - paymentMethodType, paymentToken); - if (updated) + public async Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null) + { + var organization = await GetOrgById(organizationId); + if (organization == null) + { + throw new NotFoundException(); + } + + var eop = endOfPeriod.GetValueOrDefault(true); + if (!endOfPeriod.HasValue && organization.ExpirationDate.HasValue && + organization.ExpirationDate.Value < DateTime.UtcNow) + { + eop = false; + } + + await _paymentService.CancelSubscriptionAsync(organization, eop); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.CancelSubscription, organization) { - await ReplaceAndUpdateCache(organization); + EndOfPeriod = endOfPeriod, + }); + } + + public async Task ReinstateSubscriptionAsync(Guid organizationId) + { + var organization = await GetOrgById(organizationId); + if (organization == null) + { + throw new NotFoundException(); + } + + await _paymentService.ReinstateSubscriptionAsync(organization); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.ReinstateSubscription, organization)); + } + + public async Task> UpgradePlanAsync(Guid organizationId, OrganizationUpgrade upgrade) + { + var organization = await GetOrgById(organizationId); + if (organization == null) + { + throw new NotFoundException(); + } + + if (string.IsNullOrWhiteSpace(organization.GatewayCustomerId)) + { + throw new BadRequestException("Your account has no payment method available."); + } + + var existingPlan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType); + if (existingPlan == null) + { + throw new BadRequestException("Existing plan not found."); + } + + var newPlan = StaticStore.Plans.FirstOrDefault(p => p.Type == upgrade.Plan && !p.Disabled); + if (newPlan == null) + { + throw new BadRequestException("Plan not found."); + } + + if (existingPlan.Type == newPlan.Type) + { + throw new BadRequestException("Organization is already on this plan."); + } + + if (existingPlan.UpgradeSortOrder >= newPlan.UpgradeSortOrder) + { + throw new BadRequestException("You cannot upgrade to this plan."); + } + + if (existingPlan.Type != PlanType.Free) + { + throw new BadRequestException("You can only upgrade from the free plan. Contact support."); + } + + ValidateOrganizationUpgradeParameters(newPlan, upgrade); + + var newPlanSeats = (short)(newPlan.BaseSeats + + (newPlan.HasAdditionalSeatsOption ? upgrade.AdditionalSeats : 0)); + if (!organization.Seats.HasValue || organization.Seats.Value > newPlanSeats) + { + var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organization.Id); + if (userCount > newPlanSeats) + { + throw new BadRequestException($"Your organization currently has {userCount} seats filled. " + + $"Your new plan only has ({newPlanSeats}) seats. Remove some users."); } } - public async Task CancelSubscriptionAsync(Guid organizationId, bool? endOfPeriod = null) + if (newPlan.MaxCollections.HasValue && (!organization.MaxCollections.HasValue || + organization.MaxCollections.Value > newPlan.MaxCollections.Value)) { - var organization = await GetOrgById(organizationId); - if (organization == null) + var collectionCount = await _collectionRepository.GetCountByOrganizationIdAsync(organization.Id); + if (collectionCount > newPlan.MaxCollections.Value) { - throw new NotFoundException(); + throw new BadRequestException($"Your organization currently has {collectionCount} collections. " + + $"Your new plan allows for a maximum of ({newPlan.MaxCollections.Value}) collections. " + + "Remove some collections."); } + } - var eop = endOfPeriod.GetValueOrDefault(true); - if (!endOfPeriod.HasValue && organization.ExpirationDate.HasValue && - organization.ExpirationDate.Value < DateTime.UtcNow) + if (!newPlan.HasGroups && organization.UseGroups) + { + var groups = await _groupRepository.GetManyByOrganizationIdAsync(organization.Id); + if (groups.Any()) { - eop = false; + throw new BadRequestException($"Your new plan does not allow the groups feature. " + + $"Remove your groups."); } + } - await _paymentService.CancelSubscriptionAsync(organization, eop); + if (!newPlan.HasPolicies && organization.UsePolicies) + { + var policies = await _policyRepository.GetManyByOrganizationIdAsync(organization.Id); + if (policies.Any(p => p.Enabled)) + { + throw new BadRequestException($"Your new plan does not allow the policies feature. " + + $"Disable your policies."); + } + } + + if (!newPlan.HasSso && organization.UseSso) + { + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); + if (ssoConfig != null && ssoConfig.Enabled) + { + throw new BadRequestException($"Your new plan does not allow the SSO feature. " + + $"Disable your SSO configuration."); + } + } + + if (!newPlan.HasKeyConnector && organization.UseKeyConnector) + { + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); + if (ssoConfig != null && ssoConfig.GetData().KeyConnectorEnabled) + { + throw new BadRequestException("Your new plan does not allow the Key Connector feature. " + + "Disable your Key Connector."); + } + } + + if (!newPlan.HasResetPassword && organization.UseResetPassword) + { + var resetPasswordPolicy = + await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); + if (resetPasswordPolicy != null && resetPasswordPolicy.Enabled) + { + throw new BadRequestException("Your new plan does not allow the Password Reset feature. " + + "Disable your Password Reset policy."); + } + } + + if (!newPlan.HasScim && organization.UseScim) + { + var scimConnections = await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(organization.Id, + OrganizationConnectionType.Scim); + if (scimConnections != null && scimConnections.Any(c => c.GetConfig()?.Enabled == true)) + { + throw new BadRequestException("Your new plan does not allow the SCIM feature. " + + "Disable your SCIM configuration."); + } + } + + // TODO: Check storage? + + string paymentIntentClientSecret = null; + var success = true; + if (string.IsNullOrWhiteSpace(organization.GatewaySubscriptionId)) + { + paymentIntentClientSecret = await _paymentService.UpgradeFreeOrganizationAsync(organization, newPlan, + upgrade.AdditionalStorageGb, upgrade.AdditionalSeats, upgrade.PremiumAccessAddon, upgrade.TaxInfo); + success = string.IsNullOrWhiteSpace(paymentIntentClientSecret); + } + else + { + // TODO: Update existing sub + throw new BadRequestException("You can only upgrade from the free plan. Contact support."); + } + + organization.BusinessName = upgrade.BusinessName; + organization.PlanType = newPlan.Type; + organization.Seats = (short)(newPlan.BaseSeats + upgrade.AdditionalSeats); + organization.MaxCollections = newPlan.MaxCollections; + organization.UseGroups = newPlan.HasGroups; + organization.UseDirectory = newPlan.HasDirectory; + organization.UseEvents = newPlan.HasEvents; + organization.UseTotp = newPlan.HasTotp; + organization.Use2fa = newPlan.Has2fa; + organization.UseApi = newPlan.HasApi; + organization.SelfHost = newPlan.HasSelfHost; + organization.UsePolicies = newPlan.HasPolicies; + organization.MaxStorageGb = !newPlan.BaseStorageGb.HasValue ? + (short?)null : (short)(newPlan.BaseStorageGb.Value + upgrade.AdditionalStorageGb); + organization.UseGroups = newPlan.HasGroups; + organization.UseDirectory = newPlan.HasDirectory; + organization.UseEvents = newPlan.HasEvents; + organization.UseTotp = newPlan.HasTotp; + organization.Use2fa = newPlan.Has2fa; + organization.UseApi = newPlan.HasApi; + organization.UseSso = newPlan.HasSso; + organization.UseKeyConnector = newPlan.HasKeyConnector; + organization.UseScim = newPlan.HasScim; + organization.UseResetPassword = newPlan.HasResetPassword; + organization.SelfHost = newPlan.HasSelfHost; + organization.UsersGetPremium = newPlan.UsersGetPremium || upgrade.PremiumAccessAddon; + organization.Plan = newPlan.Name; + organization.Enabled = success; + organization.PublicKey = upgrade.PublicKey; + organization.PrivateKey = upgrade.PrivateKey; + await ReplaceAndUpdateCache(organization); + if (success) + { await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.CancelSubscription, organization) + new ReferenceEvent(ReferenceEventType.UpgradePlan, organization) { - EndOfPeriod = endOfPeriod, + PlanName = newPlan.Name, + PlanType = newPlan.Type, + OldPlanName = existingPlan.Name, + OldPlanType = existingPlan.Type, + Seats = organization.Seats, + Storage = organization.MaxStorageGb, }); } - public async Task ReinstateSubscriptionAsync(Guid organizationId) - { - var organization = await GetOrgById(organizationId); - if (organization == null) - { - throw new NotFoundException(); - } + return new Tuple(success, paymentIntentClientSecret); + } - await _paymentService.ReinstateSubscriptionAsync(organization); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.ReinstateSubscription, organization)); + public async Task AdjustStorageAsync(Guid organizationId, short storageAdjustmentGb) + { + var organization = await GetOrgById(organizationId); + if (organization == null) + { + throw new NotFoundException(); } - public async Task> UpgradePlanAsync(Guid organizationId, OrganizationUpgrade upgrade) + var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType); + if (plan == null) { - var organization = await GetOrgById(organizationId); - if (organization == null) + throw new BadRequestException("Existing plan not found."); + } + + if (!plan.HasAdditionalStorageOption) + { + throw new BadRequestException("Plan does not allow additional storage."); + } + + var secret = await BillingHelpers.AdjustStorageAsync(_paymentService, organization, storageAdjustmentGb, + plan.StripeStoragePlanId); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.AdjustStorage, organization) { - throw new NotFoundException(); + PlanName = plan.Name, + PlanType = plan.Type, + Storage = storageAdjustmentGb, + }); + await ReplaceAndUpdateCache(organization); + return secret; + } + + public async Task UpdateSubscription(Guid organizationId, int seatAdjustment, int? maxAutoscaleSeats) + { + var organization = await GetOrgById(organizationId); + if (organization == null) + { + throw new NotFoundException(); + } + + var newSeatCount = organization.Seats + seatAdjustment; + if (maxAutoscaleSeats.HasValue && newSeatCount > maxAutoscaleSeats.Value) + { + throw new BadRequestException("Cannot set max seat autoscaling below seat count."); + } + + if (seatAdjustment != 0) + { + await AdjustSeatsAsync(organization, seatAdjustment); + } + if (maxAutoscaleSeats != organization.MaxAutoscaleSeats) + { + await UpdateAutoscalingAsync(organization, maxAutoscaleSeats); + } + } + + private async Task UpdateAutoscalingAsync(Organization organization, int? maxAutoscaleSeats) + { + + if (maxAutoscaleSeats.HasValue && + organization.Seats.HasValue && + maxAutoscaleSeats.Value < organization.Seats.Value) + { + throw new BadRequestException($"Cannot set max seat autoscaling below current seat count."); + } + + var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType); + if (plan == null) + { + throw new BadRequestException("Existing plan not found."); + } + + if (!plan.AllowSeatAutoscale) + { + throw new BadRequestException("Your plan does not allow seat autoscaling."); + } + + if (plan.MaxUsers.HasValue && maxAutoscaleSeats.HasValue && + maxAutoscaleSeats > plan.MaxUsers) + { + throw new BadRequestException(string.Concat($"Your plan has a seat limit of {plan.MaxUsers}, ", + $"but you have specified a max autoscale count of {maxAutoscaleSeats}.", + "Reduce your max autoscale seat count.")); + } + + organization.MaxAutoscaleSeats = maxAutoscaleSeats; + + await ReplaceAndUpdateCache(organization); + } + + public async Task AdjustSeatsAsync(Guid organizationId, int seatAdjustment, DateTime? prorationDate = null) + { + var organization = await GetOrgById(organizationId); + if (organization == null) + { + throw new NotFoundException(); + } + + return await AdjustSeatsAsync(organization, seatAdjustment, prorationDate); + } + + private async Task AdjustSeatsAsync(Organization organization, int seatAdjustment, DateTime? prorationDate = null, IEnumerable ownerEmails = null) + { + if (organization.Seats == null) + { + throw new BadRequestException("Organization has no seat limit, no need to adjust seats"); + } + + if (string.IsNullOrWhiteSpace(organization.GatewayCustomerId)) + { + throw new BadRequestException("No payment method found."); + } + + if (string.IsNullOrWhiteSpace(organization.GatewaySubscriptionId)) + { + throw new BadRequestException("No subscription found."); + } + + var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType); + if (plan == null) + { + throw new BadRequestException("Existing plan not found."); + } + + if (!plan.HasAdditionalSeatsOption) + { + throw new BadRequestException("Plan does not allow additional seats."); + } + + var newSeatTotal = organization.Seats.Value + seatAdjustment; + if (plan.BaseSeats > newSeatTotal) + { + throw new BadRequestException($"Plan has a minimum of {plan.BaseSeats} seats."); + } + + if (newSeatTotal <= 0) + { + throw new BadRequestException("You must have at least 1 seat."); + } + + var additionalSeats = newSeatTotal - plan.BaseSeats; + if (plan.MaxAdditionalSeats.HasValue && additionalSeats > plan.MaxAdditionalSeats.Value) + { + throw new BadRequestException($"Organization plan allows a maximum of " + + $"{plan.MaxAdditionalSeats.Value} additional seats."); + } + + if (!organization.Seats.HasValue || organization.Seats.Value > newSeatTotal) + { + var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organization.Id); + if (userCount > newSeatTotal) + { + throw new BadRequestException($"Your organization currently has {userCount} seats filled. " + + $"Your new plan only has ({newSeatTotal}) seats. Remove some users."); } + } - if (string.IsNullOrWhiteSpace(organization.GatewayCustomerId)) + var paymentIntentClientSecret = await _paymentService.AdjustSeatsAsync(organization, plan, additionalSeats, prorationDate); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.AdjustSeats, organization) { - throw new BadRequestException("Your account has no payment method available."); - } + PlanName = plan.Name, + PlanType = plan.Type, + Seats = newSeatTotal, + PreviousSeats = organization.Seats + }); + organization.Seats = (short?)newSeatTotal; + await ReplaceAndUpdateCache(organization); - var existingPlan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType); - if (existingPlan == null) + if (organization.Seats.HasValue && organization.MaxAutoscaleSeats.HasValue && organization.Seats == organization.MaxAutoscaleSeats) + { + try { - throw new BadRequestException("Existing plan not found."); - } - - var newPlan = StaticStore.Plans.FirstOrDefault(p => p.Type == upgrade.Plan && !p.Disabled); - if (newPlan == null) - { - throw new BadRequestException("Plan not found."); - } - - if (existingPlan.Type == newPlan.Type) - { - throw new BadRequestException("Organization is already on this plan."); - } - - if (existingPlan.UpgradeSortOrder >= newPlan.UpgradeSortOrder) - { - throw new BadRequestException("You cannot upgrade to this plan."); - } - - if (existingPlan.Type != PlanType.Free) - { - throw new BadRequestException("You can only upgrade from the free plan. Contact support."); - } - - ValidateOrganizationUpgradeParameters(newPlan, upgrade); - - var newPlanSeats = (short)(newPlan.BaseSeats + - (newPlan.HasAdditionalSeatsOption ? upgrade.AdditionalSeats : 0)); - if (!organization.Seats.HasValue || organization.Seats.Value > newPlanSeats) - { - var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organization.Id); - if (userCount > newPlanSeats) + if (ownerEmails == null) { - throw new BadRequestException($"Your organization currently has {userCount} seats filled. " + - $"Your new plan only has ({newPlanSeats}) seats. Remove some users."); + ownerEmails = (await _organizationUserRepository.GetManyByMinimumRoleAsync(organization.Id, + OrganizationUserType.Owner)).Select(u => u.Email).Distinct(); } + await _mailService.SendOrganizationMaxSeatLimitReachedEmailAsync(organization, organization.MaxAutoscaleSeats.Value, ownerEmails); } - - if (newPlan.MaxCollections.HasValue && (!organization.MaxCollections.HasValue || - organization.MaxCollections.Value > newPlan.MaxCollections.Value)) + catch (Exception e) { - var collectionCount = await _collectionRepository.GetCountByOrganizationIdAsync(organization.Id); - if (collectionCount > newPlan.MaxCollections.Value) + _logger.LogError(e, "Error encountered notifying organization owners of seat limit reached."); + } + } + + return paymentIntentClientSecret; + } + + public async Task VerifyBankAsync(Guid organizationId, int amount1, int amount2) + { + var organization = await GetOrgById(organizationId); + if (organization == null) + { + throw new NotFoundException(); + } + + if (string.IsNullOrWhiteSpace(organization.GatewayCustomerId)) + { + throw new GatewayException("Not a gateway customer."); + } + + var bankService = new BankAccountService(); + var customerService = new CustomerService(); + var customer = await customerService.GetAsync(organization.GatewayCustomerId, + new CustomerGetOptions { Expand = new List { "sources" } }); + if (customer == null) + { + throw new GatewayException("Cannot find customer."); + } + + var bankAccount = customer.Sources + .FirstOrDefault(s => s is BankAccount && ((BankAccount)s).Status != "verified") as BankAccount; + if (bankAccount == null) + { + throw new GatewayException("Cannot find an unverified bank account."); + } + + try + { + var result = await bankService.VerifyAsync(organization.GatewayCustomerId, bankAccount.Id, + new BankAccountVerifyOptions { Amounts = new List { amount1, amount2 } }); + if (result.Status != "verified") + { + throw new GatewayException("Unable to verify account."); + } + } + catch (StripeException e) + { + throw new GatewayException(e.Message); + } + } + + public async Task> SignUpAsync(OrganizationSignup signup, + bool provider = false) + { + var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == signup.Plan); + if (!(plan is { LegacyYear: null })) + { + throw new BadRequestException("Invalid plan selected."); + } + + if (plan.Disabled) + { + throw new BadRequestException("Plan not found."); + } + + if (!provider) + { + await ValidateSignUpPoliciesAsync(signup.Owner.Id); + } + + ValidateOrganizationUpgradeParameters(plan, signup); + + var organization = new Organization + { + // Pre-generate the org id so that we can save it with the Stripe subscription.. + Id = CoreHelpers.GenerateComb(), + Name = signup.Name, + BillingEmail = signup.BillingEmail, + BusinessName = signup.BusinessName, + PlanType = plan.Type, + Seats = (short)(plan.BaseSeats + signup.AdditionalSeats), + MaxCollections = plan.MaxCollections, + MaxStorageGb = !plan.BaseStorageGb.HasValue ? + (short?)null : (short)(plan.BaseStorageGb.Value + signup.AdditionalStorageGb), + UsePolicies = plan.HasPolicies, + UseSso = plan.HasSso, + UseGroups = plan.HasGroups, + UseEvents = plan.HasEvents, + UseDirectory = plan.HasDirectory, + UseTotp = plan.HasTotp, + Use2fa = plan.Has2fa, + UseApi = plan.HasApi, + UseResetPassword = plan.HasResetPassword, + SelfHost = plan.HasSelfHost, + UsersGetPremium = plan.UsersGetPremium || signup.PremiumAccessAddon, + UseScim = plan.HasScim, + Plan = plan.Name, + Gateway = null, + ReferenceData = signup.Owner.ReferenceData, + Enabled = true, + LicenseKey = CoreHelpers.SecureRandomString(20), + PublicKey = signup.PublicKey, + PrivateKey = signup.PrivateKey, + CreationDate = DateTime.UtcNow, + RevisionDate = DateTime.UtcNow, + }; + + if (plan.Type == PlanType.Free && !provider) + { + var adminCount = + await _organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync(signup.Owner.Id); + if (adminCount > 0) + { + throw new BadRequestException("You can only be an admin of one free organization."); + } + } + else if (plan.Type != PlanType.Free) + { + await _paymentService.PurchaseOrganizationAsync(organization, signup.PaymentMethodType.Value, + signup.PaymentToken, plan, signup.AdditionalStorageGb, signup.AdditionalSeats, + signup.PremiumAccessAddon, signup.TaxInfo); + } + + var ownerId = provider ? default : signup.Owner.Id; + var returnValue = await SignUpAsync(organization, ownerId, signup.OwnerKey, signup.CollectionName, true); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.Signup, organization) + { + PlanName = plan.Name, + PlanType = plan.Type, + Seats = returnValue.Item1.Seats, + Storage = returnValue.Item1.MaxStorageGb, + }); + return returnValue; + } + + private async Task ValidateSignUpPoliciesAsync(Guid ownerId) + { + var singleOrgPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(ownerId, PolicyType.SingleOrg); + if (singleOrgPolicyCount > 0) + { + throw new BadRequestException("You may not create an organization. You belong to an organization " + + "which has a policy that prohibits you from being a member of any other organization."); + } + } + + public async Task> SignUpAsync( + OrganizationLicense license, User owner, string ownerKey, string collectionName, string publicKey, + string privateKey) + { + if (license?.LicenseType != null && license.LicenseType != LicenseType.Organization) + { + throw new BadRequestException("Premium licenses cannot be applied to an organization. " + + "Upload this license from your personal account settings page."); + } + + if (license == null || !_licensingService.VerifyLicense(license)) + { + throw new BadRequestException("Invalid license."); + } + + if (!license.CanUse(_globalSettings)) + { + throw new BadRequestException("Invalid license. Make sure your license allows for on-premise " + + "hosting of organizations and that the installation id matches your current installation."); + } + + if (license.PlanType != PlanType.Custom && + StaticStore.Plans.FirstOrDefault(p => p.Type == license.PlanType && !p.Disabled) == null) + { + throw new BadRequestException("Plan not found."); + } + + var enabledOrgs = await _organizationRepository.GetManyByEnabledAsync(); + if (enabledOrgs.Any(o => o.LicenseKey.Equals(license.LicenseKey))) + { + throw new BadRequestException("License is already in use by another organization."); + } + + await ValidateSignUpPoliciesAsync(owner.Id); + + var organization = new Organization + { + Name = license.Name, + BillingEmail = license.BillingEmail, + BusinessName = license.BusinessName, + PlanType = license.PlanType, + Seats = license.Seats, + MaxCollections = license.MaxCollections, + MaxStorageGb = _globalSettings.SelfHosted ? 10240 : license.MaxStorageGb, // 10 TB + UsePolicies = license.UsePolicies, + UseSso = license.UseSso, + UseKeyConnector = license.UseKeyConnector, + UseScim = license.UseScim, + UseGroups = license.UseGroups, + UseDirectory = license.UseDirectory, + UseEvents = license.UseEvents, + UseTotp = license.UseTotp, + Use2fa = license.Use2fa, + UseApi = license.UseApi, + UseResetPassword = license.UseResetPassword, + Plan = license.Plan, + SelfHost = license.SelfHost, + UsersGetPremium = license.UsersGetPremium, + Gateway = null, + GatewayCustomerId = null, + GatewaySubscriptionId = null, + ReferenceData = owner.ReferenceData, + Enabled = license.Enabled, + ExpirationDate = license.Expires, + LicenseKey = license.LicenseKey, + PublicKey = publicKey, + PrivateKey = privateKey, + CreationDate = DateTime.UtcNow, + RevisionDate = DateTime.UtcNow + }; + + var result = await SignUpAsync(organization, owner.Id, ownerKey, collectionName, false); + + var dir = $"{_globalSettings.LicenseDirectory}/organization"; + Directory.CreateDirectory(dir); + await using var fs = new FileStream(Path.Combine(dir, $"{organization.Id}.json"), FileMode.Create); + await JsonSerializer.SerializeAsync(fs, license, JsonHelpers.Indented); + return result; + } + + private async Task> SignUpAsync(Organization organization, + Guid ownerId, string ownerKey, string collectionName, bool withPayment) + { + try + { + await _organizationRepository.CreateAsync(organization); + await _organizationApiKeyRepository.CreateAsync(new OrganizationApiKey + { + OrganizationId = organization.Id, + ApiKey = CoreHelpers.SecureRandomString(30), + Type = OrganizationApiKeyType.Default, + RevisionDate = DateTime.UtcNow, + }); + await _applicationCacheService.UpsertOrganizationAbilityAsync(organization); + + if (!string.IsNullOrWhiteSpace(collectionName)) + { + var defaultCollection = new Collection { - throw new BadRequestException($"Your organization currently has {collectionCount} collections. " + - $"Your new plan allows for a maximum of ({newPlan.MaxCollections.Value}) collections. " + - "Remove some collections."); - } + Name = collectionName, + OrganizationId = organization.Id, + CreationDate = organization.CreationDate, + RevisionDate = organization.CreationDate + }; + await _collectionRepository.CreateAsync(defaultCollection); } - if (!newPlan.HasGroups && organization.UseGroups) + OrganizationUser orgUser = null; + if (ownerId != default) { - var groups = await _groupRepository.GetManyByOrganizationIdAsync(organization.Id); - if (groups.Any()) + orgUser = new OrganizationUser { - throw new BadRequestException($"Your new plan does not allow the groups feature. " + - $"Remove your groups."); - } + OrganizationId = organization.Id, + UserId = ownerId, + Key = ownerKey, + Type = OrganizationUserType.Owner, + Status = OrganizationUserStatusType.Confirmed, + AccessAll = true, + CreationDate = organization.CreationDate, + RevisionDate = organization.CreationDate + }; + + await _organizationUserRepository.CreateAsync(orgUser); + + var deviceIds = await GetUserDeviceIdsAsync(orgUser.UserId.Value); + await _pushRegistrationService.AddUserRegistrationOrganizationAsync(deviceIds, + organization.Id.ToString()); + await _pushNotificationService.PushSyncOrgKeysAsync(ownerId); } - if (!newPlan.HasPolicies && organization.UsePolicies) + return new Tuple(organization, orgUser); + } + catch + { + if (withPayment) { - var policies = await _policyRepository.GetManyByOrganizationIdAsync(organization.Id); - if (policies.Any(p => p.Enabled)) - { - throw new BadRequestException($"Your new plan does not allow the policies feature. " + - $"Disable your policies."); - } + await _paymentService.CancelAndRecoverChargesAsync(organization); } - if (!newPlan.HasSso && organization.UseSso) + if (organization.Id != default(Guid)) { - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); - if (ssoConfig != null && ssoConfig.Enabled) - { - throw new BadRequestException($"Your new plan does not allow the SSO feature. " + - $"Disable your SSO configuration."); - } + await _organizationRepository.DeleteAsync(organization); + await _applicationCacheService.DeleteOrganizationAbilityAsync(organization.Id); } - if (!newPlan.HasKeyConnector && organization.UseKeyConnector) - { - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); - if (ssoConfig != null && ssoConfig.GetData().KeyConnectorEnabled) - { - throw new BadRequestException("Your new plan does not allow the Key Connector feature. " + - "Disable your Key Connector."); - } - } + throw; + } + } - if (!newPlan.HasResetPassword && organization.UseResetPassword) - { - var resetPasswordPolicy = - await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); - if (resetPasswordPolicy != null && resetPasswordPolicy.Enabled) - { - throw new BadRequestException("Your new plan does not allow the Password Reset feature. " + - "Disable your Password Reset policy."); - } - } + public async Task UpdateLicenseAsync(Guid organizationId, OrganizationLicense license) + { + var organization = await GetOrgById(organizationId); + if (organization == null) + { + throw new NotFoundException(); + } - if (!newPlan.HasScim && organization.UseScim) - { - var scimConnections = await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(organization.Id, - OrganizationConnectionType.Scim); - if (scimConnections != null && scimConnections.Any(c => c.GetConfig()?.Enabled == true)) - { - throw new BadRequestException("Your new plan does not allow the SCIM feature. " + - "Disable your SCIM configuration."); - } - } + if (!_globalSettings.SelfHosted) + { + throw new InvalidOperationException("Licenses require self hosting."); + } - // TODO: Check storage? + if (license?.LicenseType != null && license.LicenseType != LicenseType.Organization) + { + throw new BadRequestException("Premium licenses cannot be applied to an organization. " + + "Upload this license from your personal account settings page."); + } - string paymentIntentClientSecret = null; - var success = true; - if (string.IsNullOrWhiteSpace(organization.GatewaySubscriptionId)) - { - paymentIntentClientSecret = await _paymentService.UpgradeFreeOrganizationAsync(organization, newPlan, - upgrade.AdditionalStorageGb, upgrade.AdditionalSeats, upgrade.PremiumAccessAddon, upgrade.TaxInfo); - success = string.IsNullOrWhiteSpace(paymentIntentClientSecret); - } - else - { - // TODO: Update existing sub - throw new BadRequestException("You can only upgrade from the free plan. Contact support."); - } + if (license == null || !_licensingService.VerifyLicense(license)) + { + throw new BadRequestException("Invalid license."); + } - organization.BusinessName = upgrade.BusinessName; - organization.PlanType = newPlan.Type; - organization.Seats = (short)(newPlan.BaseSeats + upgrade.AdditionalSeats); - organization.MaxCollections = newPlan.MaxCollections; - organization.UseGroups = newPlan.HasGroups; - organization.UseDirectory = newPlan.HasDirectory; - organization.UseEvents = newPlan.HasEvents; - organization.UseTotp = newPlan.HasTotp; - organization.Use2fa = newPlan.Has2fa; - organization.UseApi = newPlan.HasApi; - organization.SelfHost = newPlan.HasSelfHost; - organization.UsePolicies = newPlan.HasPolicies; - organization.MaxStorageGb = !newPlan.BaseStorageGb.HasValue ? - (short?)null : (short)(newPlan.BaseStorageGb.Value + upgrade.AdditionalStorageGb); - organization.UseGroups = newPlan.HasGroups; - organization.UseDirectory = newPlan.HasDirectory; - organization.UseEvents = newPlan.HasEvents; - organization.UseTotp = newPlan.HasTotp; - organization.Use2fa = newPlan.Has2fa; - organization.UseApi = newPlan.HasApi; - organization.UseSso = newPlan.HasSso; - organization.UseKeyConnector = newPlan.HasKeyConnector; - organization.UseScim = newPlan.HasScim; - organization.UseResetPassword = newPlan.HasResetPassword; - organization.SelfHost = newPlan.HasSelfHost; - organization.UsersGetPremium = newPlan.UsersGetPremium || upgrade.PremiumAccessAddon; - organization.Plan = newPlan.Name; - organization.Enabled = success; - organization.PublicKey = upgrade.PublicKey; - organization.PrivateKey = upgrade.PrivateKey; - await ReplaceAndUpdateCache(organization); - if (success) + if (!license.CanUse(_globalSettings)) + { + throw new BadRequestException("Invalid license. Make sure your license allows for on-premise " + + "hosting of organizations and that the installation id matches your current installation."); + } + + var enabledOrgs = await _organizationRepository.GetManyByEnabledAsync(); + if (enabledOrgs.Any(o => o.LicenseKey.Equals(license.LicenseKey) && o.Id != organizationId)) + { + throw new BadRequestException("License is already in use by another organization."); + } + + if (license.Seats.HasValue && + (!organization.Seats.HasValue || organization.Seats.Value > license.Seats.Value)) + { + var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organization.Id); + if (userCount > license.Seats.Value) { + throw new BadRequestException($"Your organization currently has {userCount} seats filled. " + + $"Your new license only has ({license.Seats.Value}) seats. Remove some users."); + } + } + + if (license.MaxCollections.HasValue && (!organization.MaxCollections.HasValue || + organization.MaxCollections.Value > license.MaxCollections.Value)) + { + var collectionCount = await _collectionRepository.GetCountByOrganizationIdAsync(organization.Id); + if (collectionCount > license.MaxCollections.Value) + { + throw new BadRequestException($"Your organization currently has {collectionCount} collections. " + + $"Your new license allows for a maximum of ({license.MaxCollections.Value}) collections. " + + "Remove some collections."); + } + } + + if (!license.UseGroups && organization.UseGroups) + { + var groups = await _groupRepository.GetManyByOrganizationIdAsync(organization.Id); + if (groups.Count > 0) + { + throw new BadRequestException($"Your organization currently has {groups.Count} groups. " + + $"Your new license does not allow for the use of groups. Remove all groups."); + } + } + + if (!license.UsePolicies && organization.UsePolicies) + { + var policies = await _policyRepository.GetManyByOrganizationIdAsync(organization.Id); + if (policies.Any(p => p.Enabled)) + { + throw new BadRequestException($"Your organization currently has {policies.Count} enabled " + + $"policies. Your new license does not allow for the use of policies. Disable all policies."); + } + } + + if (!license.UseSso && organization.UseSso) + { + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); + if (ssoConfig != null && ssoConfig.Enabled) + { + throw new BadRequestException($"Your organization currently has a SSO configuration. " + + $"Your new license does not allow for the use of SSO. Disable your SSO configuration."); + } + } + + if (!license.UseKeyConnector && organization.UseKeyConnector) + { + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); + if (ssoConfig != null && ssoConfig.GetData().KeyConnectorEnabled) + { + throw new BadRequestException($"Your organization currently has Key Connector enabled. " + + $"Your new license does not allow for the use of Key Connector. Disable your Key Connector."); + } + } + + if (!license.UseScim && organization.UseScim) + { + var scimConnections = await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(organization.Id, + OrganizationConnectionType.Scim); + if (scimConnections != null && scimConnections.Any(c => c.GetConfig()?.Enabled == true)) + { + throw new BadRequestException("Your new plan does not allow the SCIM feature. " + + "Disable your SCIM configuration."); + } + } + + if (!license.UseResetPassword && organization.UseResetPassword) + { + var resetPasswordPolicy = + await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); + if (resetPasswordPolicy != null && resetPasswordPolicy.Enabled) + { + throw new BadRequestException("Your new license does not allow the Password Reset feature. " + + "Disable your Password Reset policy."); + } + } + + var dir = $"{_globalSettings.LicenseDirectory}/organization"; + Directory.CreateDirectory(dir); + await using var fs = new FileStream(Path.Combine(dir, $"{organization.Id}.json"), FileMode.Create); + await JsonSerializer.SerializeAsync(fs, license, JsonHelpers.Indented); + + organization.Name = license.Name; + organization.BusinessName = license.BusinessName; + organization.BillingEmail = license.BillingEmail; + organization.PlanType = license.PlanType; + organization.Seats = license.Seats; + organization.MaxCollections = license.MaxCollections; + organization.UseGroups = license.UseGroups; + organization.UseDirectory = license.UseDirectory; + organization.UseEvents = license.UseEvents; + organization.UseTotp = license.UseTotp; + organization.Use2fa = license.Use2fa; + organization.UseApi = license.UseApi; + organization.UsePolicies = license.UsePolicies; + organization.UseSso = license.UseSso; + organization.UseKeyConnector = license.UseKeyConnector; + organization.UseScim = license.UseScim; + organization.UseResetPassword = license.UseResetPassword; + organization.SelfHost = license.SelfHost; + organization.UsersGetPremium = license.UsersGetPremium; + organization.Plan = license.Plan; + organization.Enabled = license.Enabled; + organization.ExpirationDate = license.Expires; + organization.LicenseKey = license.LicenseKey; + organization.RevisionDate = DateTime.UtcNow; + await ReplaceAndUpdateCache(organization); + } + + public async Task DeleteAsync(Organization organization) + { + await ValidateDeleteOrganizationAsync(organization); + + if (!string.IsNullOrWhiteSpace(organization.GatewaySubscriptionId)) + { + try + { + var eop = !organization.ExpirationDate.HasValue || + organization.ExpirationDate.Value >= DateTime.UtcNow; + await _paymentService.CancelSubscriptionAsync(organization, eop); await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.UpgradePlan, organization) - { - PlanName = newPlan.Name, - PlanType = newPlan.Type, - OldPlanName = existingPlan.Name, - OldPlanType = existingPlan.Type, - Seats = organization.Seats, - Storage = organization.MaxStorageGb, - }); + new ReferenceEvent(ReferenceEventType.DeleteAccount, organization)); } - - return new Tuple(success, paymentIntentClientSecret); + catch (GatewayException) { } } - public async Task AdjustStorageAsync(Guid organizationId, short storageAdjustmentGb) + await _organizationRepository.DeleteAsync(organization); + await _applicationCacheService.DeleteOrganizationAbilityAsync(organization.Id); + } + + public async Task EnableAsync(Guid organizationId, DateTime? expirationDate) + { + var org = await GetOrgById(organizationId); + if (org != null && !org.Enabled && org.Gateway.HasValue) { - var organization = await GetOrgById(organizationId); - if (organization == null) - { - throw new NotFoundException(); - } + org.Enabled = true; + org.ExpirationDate = expirationDate; + org.RevisionDate = DateTime.UtcNow; + await ReplaceAndUpdateCache(org); + } + } - var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType); - if (plan == null) - { - throw new BadRequestException("Existing plan not found."); - } + public async Task DisableAsync(Guid organizationId, DateTime? expirationDate) + { + var org = await GetOrgById(organizationId); + if (org != null && org.Enabled) + { + org.Enabled = false; + org.ExpirationDate = expirationDate; + org.RevisionDate = DateTime.UtcNow; + await ReplaceAndUpdateCache(org); - if (!plan.HasAdditionalStorageOption) - { - throw new BadRequestException("Plan does not allow additional storage."); - } + // TODO: send email to owners? + } + } - var secret = await BillingHelpers.AdjustStorageAsync(_paymentService, organization, storageAdjustmentGb, - plan.StripeStoragePlanId); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.AdjustStorage, organization) - { - PlanName = plan.Name, - PlanType = plan.Type, - Storage = storageAdjustmentGb, - }); - await ReplaceAndUpdateCache(organization); - return secret; + public async Task UpdateExpirationDateAsync(Guid organizationId, DateTime? expirationDate) + { + var org = await GetOrgById(organizationId); + if (org != null) + { + org.ExpirationDate = expirationDate; + org.RevisionDate = DateTime.UtcNow; + await ReplaceAndUpdateCache(org); + } + } + + public async Task EnableAsync(Guid organizationId) + { + var org = await GetOrgById(organizationId); + if (org != null && !org.Enabled) + { + org.Enabled = true; + await ReplaceAndUpdateCache(org); + } + } + + public async Task UpdateAsync(Organization organization, bool updateBilling = false) + { + if (organization.Id == default(Guid)) + { + throw new ApplicationException("Cannot create org this way. Call SignUpAsync."); } - public async Task UpdateSubscription(Guid organizationId, int seatAdjustment, int? maxAutoscaleSeats) + if (!string.IsNullOrWhiteSpace(organization.Identifier)) { - var organization = await GetOrgById(organizationId); - if (organization == null) + var orgById = await _organizationRepository.GetByIdentifierAsync(organization.Identifier); + if (orgById != null && orgById.Id != organization.Id) { - throw new NotFoundException(); - } - - var newSeatCount = organization.Seats + seatAdjustment; - if (maxAutoscaleSeats.HasValue && newSeatCount > maxAutoscaleSeats.Value) - { - throw new BadRequestException("Cannot set max seat autoscaling below seat count."); - } - - if (seatAdjustment != 0) - { - await AdjustSeatsAsync(organization, seatAdjustment); - } - if (maxAutoscaleSeats != organization.MaxAutoscaleSeats) - { - await UpdateAutoscalingAsync(organization, maxAutoscaleSeats); + throw new BadRequestException("Identifier already in use by another organization."); } } - private async Task UpdateAutoscalingAsync(Organization organization, int? maxAutoscaleSeats) + await ReplaceAndUpdateCache(organization, EventType.Organization_Updated); + + if (updateBilling && !string.IsNullOrWhiteSpace(organization.GatewayCustomerId)) { - - if (maxAutoscaleSeats.HasValue && - organization.Seats.HasValue && - maxAutoscaleSeats.Value < organization.Seats.Value) + var customerService = new CustomerService(); + await customerService.UpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions { - throw new BadRequestException($"Cannot set max seat autoscaling below current seat count."); - } + Email = organization.BillingEmail, + Description = organization.BusinessName + }); + } + } - var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType); - if (plan == null) - { - throw new BadRequestException("Existing plan not found."); - } - - if (!plan.AllowSeatAutoscale) - { - throw new BadRequestException("Your plan does not allow seat autoscaling."); - } - - if (plan.MaxUsers.HasValue && maxAutoscaleSeats.HasValue && - maxAutoscaleSeats > plan.MaxUsers) - { - throw new BadRequestException(string.Concat($"Your plan has a seat limit of {plan.MaxUsers}, ", - $"but you have specified a max autoscale count of {maxAutoscaleSeats}.", - "Reduce your max autoscale seat count.")); - } - - organization.MaxAutoscaleSeats = maxAutoscaleSeats; - - await ReplaceAndUpdateCache(organization); + public async Task UpdateTwoFactorProviderAsync(Organization organization, TwoFactorProviderType type) + { + if (!type.ToString().Contains("Organization")) + { + throw new ArgumentException("Not an organization provider type."); } - public async Task AdjustSeatsAsync(Guid organizationId, int seatAdjustment, DateTime? prorationDate = null) + if (!organization.Use2fa) { - var organization = await GetOrgById(organizationId); - if (organization == null) - { - throw new NotFoundException(); - } - - return await AdjustSeatsAsync(organization, seatAdjustment, prorationDate); + throw new BadRequestException("Organization cannot use 2FA."); } - private async Task AdjustSeatsAsync(Organization organization, int seatAdjustment, DateTime? prorationDate = null, IEnumerable ownerEmails = null) + var providers = organization.GetTwoFactorProviders(); + if (!providers?.ContainsKey(type) ?? true) { - if (organization.Seats == null) + return; + } + + providers[type].Enabled = true; + organization.SetTwoFactorProviders(providers); + await UpdateAsync(organization); + } + + public async Task DisableTwoFactorProviderAsync(Organization organization, TwoFactorProviderType type) + { + if (!type.ToString().Contains("Organization")) + { + throw new ArgumentException("Not an organization provider type."); + } + + var providers = organization.GetTwoFactorProviders(); + if (!providers?.ContainsKey(type) ?? true) + { + return; + } + + providers.Remove(type); + organization.SetTwoFactorProviders(providers); + await UpdateAsync(organization); + } + + public async Task> InviteUsersAsync(Guid organizationId, Guid? invitingUserId, + IEnumerable<(OrganizationUserInvite invite, string externalId)> invites) + { + var organization = await GetOrgById(organizationId); + var initialSeatCount = organization.Seats; + if (organization == null || invites.Any(i => i.invite.Emails == null)) + { + throw new NotFoundException(); + } + + var inviteTypes = new HashSet(invites.Where(i => i.invite.Type.HasValue) + .Select(i => i.invite.Type.Value)); + if (invitingUserId.HasValue && inviteTypes.Count > 0) + { + foreach (var type in inviteTypes) { - throw new BadRequestException("Organization has no seat limit, no need to adjust seats"); + await ValidateOrganizationUserUpdatePermissions(organizationId, type, null); } + } - if (string.IsNullOrWhiteSpace(organization.GatewayCustomerId)) + var newSeatsRequired = 0; + var existingEmails = new HashSet(await _organizationUserRepository.SelectKnownEmailsAsync( + organizationId, invites.SelectMany(i => i.invite.Emails), false), StringComparer.InvariantCultureIgnoreCase); + if (organization.Seats.HasValue) + { + var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organizationId); + var availableSeats = organization.Seats.Value - userCount; + newSeatsRequired = invites.Sum(i => i.invite.Emails.Count()) - existingEmails.Count() - availableSeats; + } + + if (newSeatsRequired > 0) + { + var (canScale, failureReason) = CanScale(organization, newSeatsRequired); + if (!canScale) { - throw new BadRequestException("No payment method found."); + throw new BadRequestException(failureReason); } + } - if (string.IsNullOrWhiteSpace(organization.GatewaySubscriptionId)) - { - throw new BadRequestException("No subscription found."); - } + var invitedAreAllOwners = invites.All(i => i.invite.Type == OrganizationUserType.Owner); + if (!invitedAreAllOwners && !await HasConfirmedOwnersExceptAsync(organizationId, new Guid[] { })) + { + throw new BadRequestException("Organization must have at least one confirmed owner."); + } - var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == organization.PlanType); - if (plan == null) - { - throw new BadRequestException("Existing plan not found."); - } - if (!plan.HasAdditionalSeatsOption) - { - throw new BadRequestException("Plan does not allow additional seats."); - } - - var newSeatTotal = organization.Seats.Value + seatAdjustment; - if (plan.BaseSeats > newSeatTotal) - { - throw new BadRequestException($"Plan has a minimum of {plan.BaseSeats} seats."); - } - - if (newSeatTotal <= 0) - { - throw new BadRequestException("You must have at least 1 seat."); - } - - var additionalSeats = newSeatTotal - plan.BaseSeats; - if (plan.MaxAdditionalSeats.HasValue && additionalSeats > plan.MaxAdditionalSeats.Value) - { - throw new BadRequestException($"Organization plan allows a maximum of " + - $"{plan.MaxAdditionalSeats.Value} additional seats."); - } - - if (!organization.Seats.HasValue || organization.Seats.Value > newSeatTotal) - { - var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organization.Id); - if (userCount > newSeatTotal) - { - throw new BadRequestException($"Your organization currently has {userCount} seats filled. " + - $"Your new plan only has ({newSeatTotal}) seats. Remove some users."); - } - } - - var paymentIntentClientSecret = await _paymentService.AdjustSeatsAsync(organization, plan, additionalSeats, prorationDate); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.AdjustSeats, organization) - { - PlanName = plan.Name, - PlanType = plan.Type, - Seats = newSeatTotal, - PreviousSeats = organization.Seats - }); - organization.Seats = (short?)newSeatTotal; - await ReplaceAndUpdateCache(organization); - - if (organization.Seats.HasValue && organization.MaxAutoscaleSeats.HasValue && organization.Seats == organization.MaxAutoscaleSeats) + var orgUsers = new List(); + var limitedCollectionOrgUsers = new List<(OrganizationUser, IEnumerable)>(); + var orgUserInvitedCount = 0; + var exceptions = new List(); + var events = new List<(OrganizationUser, EventType, DateTime?)>(); + foreach (var (invite, externalId) in invites) + { + // Prevent duplicate invitations + foreach (var email in invite.Emails.Distinct()) { try { - if (ownerEmails == null) + // Make sure user is not already invited + if (existingEmails.Contains(email)) { - ownerEmails = (await _organizationUserRepository.GetManyByMinimumRoleAsync(organization.Id, - OrganizationUserType.Owner)).Select(u => u.Email).Distinct(); + continue; } - await _mailService.SendOrganizationMaxSeatLimitReachedEmailAsync(organization, organization.MaxAutoscaleSeats.Value, ownerEmails); + + var orgUser = new OrganizationUser + { + OrganizationId = organizationId, + UserId = null, + Email = email.ToLowerInvariant(), + Key = null, + Type = invite.Type.Value, + Status = OrganizationUserStatusType.Invited, + AccessAll = invite.AccessAll, + ExternalId = externalId, + CreationDate = DateTime.UtcNow, + RevisionDate = DateTime.UtcNow, + }; + + if (invite.Permissions != null) + { + orgUser.Permissions = JsonSerializer.Serialize(invite.Permissions, JsonHelpers.CamelCase); + } + + if (!orgUser.AccessAll && invite.Collections.Any()) + { + limitedCollectionOrgUsers.Add((orgUser, invite.Collections)); + } + else + { + orgUsers.Add(orgUser); + } + + events.Add((orgUser, EventType.OrganizationUser_Invited, DateTime.UtcNow)); + orgUserInvitedCount++; } catch (Exception e) { - _logger.LogError(e, "Error encountered notifying organization owners of seat limit reached."); + exceptions.Add(e); } } - - return paymentIntentClientSecret; } - public async Task VerifyBankAsync(Guid organizationId, int amount1, int amount2) + if (exceptions.Any()) { - var organization = await GetOrgById(organizationId); - if (organization == null) + throw new AggregateException("One or more errors occurred while inviting users.", exceptions); + } + + var prorationDate = DateTime.UtcNow; + try + { + await _organizationUserRepository.CreateManyAsync(orgUsers); + foreach (var (orgUser, collections) in limitedCollectionOrgUsers) { - throw new NotFoundException(); + await _organizationUserRepository.CreateAsync(orgUser, collections); } - if (string.IsNullOrWhiteSpace(organization.GatewayCustomerId)) + if (!await _currentContext.ManageUsers(organization.Id)) { - throw new GatewayException("Not a gateway customer."); + throw new BadRequestException("Cannot add seats. Cannot manage organization users."); } - var bankService = new BankAccountService(); - var customerService = new CustomerService(); - var customer = await customerService.GetAsync(organization.GatewayCustomerId, - new CustomerGetOptions { Expand = new List { "sources" } }); - if (customer == null) - { - throw new GatewayException("Cannot find customer."); - } + await AutoAddSeatsAsync(organization, newSeatsRequired, prorationDate); + await SendInvitesAsync(orgUsers.Concat(limitedCollectionOrgUsers.Select(u => u.Item1)), organization); + await _eventService.LogOrganizationUserEventsAsync(events); - var bankAccount = customer.Sources - .FirstOrDefault(s => s is BankAccount && ((BankAccount)s).Status != "verified") as BankAccount; - if (bankAccount == null) - { - throw new GatewayException("Cannot find an unverified bank account."); - } - - try - { - var result = await bankService.VerifyAsync(organization.GatewayCustomerId, bankAccount.Id, - new BankAccountVerifyOptions { Amounts = new List { amount1, amount2 } }); - if (result.Status != "verified") + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.InvitedUsers, organization) { - throw new GatewayException("Unable to verify account."); - } - } - catch (StripeException e) + Users = orgUserInvitedCount + }); + } + catch (Exception e) + { + // Revert any added users. + var invitedOrgUserIds = orgUsers.Select(u => u.Id).Concat(limitedCollectionOrgUsers.Select(u => u.Item1.Id)); + await _organizationUserRepository.DeleteManyAsync(invitedOrgUserIds); + var currentSeatCount = (await _organizationRepository.GetByIdAsync(organization.Id)).Seats; + + if (initialSeatCount.HasValue && currentSeatCount.HasValue && currentSeatCount.Value != initialSeatCount.Value) { - throw new GatewayException(e.Message); + await AdjustSeatsAsync(organization, initialSeatCount.Value - currentSeatCount.Value, prorationDate); } + + exceptions.Add(e); } - public async Task> SignUpAsync(OrganizationSignup signup, - bool provider = false) + if (exceptions.Any()) { - var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == signup.Plan); - if (!(plan is { LegacyYear: null })) + throw new AggregateException("One or more errors occurred while inviting users.", exceptions); + } + + return orgUsers; + } + + public async Task>> ResendInvitesAsync(Guid organizationId, Guid? invitingUserId, + IEnumerable organizationUsersId) + { + var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUsersId); + var org = await GetOrgById(organizationId); + + var result = new List>(); + foreach (var orgUser in orgUsers) + { + if (orgUser.Status != OrganizationUserStatusType.Invited || orgUser.OrganizationId != organizationId) { - throw new BadRequestException("Invalid plan selected."); + result.Add(Tuple.Create(orgUser, "User invalid.")); + continue; } - if (plan.Disabled) + await SendInviteAsync(orgUser, org); + result.Add(Tuple.Create(orgUser, "")); + } + + return result; + } + + public async Task ResendInviteAsync(Guid organizationId, Guid? invitingUserId, Guid organizationUserId) + { + var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); + if (orgUser == null || orgUser.OrganizationId != organizationId || + orgUser.Status != OrganizationUserStatusType.Invited) + { + throw new BadRequestException("User invalid."); + } + + var org = await GetOrgById(orgUser.OrganizationId); + await SendInviteAsync(orgUser, org); + } + + private async Task SendInvitesAsync(IEnumerable orgUsers, Organization organization) + { + string MakeToken(OrganizationUser orgUser) => + _dataProtector.Protect($"OrganizationUserInvite {orgUser.Id} {orgUser.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + await _mailService.BulkSendOrganizationInviteEmailAsync(organization.Name, + orgUsers.Select(o => (o, new ExpiringToken(MakeToken(o), DateTime.UtcNow.AddDays(5))))); + } + + private async Task SendInviteAsync(OrganizationUser orgUser, Organization organization) + { + var now = DateTime.UtcNow; + var nowMillis = CoreHelpers.ToEpocMilliseconds(now); + var token = _dataProtector.Protect( + $"OrganizationUserInvite {orgUser.Id} {orgUser.Email} {nowMillis}"); + + await _mailService.SendOrganizationInviteEmailAsync(organization.Name, orgUser, new ExpiringToken(token, now.AddDays(5))); + } + + public async Task AcceptUserAsync(Guid organizationUserId, User user, string token, + IUserService userService) + { + var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); + if (orgUser == null) + { + throw new BadRequestException("User invalid."); + } + + if (!CoreHelpers.UserInviteTokenIsValid(_dataProtector, token, user.Email, orgUser.Id, _globalSettings)) + { + throw new BadRequestException("Invalid token."); + } + + var existingOrgUserCount = await _organizationUserRepository.GetCountByOrganizationAsync( + orgUser.OrganizationId, user.Email, true); + if (existingOrgUserCount > 0) + { + if (orgUser.Status == OrganizationUserStatusType.Accepted) { - throw new BadRequestException("Plan not found."); + throw new BadRequestException("Invitation already accepted. You will receive an email when your organization membership is confirmed."); } + throw new BadRequestException("You are already part of this organization."); + } - if (!provider) + if (string.IsNullOrWhiteSpace(orgUser.Email) || + !orgUser.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase)) + { + throw new BadRequestException("User email does not match invite."); + } + + return await AcceptUserAsync(orgUser, user, userService); + } + + public async Task AcceptUserAsync(string orgIdentifier, User user, IUserService userService) + { + var org = await _organizationRepository.GetByIdentifierAsync(orgIdentifier); + if (org == null) + { + throw new BadRequestException("Organization invalid."); + } + + var usersOrgs = await _organizationUserRepository.GetManyByUserAsync(user.Id); + var orgUser = usersOrgs.FirstOrDefault(u => u.OrganizationId == org.Id); + if (orgUser == null) + { + throw new BadRequestException("User not found within organization."); + } + + return await AcceptUserAsync(orgUser, user, userService); + } + + private async Task AcceptUserAsync(OrganizationUser orgUser, User user, + IUserService userService) + { + if (orgUser.Status != OrganizationUserStatusType.Invited) + { + throw new BadRequestException("Already accepted."); + } + + if (orgUser.Type == OrganizationUserType.Owner || orgUser.Type == OrganizationUserType.Admin) + { + var org = await GetOrgById(orgUser.OrganizationId); + if (org.PlanType == PlanType.Free) { - await ValidateSignUpPoliciesAsync(signup.Owner.Id); - } - - ValidateOrganizationUpgradeParameters(plan, signup); - - var organization = new Organization - { - // Pre-generate the org id so that we can save it with the Stripe subscription.. - Id = CoreHelpers.GenerateComb(), - Name = signup.Name, - BillingEmail = signup.BillingEmail, - BusinessName = signup.BusinessName, - PlanType = plan.Type, - Seats = (short)(plan.BaseSeats + signup.AdditionalSeats), - MaxCollections = plan.MaxCollections, - MaxStorageGb = !plan.BaseStorageGb.HasValue ? - (short?)null : (short)(plan.BaseStorageGb.Value + signup.AdditionalStorageGb), - UsePolicies = plan.HasPolicies, - UseSso = plan.HasSso, - UseGroups = plan.HasGroups, - UseEvents = plan.HasEvents, - UseDirectory = plan.HasDirectory, - UseTotp = plan.HasTotp, - Use2fa = plan.Has2fa, - UseApi = plan.HasApi, - UseResetPassword = plan.HasResetPassword, - SelfHost = plan.HasSelfHost, - UsersGetPremium = plan.UsersGetPremium || signup.PremiumAccessAddon, - UseScim = plan.HasScim, - Plan = plan.Name, - Gateway = null, - ReferenceData = signup.Owner.ReferenceData, - Enabled = true, - LicenseKey = CoreHelpers.SecureRandomString(20), - PublicKey = signup.PublicKey, - PrivateKey = signup.PrivateKey, - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow, - }; - - if (plan.Type == PlanType.Free && !provider) - { - var adminCount = - await _organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync(signup.Owner.Id); + var adminCount = await _organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync( + user.Id); if (adminCount > 0) { throw new BadRequestException("You can only be an admin of one free organization."); } } - else if (plan.Type != PlanType.Free) - { - await _paymentService.PurchaseOrganizationAsync(organization, signup.PaymentMethodType.Value, - signup.PaymentToken, plan, signup.AdditionalStorageGb, signup.AdditionalSeats, - signup.PremiumAccessAddon, signup.TaxInfo); - } - - var ownerId = provider ? default : signup.Owner.Id; - var returnValue = await SignUpAsync(organization, ownerId, signup.OwnerKey, signup.CollectionName, true); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.Signup, organization) - { - PlanName = plan.Name, - PlanType = plan.Type, - Seats = returnValue.Item1.Seats, - Storage = returnValue.Item1.MaxStorageGb, - }); - return returnValue; } - private async Task ValidateSignUpPoliciesAsync(Guid ownerId) + // Enforce Single Organization Policy of organization user is trying to join + var allOrgUsers = await _organizationUserRepository.GetManyByUserAsync(user.Id); + var hasOtherOrgs = allOrgUsers.Any(ou => ou.OrganizationId != orgUser.OrganizationId); + var invitedSingleOrgPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(user.Id, + PolicyType.SingleOrg, OrganizationUserStatusType.Invited); + + if (hasOtherOrgs && invitedSingleOrgPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId)) { - var singleOrgPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(ownerId, PolicyType.SingleOrg); - if (singleOrgPolicyCount > 0) + throw new BadRequestException("You may not join this organization until you leave or remove " + + "all other organizations."); + } + + // Enforce Single Organization Policy of other organizations user is a member of + var singleOrgPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(user.Id, + PolicyType.SingleOrg); + if (singleOrgPolicyCount > 0) + { + throw new BadRequestException("You cannot join this organization because you are a member of " + + "another organization which forbids it"); + } + + // Enforce Two Factor Authentication Policy of organization user is trying to join + if (!await userService.TwoFactorIsEnabledAsync(user)) + { + var invitedTwoFactorPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(user.Id, + PolicyType.TwoFactorAuthentication, OrganizationUserStatusType.Invited); + if (invitedTwoFactorPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId)) { - throw new BadRequestException("You may not create an organization. You belong to an organization " + - "which has a policy that prohibits you from being a member of any other organization."); + throw new BadRequestException("You cannot join this organization until you enable " + + "two-step login on your user account."); } } - public async Task> SignUpAsync( - OrganizationLicense license, User owner, string ownerKey, string collectionName, string publicKey, - string privateKey) + orgUser.Status = OrganizationUserStatusType.Accepted; + orgUser.UserId = user.Id; + orgUser.Email = null; + + await _organizationUserRepository.ReplaceAsync(orgUser); + + var admins = await _organizationUserRepository.GetManyByMinimumRoleAsync(orgUser.OrganizationId, OrganizationUserType.Admin); + var adminEmails = admins.Select(a => a.Email).Distinct().ToList(); + + if (adminEmails.Count > 0) { - if (license?.LicenseType != null && license.LicenseType != LicenseType.Organization) - { - throw new BadRequestException("Premium licenses cannot be applied to an organization. " - + "Upload this license from your personal account settings page."); - } - - if (license == null || !_licensingService.VerifyLicense(license)) - { - throw new BadRequestException("Invalid license."); - } - - if (!license.CanUse(_globalSettings)) - { - throw new BadRequestException("Invalid license. Make sure your license allows for on-premise " + - "hosting of organizations and that the installation id matches your current installation."); - } - - if (license.PlanType != PlanType.Custom && - StaticStore.Plans.FirstOrDefault(p => p.Type == license.PlanType && !p.Disabled) == null) - { - throw new BadRequestException("Plan not found."); - } - - var enabledOrgs = await _organizationRepository.GetManyByEnabledAsync(); - if (enabledOrgs.Any(o => o.LicenseKey.Equals(license.LicenseKey))) - { - throw new BadRequestException("License is already in use by another organization."); - } - - await ValidateSignUpPoliciesAsync(owner.Id); - - var organization = new Organization - { - Name = license.Name, - BillingEmail = license.BillingEmail, - BusinessName = license.BusinessName, - PlanType = license.PlanType, - Seats = license.Seats, - MaxCollections = license.MaxCollections, - MaxStorageGb = _globalSettings.SelfHosted ? 10240 : license.MaxStorageGb, // 10 TB - UsePolicies = license.UsePolicies, - UseSso = license.UseSso, - UseKeyConnector = license.UseKeyConnector, - UseScim = license.UseScim, - UseGroups = license.UseGroups, - UseDirectory = license.UseDirectory, - UseEvents = license.UseEvents, - UseTotp = license.UseTotp, - Use2fa = license.Use2fa, - UseApi = license.UseApi, - UseResetPassword = license.UseResetPassword, - Plan = license.Plan, - SelfHost = license.SelfHost, - UsersGetPremium = license.UsersGetPremium, - Gateway = null, - GatewayCustomerId = null, - GatewaySubscriptionId = null, - ReferenceData = owner.ReferenceData, - Enabled = license.Enabled, - ExpirationDate = license.Expires, - LicenseKey = license.LicenseKey, - PublicKey = publicKey, - PrivateKey = privateKey, - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow - }; - - var result = await SignUpAsync(organization, owner.Id, ownerKey, collectionName, false); - - var dir = $"{_globalSettings.LicenseDirectory}/organization"; - Directory.CreateDirectory(dir); - await using var fs = new FileStream(Path.Combine(dir, $"{organization.Id}.json"), FileMode.Create); - await JsonSerializer.SerializeAsync(fs, license, JsonHelpers.Indented); - return result; + var organization = await _organizationRepository.GetByIdAsync(orgUser.OrganizationId); + await _mailService.SendOrganizationAcceptedEmailAsync(organization, user.Email, adminEmails); } - private async Task> SignUpAsync(Organization organization, - Guid ownerId, string ownerKey, string collectionName, bool withPayment) + return orgUser; + } + + public async Task ConfirmUserAsync(Guid organizationId, Guid organizationUserId, string key, + Guid confirmingUserId, IUserService userService) + { + var result = await ConfirmUsersAsync(organizationId, new Dictionary() { { organizationUserId, key } }, + confirmingUserId, userService); + + if (!result.Any()) { + throw new BadRequestException("User not valid."); + } + + var (orgUser, error) = result[0]; + if (error != "") + { + throw new BadRequestException(error); + } + return orgUser; + } + + public async Task>> ConfirmUsersAsync(Guid organizationId, Dictionary keys, + Guid confirmingUserId, IUserService userService) + { + var organizationUsers = await _organizationUserRepository.GetManyAsync(keys.Keys); + var validOrganizationUsers = organizationUsers + .Where(u => u.Status == OrganizationUserStatusType.Accepted && u.OrganizationId == organizationId && u.UserId != null) + .ToList(); + + if (!validOrganizationUsers.Any()) + { + return new List>(); + } + + var validOrganizationUserIds = validOrganizationUsers.Select(u => u.UserId.Value).ToList(); + + var organization = await GetOrgById(organizationId); + var policies = await _policyRepository.GetManyByOrganizationIdAsync(organizationId); + var usersOrgs = await _organizationUserRepository.GetManyByManyUsersAsync(validOrganizationUserIds); + var users = await _userRepository.GetManyAsync(validOrganizationUserIds); + + var keyedFilteredUsers = validOrganizationUsers.ToDictionary(u => u.UserId.Value, u => u); + var keyedOrganizationUsers = usersOrgs.GroupBy(u => u.UserId.Value) + .ToDictionary(u => u.Key, u => u.ToList()); + + var succeededUsers = new List(); + var result = new List>(); + + foreach (var user in users) + { + if (!keyedFilteredUsers.ContainsKey(user.Id)) + { + continue; + } + var orgUser = keyedFilteredUsers[user.Id]; + var orgUsers = keyedOrganizationUsers.GetValueOrDefault(user.Id, new List()); try { - await _organizationRepository.CreateAsync(organization); - await _organizationApiKeyRepository.CreateAsync(new OrganizationApiKey + if (organization.PlanType == PlanType.Free && (orgUser.Type == OrganizationUserType.Admin + || orgUser.Type == OrganizationUserType.Owner)) { - OrganizationId = organization.Id, - ApiKey = CoreHelpers.SecureRandomString(30), - Type = OrganizationApiKeyType.Default, - RevisionDate = DateTime.UtcNow, - }); - await _applicationCacheService.UpsertOrganizationAbilityAsync(organization); - - if (!string.IsNullOrWhiteSpace(collectionName)) - { - var defaultCollection = new Collection - { - Name = collectionName, - OrganizationId = organization.Id, - CreationDate = organization.CreationDate, - RevisionDate = organization.CreationDate - }; - await _collectionRepository.CreateAsync(defaultCollection); - } - - OrganizationUser orgUser = null; - if (ownerId != default) - { - orgUser = new OrganizationUser - { - OrganizationId = organization.Id, - UserId = ownerId, - Key = ownerKey, - Type = OrganizationUserType.Owner, - Status = OrganizationUserStatusType.Confirmed, - AccessAll = true, - CreationDate = organization.CreationDate, - RevisionDate = organization.CreationDate - }; - - await _organizationUserRepository.CreateAsync(orgUser); - - var deviceIds = await GetUserDeviceIdsAsync(orgUser.UserId.Value); - await _pushRegistrationService.AddUserRegistrationOrganizationAsync(deviceIds, - organization.Id.ToString()); - await _pushNotificationService.PushSyncOrgKeysAsync(ownerId); - } - - return new Tuple(organization, orgUser); - } - catch - { - if (withPayment) - { - await _paymentService.CancelAndRecoverChargesAsync(organization); - } - - if (organization.Id != default(Guid)) - { - await _organizationRepository.DeleteAsync(organization); - await _applicationCacheService.DeleteOrganizationAbilityAsync(organization.Id); - } - - throw; - } - } - - public async Task UpdateLicenseAsync(Guid organizationId, OrganizationLicense license) - { - var organization = await GetOrgById(organizationId); - if (organization == null) - { - throw new NotFoundException(); - } - - if (!_globalSettings.SelfHosted) - { - throw new InvalidOperationException("Licenses require self hosting."); - } - - if (license?.LicenseType != null && license.LicenseType != LicenseType.Organization) - { - throw new BadRequestException("Premium licenses cannot be applied to an organization. " - + "Upload this license from your personal account settings page."); - } - - if (license == null || !_licensingService.VerifyLicense(license)) - { - throw new BadRequestException("Invalid license."); - } - - if (!license.CanUse(_globalSettings)) - { - throw new BadRequestException("Invalid license. Make sure your license allows for on-premise " + - "hosting of organizations and that the installation id matches your current installation."); - } - - var enabledOrgs = await _organizationRepository.GetManyByEnabledAsync(); - if (enabledOrgs.Any(o => o.LicenseKey.Equals(license.LicenseKey) && o.Id != organizationId)) - { - throw new BadRequestException("License is already in use by another organization."); - } - - if (license.Seats.HasValue && - (!organization.Seats.HasValue || organization.Seats.Value > license.Seats.Value)) - { - var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organization.Id); - if (userCount > license.Seats.Value) - { - throw new BadRequestException($"Your organization currently has {userCount} seats filled. " + - $"Your new license only has ({license.Seats.Value}) seats. Remove some users."); - } - } - - if (license.MaxCollections.HasValue && (!organization.MaxCollections.HasValue || - organization.MaxCollections.Value > license.MaxCollections.Value)) - { - var collectionCount = await _collectionRepository.GetCountByOrganizationIdAsync(organization.Id); - if (collectionCount > license.MaxCollections.Value) - { - throw new BadRequestException($"Your organization currently has {collectionCount} collections. " + - $"Your new license allows for a maximum of ({license.MaxCollections.Value}) collections. " + - "Remove some collections."); - } - } - - if (!license.UseGroups && organization.UseGroups) - { - var groups = await _groupRepository.GetManyByOrganizationIdAsync(organization.Id); - if (groups.Count > 0) - { - throw new BadRequestException($"Your organization currently has {groups.Count} groups. " + - $"Your new license does not allow for the use of groups. Remove all groups."); - } - } - - if (!license.UsePolicies && organization.UsePolicies) - { - var policies = await _policyRepository.GetManyByOrganizationIdAsync(organization.Id); - if (policies.Any(p => p.Enabled)) - { - throw new BadRequestException($"Your organization currently has {policies.Count} enabled " + - $"policies. Your new license does not allow for the use of policies. Disable all policies."); - } - } - - if (!license.UseSso && organization.UseSso) - { - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); - if (ssoConfig != null && ssoConfig.Enabled) - { - throw new BadRequestException($"Your organization currently has a SSO configuration. " + - $"Your new license does not allow for the use of SSO. Disable your SSO configuration."); - } - } - - if (!license.UseKeyConnector && organization.UseKeyConnector) - { - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); - if (ssoConfig != null && ssoConfig.GetData().KeyConnectorEnabled) - { - throw new BadRequestException($"Your organization currently has Key Connector enabled. " + - $"Your new license does not allow for the use of Key Connector. Disable your Key Connector."); - } - } - - if (!license.UseScim && organization.UseScim) - { - var scimConnections = await _organizationConnectionRepository.GetByOrganizationIdTypeAsync(organization.Id, - OrganizationConnectionType.Scim); - if (scimConnections != null && scimConnections.Any(c => c.GetConfig()?.Enabled == true)) - { - throw new BadRequestException("Your new plan does not allow the SCIM feature. " + - "Disable your SCIM configuration."); - } - } - - if (!license.UseResetPassword && organization.UseResetPassword) - { - var resetPasswordPolicy = - await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); - if (resetPasswordPolicy != null && resetPasswordPolicy.Enabled) - { - throw new BadRequestException("Your new license does not allow the Password Reset feature. " - + "Disable your Password Reset policy."); - } - } - - var dir = $"{_globalSettings.LicenseDirectory}/organization"; - Directory.CreateDirectory(dir); - await using var fs = new FileStream(Path.Combine(dir, $"{organization.Id}.json"), FileMode.Create); - await JsonSerializer.SerializeAsync(fs, license, JsonHelpers.Indented); - - organization.Name = license.Name; - organization.BusinessName = license.BusinessName; - organization.BillingEmail = license.BillingEmail; - organization.PlanType = license.PlanType; - organization.Seats = license.Seats; - organization.MaxCollections = license.MaxCollections; - organization.UseGroups = license.UseGroups; - organization.UseDirectory = license.UseDirectory; - organization.UseEvents = license.UseEvents; - organization.UseTotp = license.UseTotp; - organization.Use2fa = license.Use2fa; - organization.UseApi = license.UseApi; - organization.UsePolicies = license.UsePolicies; - organization.UseSso = license.UseSso; - organization.UseKeyConnector = license.UseKeyConnector; - organization.UseScim = license.UseScim; - organization.UseResetPassword = license.UseResetPassword; - organization.SelfHost = license.SelfHost; - organization.UsersGetPremium = license.UsersGetPremium; - organization.Plan = license.Plan; - organization.Enabled = license.Enabled; - organization.ExpirationDate = license.Expires; - organization.LicenseKey = license.LicenseKey; - organization.RevisionDate = DateTime.UtcNow; - await ReplaceAndUpdateCache(organization); - } - - public async Task DeleteAsync(Organization organization) - { - await ValidateDeleteOrganizationAsync(organization); - - if (!string.IsNullOrWhiteSpace(organization.GatewaySubscriptionId)) - { - try - { - var eop = !organization.ExpirationDate.HasValue || - organization.ExpirationDate.Value >= DateTime.UtcNow; - await _paymentService.CancelSubscriptionAsync(organization, eop); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.DeleteAccount, organization)); - } - catch (GatewayException) { } - } - - await _organizationRepository.DeleteAsync(organization); - await _applicationCacheService.DeleteOrganizationAbilityAsync(organization.Id); - } - - public async Task EnableAsync(Guid organizationId, DateTime? expirationDate) - { - var org = await GetOrgById(organizationId); - if (org != null && !org.Enabled && org.Gateway.HasValue) - { - org.Enabled = true; - org.ExpirationDate = expirationDate; - org.RevisionDate = DateTime.UtcNow; - await ReplaceAndUpdateCache(org); - } - } - - public async Task DisableAsync(Guid organizationId, DateTime? expirationDate) - { - var org = await GetOrgById(organizationId); - if (org != null && org.Enabled) - { - org.Enabled = false; - org.ExpirationDate = expirationDate; - org.RevisionDate = DateTime.UtcNow; - await ReplaceAndUpdateCache(org); - - // TODO: send email to owners? - } - } - - public async Task UpdateExpirationDateAsync(Guid organizationId, DateTime? expirationDate) - { - var org = await GetOrgById(organizationId); - if (org != null) - { - org.ExpirationDate = expirationDate; - org.RevisionDate = DateTime.UtcNow; - await ReplaceAndUpdateCache(org); - } - } - - public async Task EnableAsync(Guid organizationId) - { - var org = await GetOrgById(organizationId); - if (org != null && !org.Enabled) - { - org.Enabled = true; - await ReplaceAndUpdateCache(org); - } - } - - public async Task UpdateAsync(Organization organization, bool updateBilling = false) - { - if (organization.Id == default(Guid)) - { - throw new ApplicationException("Cannot create org this way. Call SignUpAsync."); - } - - if (!string.IsNullOrWhiteSpace(organization.Identifier)) - { - var orgById = await _organizationRepository.GetByIdentifierAsync(organization.Identifier); - if (orgById != null && orgById.Id != organization.Id) - { - throw new BadRequestException("Identifier already in use by another organization."); - } - } - - await ReplaceAndUpdateCache(organization, EventType.Organization_Updated); - - if (updateBilling && !string.IsNullOrWhiteSpace(organization.GatewayCustomerId)) - { - var customerService = new CustomerService(); - await customerService.UpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions - { - Email = organization.BillingEmail, - Description = organization.BusinessName - }); - } - } - - public async Task UpdateTwoFactorProviderAsync(Organization organization, TwoFactorProviderType type) - { - if (!type.ToString().Contains("Organization")) - { - throw new ArgumentException("Not an organization provider type."); - } - - if (!organization.Use2fa) - { - throw new BadRequestException("Organization cannot use 2FA."); - } - - var providers = organization.GetTwoFactorProviders(); - if (!providers?.ContainsKey(type) ?? true) - { - return; - } - - providers[type].Enabled = true; - organization.SetTwoFactorProviders(providers); - await UpdateAsync(organization); - } - - public async Task DisableTwoFactorProviderAsync(Organization organization, TwoFactorProviderType type) - { - if (!type.ToString().Contains("Organization")) - { - throw new ArgumentException("Not an organization provider type."); - } - - var providers = organization.GetTwoFactorProviders(); - if (!providers?.ContainsKey(type) ?? true) - { - return; - } - - providers.Remove(type); - organization.SetTwoFactorProviders(providers); - await UpdateAsync(organization); - } - - public async Task> InviteUsersAsync(Guid organizationId, Guid? invitingUserId, - IEnumerable<(OrganizationUserInvite invite, string externalId)> invites) - { - var organization = await GetOrgById(organizationId); - var initialSeatCount = organization.Seats; - if (organization == null || invites.Any(i => i.invite.Emails == null)) - { - throw new NotFoundException(); - } - - var inviteTypes = new HashSet(invites.Where(i => i.invite.Type.HasValue) - .Select(i => i.invite.Type.Value)); - if (invitingUserId.HasValue && inviteTypes.Count > 0) - { - foreach (var type in inviteTypes) - { - await ValidateOrganizationUserUpdatePermissions(organizationId, type, null); - } - } - - var newSeatsRequired = 0; - var existingEmails = new HashSet(await _organizationUserRepository.SelectKnownEmailsAsync( - organizationId, invites.SelectMany(i => i.invite.Emails), false), StringComparer.InvariantCultureIgnoreCase); - if (organization.Seats.HasValue) - { - var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organizationId); - var availableSeats = organization.Seats.Value - userCount; - newSeatsRequired = invites.Sum(i => i.invite.Emails.Count()) - existingEmails.Count() - availableSeats; - } - - if (newSeatsRequired > 0) - { - var (canScale, failureReason) = CanScale(organization, newSeatsRequired); - if (!canScale) - { - throw new BadRequestException(failureReason); - } - } - - var invitedAreAllOwners = invites.All(i => i.invite.Type == OrganizationUserType.Owner); - if (!invitedAreAllOwners && !await HasConfirmedOwnersExceptAsync(organizationId, new Guid[] { })) - { - throw new BadRequestException("Organization must have at least one confirmed owner."); - } - - - var orgUsers = new List(); - var limitedCollectionOrgUsers = new List<(OrganizationUser, IEnumerable)>(); - var orgUserInvitedCount = 0; - var exceptions = new List(); - var events = new List<(OrganizationUser, EventType, DateTime?)>(); - foreach (var (invite, externalId) in invites) - { - // Prevent duplicate invitations - foreach (var email in invite.Emails.Distinct()) - { - try - { - // Make sure user is not already invited - if (existingEmails.Contains(email)) - { - continue; - } - - var orgUser = new OrganizationUser - { - OrganizationId = organizationId, - UserId = null, - Email = email.ToLowerInvariant(), - Key = null, - Type = invite.Type.Value, - Status = OrganizationUserStatusType.Invited, - AccessAll = invite.AccessAll, - ExternalId = externalId, - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow, - }; - - if (invite.Permissions != null) - { - orgUser.Permissions = JsonSerializer.Serialize(invite.Permissions, JsonHelpers.CamelCase); - } - - if (!orgUser.AccessAll && invite.Collections.Any()) - { - limitedCollectionOrgUsers.Add((orgUser, invite.Collections)); - } - else - { - orgUsers.Add(orgUser); - } - - events.Add((orgUser, EventType.OrganizationUser_Invited, DateTime.UtcNow)); - orgUserInvitedCount++; - } - catch (Exception e) - { - exceptions.Add(e); - } - } - } - - if (exceptions.Any()) - { - throw new AggregateException("One or more errors occurred while inviting users.", exceptions); - } - - var prorationDate = DateTime.UtcNow; - try - { - await _organizationUserRepository.CreateManyAsync(orgUsers); - foreach (var (orgUser, collections) in limitedCollectionOrgUsers) - { - await _organizationUserRepository.CreateAsync(orgUser, collections); - } - - if (!await _currentContext.ManageUsers(organization.Id)) - { - throw new BadRequestException("Cannot add seats. Cannot manage organization users."); - } - - await AutoAddSeatsAsync(organization, newSeatsRequired, prorationDate); - await SendInvitesAsync(orgUsers.Concat(limitedCollectionOrgUsers.Select(u => u.Item1)), organization); - await _eventService.LogOrganizationUserEventsAsync(events); - - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.InvitedUsers, organization) - { - Users = orgUserInvitedCount - }); - } - catch (Exception e) - { - // Revert any added users. - var invitedOrgUserIds = orgUsers.Select(u => u.Id).Concat(limitedCollectionOrgUsers.Select(u => u.Item1.Id)); - await _organizationUserRepository.DeleteManyAsync(invitedOrgUserIds); - var currentSeatCount = (await _organizationRepository.GetByIdAsync(organization.Id)).Seats; - - if (initialSeatCount.HasValue && currentSeatCount.HasValue && currentSeatCount.Value != initialSeatCount.Value) - { - await AdjustSeatsAsync(organization, initialSeatCount.Value - currentSeatCount.Value, prorationDate); - } - - exceptions.Add(e); - } - - if (exceptions.Any()) - { - throw new AggregateException("One or more errors occurred while inviting users.", exceptions); - } - - return orgUsers; - } - - public async Task>> ResendInvitesAsync(Guid organizationId, Guid? invitingUserId, - IEnumerable organizationUsersId) - { - var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUsersId); - var org = await GetOrgById(organizationId); - - var result = new List>(); - foreach (var orgUser in orgUsers) - { - if (orgUser.Status != OrganizationUserStatusType.Invited || orgUser.OrganizationId != organizationId) - { - result.Add(Tuple.Create(orgUser, "User invalid.")); - continue; - } - - await SendInviteAsync(orgUser, org); - result.Add(Tuple.Create(orgUser, "")); - } - - return result; - } - - public async Task ResendInviteAsync(Guid organizationId, Guid? invitingUserId, Guid organizationUserId) - { - var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); - if (orgUser == null || orgUser.OrganizationId != organizationId || - orgUser.Status != OrganizationUserStatusType.Invited) - { - throw new BadRequestException("User invalid."); - } - - var org = await GetOrgById(orgUser.OrganizationId); - await SendInviteAsync(orgUser, org); - } - - private async Task SendInvitesAsync(IEnumerable orgUsers, Organization organization) - { - string MakeToken(OrganizationUser orgUser) => - _dataProtector.Protect($"OrganizationUserInvite {orgUser.Id} {orgUser.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); - - await _mailService.BulkSendOrganizationInviteEmailAsync(organization.Name, - orgUsers.Select(o => (o, new ExpiringToken(MakeToken(o), DateTime.UtcNow.AddDays(5))))); - } - - private async Task SendInviteAsync(OrganizationUser orgUser, Organization organization) - { - var now = DateTime.UtcNow; - var nowMillis = CoreHelpers.ToEpocMilliseconds(now); - var token = _dataProtector.Protect( - $"OrganizationUserInvite {orgUser.Id} {orgUser.Email} {nowMillis}"); - - await _mailService.SendOrganizationInviteEmailAsync(organization.Name, orgUser, new ExpiringToken(token, now.AddDays(5))); - } - - public async Task AcceptUserAsync(Guid organizationUserId, User user, string token, - IUserService userService) - { - var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); - if (orgUser == null) - { - throw new BadRequestException("User invalid."); - } - - if (!CoreHelpers.UserInviteTokenIsValid(_dataProtector, token, user.Email, orgUser.Id, _globalSettings)) - { - throw new BadRequestException("Invalid token."); - } - - var existingOrgUserCount = await _organizationUserRepository.GetCountByOrganizationAsync( - orgUser.OrganizationId, user.Email, true); - if (existingOrgUserCount > 0) - { - if (orgUser.Status == OrganizationUserStatusType.Accepted) - { - throw new BadRequestException("Invitation already accepted. You will receive an email when your organization membership is confirmed."); - } - throw new BadRequestException("You are already part of this organization."); - } - - if (string.IsNullOrWhiteSpace(orgUser.Email) || - !orgUser.Email.Equals(user.Email, StringComparison.InvariantCultureIgnoreCase)) - { - throw new BadRequestException("User email does not match invite."); - } - - return await AcceptUserAsync(orgUser, user, userService); - } - - public async Task AcceptUserAsync(string orgIdentifier, User user, IUserService userService) - { - var org = await _organizationRepository.GetByIdentifierAsync(orgIdentifier); - if (org == null) - { - throw new BadRequestException("Organization invalid."); - } - - var usersOrgs = await _organizationUserRepository.GetManyByUserAsync(user.Id); - var orgUser = usersOrgs.FirstOrDefault(u => u.OrganizationId == org.Id); - if (orgUser == null) - { - throw new BadRequestException("User not found within organization."); - } - - return await AcceptUserAsync(orgUser, user, userService); - } - - private async Task AcceptUserAsync(OrganizationUser orgUser, User user, - IUserService userService) - { - if (orgUser.Status != OrganizationUserStatusType.Invited) - { - throw new BadRequestException("Already accepted."); - } - - if (orgUser.Type == OrganizationUserType.Owner || orgUser.Type == OrganizationUserType.Admin) - { - var org = await GetOrgById(orgUser.OrganizationId); - if (org.PlanType == PlanType.Free) - { - var adminCount = await _organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync( - user.Id); + // Since free organizations only supports a few users there is not much point in avoiding N+1 queries for this. + var adminCount = await _organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync(user.Id); if (adminCount > 0) { - throw new BadRequestException("You can only be an admin of one free organization."); + throw new BadRequestException("User can only be an admin of one free organization."); } } + + await CheckPolicies(policies, organizationId, user, orgUsers, userService); + orgUser.Status = OrganizationUserStatusType.Confirmed; + orgUser.Key = keys[orgUser.Id]; + orgUser.Email = null; + + await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Confirmed); + await _mailService.SendOrganizationConfirmedEmailAsync(organization.Name, user.Email); + await DeleteAndPushUserRegistrationAsync(organizationId, user.Id); + succeededUsers.Add(orgUser); + result.Add(Tuple.Create(orgUser, "")); } - - // Enforce Single Organization Policy of organization user is trying to join - var allOrgUsers = await _organizationUserRepository.GetManyByUserAsync(user.Id); - var hasOtherOrgs = allOrgUsers.Any(ou => ou.OrganizationId != orgUser.OrganizationId); - var invitedSingleOrgPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(user.Id, - PolicyType.SingleOrg, OrganizationUserStatusType.Invited); - - if (hasOtherOrgs && invitedSingleOrgPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId)) + catch (BadRequestException e) { - throw new BadRequestException("You may not join this organization until you leave or remove " + - "all other organizations."); + result.Add(Tuple.Create(orgUser, e.Message)); } - - // Enforce Single Organization Policy of other organizations user is a member of - var singleOrgPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(user.Id, - PolicyType.SingleOrg); - if (singleOrgPolicyCount > 0) - { - throw new BadRequestException("You cannot join this organization because you are a member of " + - "another organization which forbids it"); - } - - // Enforce Two Factor Authentication Policy of organization user is trying to join - if (!await userService.TwoFactorIsEnabledAsync(user)) - { - var invitedTwoFactorPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(user.Id, - PolicyType.TwoFactorAuthentication, OrganizationUserStatusType.Invited); - if (invitedTwoFactorPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId)) - { - throw new BadRequestException("You cannot join this organization until you enable " + - "two-step login on your user account."); - } - } - - orgUser.Status = OrganizationUserStatusType.Accepted; - orgUser.UserId = user.Id; - orgUser.Email = null; - - await _organizationUserRepository.ReplaceAsync(orgUser); - - var admins = await _organizationUserRepository.GetManyByMinimumRoleAsync(orgUser.OrganizationId, OrganizationUserType.Admin); - var adminEmails = admins.Select(a => a.Email).Distinct().ToList(); - - if (adminEmails.Count > 0) - { - var organization = await _organizationRepository.GetByIdAsync(orgUser.OrganizationId); - await _mailService.SendOrganizationAcceptedEmailAsync(organization, user.Email, adminEmails); - } - - return orgUser; } - public async Task ConfirmUserAsync(Guid organizationId, Guid organizationUserId, string key, - Guid confirmingUserId, IUserService userService) + await _organizationUserRepository.ReplaceManyAsync(succeededUsers); + + return result; + } + + internal (bool canScale, string failureReason) CanScale(Organization organization, + int seatsToAdd) + { + var failureReason = ""; + if (_globalSettings.SelfHosted) { - var result = await ConfirmUsersAsync(organizationId, new Dictionary() { { organizationUserId, key } }, - confirmingUserId, userService); - - if (!result.Any()) - { - throw new BadRequestException("User not valid."); - } - - var (orgUser, error) = result[0]; - if (error != "") - { - throw new BadRequestException(error); - } - return orgUser; + failureReason = "Cannot autoscale on self-hosted instance."; + return (false, failureReason); } - public async Task>> ConfirmUsersAsync(Guid organizationId, Dictionary keys, - Guid confirmingUserId, IUserService userService) + if (seatsToAdd < 1) { - var organizationUsers = await _organizationUserRepository.GetManyAsync(keys.Keys); - var validOrganizationUsers = organizationUsers - .Where(u => u.Status == OrganizationUserStatusType.Accepted && u.OrganizationId == organizationId && u.UserId != null) - .ToList(); - - if (!validOrganizationUsers.Any()) - { - return new List>(); - } - - var validOrganizationUserIds = validOrganizationUsers.Select(u => u.UserId.Value).ToList(); - - var organization = await GetOrgById(organizationId); - var policies = await _policyRepository.GetManyByOrganizationIdAsync(organizationId); - var usersOrgs = await _organizationUserRepository.GetManyByManyUsersAsync(validOrganizationUserIds); - var users = await _userRepository.GetManyAsync(validOrganizationUserIds); - - var keyedFilteredUsers = validOrganizationUsers.ToDictionary(u => u.UserId.Value, u => u); - var keyedOrganizationUsers = usersOrgs.GroupBy(u => u.UserId.Value) - .ToDictionary(u => u.Key, u => u.ToList()); - - var succeededUsers = new List(); - var result = new List>(); - - foreach (var user in users) - { - if (!keyedFilteredUsers.ContainsKey(user.Id)) - { - continue; - } - var orgUser = keyedFilteredUsers[user.Id]; - var orgUsers = keyedOrganizationUsers.GetValueOrDefault(user.Id, new List()); - try - { - if (organization.PlanType == PlanType.Free && (orgUser.Type == OrganizationUserType.Admin - || orgUser.Type == OrganizationUserType.Owner)) - { - // Since free organizations only supports a few users there is not much point in avoiding N+1 queries for this. - var adminCount = await _organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync(user.Id); - if (adminCount > 0) - { - throw new BadRequestException("User can only be an admin of one free organization."); - } - } - - await CheckPolicies(policies, organizationId, user, orgUsers, userService); - orgUser.Status = OrganizationUserStatusType.Confirmed; - orgUser.Key = keys[orgUser.Id]; - orgUser.Email = null; - - await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Confirmed); - await _mailService.SendOrganizationConfirmedEmailAsync(organization.Name, user.Email); - await DeleteAndPushUserRegistrationAsync(organizationId, user.Id); - succeededUsers.Add(orgUser); - result.Add(Tuple.Create(orgUser, "")); - } - catch (BadRequestException e) - { - result.Add(Tuple.Create(orgUser, e.Message)); - } - } - - await _organizationUserRepository.ReplaceManyAsync(succeededUsers); - - return result; - } - - internal (bool canScale, string failureReason) CanScale(Organization organization, - int seatsToAdd) - { - var failureReason = ""; - if (_globalSettings.SelfHosted) - { - failureReason = "Cannot autoscale on self-hosted instance."; - return (false, failureReason); - } - - if (seatsToAdd < 1) - { - return (true, failureReason); - } - - if (organization.Seats.HasValue && - organization.MaxAutoscaleSeats.HasValue && - organization.MaxAutoscaleSeats.Value < organization.Seats.Value + seatsToAdd) - { - return (false, $"Cannot invite new users. Seat limit has been reached."); - } - return (true, failureReason); } - public async Task AutoAddSeatsAsync(Organization organization, int seatsToAdd, DateTime? prorationDate = null) + if (organization.Seats.HasValue && + organization.MaxAutoscaleSeats.HasValue && + organization.MaxAutoscaleSeats.Value < organization.Seats.Value + seatsToAdd) { - if (seatsToAdd < 1 || !organization.Seats.HasValue) - { - return; - } - - var (canScale, failureMessage) = CanScale(organization, seatsToAdd); - if (!canScale) - { - throw new BadRequestException(failureMessage); - } - - var ownerEmails = (await _organizationUserRepository.GetManyByMinimumRoleAsync(organization.Id, - OrganizationUserType.Owner)).Select(u => u.Email).Distinct(); - var initialSeatCount = organization.Seats.Value; - - await AdjustSeatsAsync(organization, seatsToAdd, prorationDate, ownerEmails); - - if (!organization.OwnersNotifiedOfAutoscaling.HasValue) - { - await _mailService.SendOrganizationAutoscaledEmailAsync(organization, initialSeatCount, - ownerEmails); - organization.OwnersNotifiedOfAutoscaling = DateTime.UtcNow; - await _organizationRepository.UpsertAsync(organization); - } + return (false, $"Cannot invite new users. Seat limit has been reached."); } - private async Task CheckPolicies(ICollection policies, Guid organizationId, User user, - ICollection userOrgs, IUserService userService) - { - var usingTwoFactorPolicy = policies.Any(p => p.Type == PolicyType.TwoFactorAuthentication && p.Enabled); - if (usingTwoFactorPolicy && !await userService.TwoFactorIsEnabledAsync(user)) - { - throw new BadRequestException("User does not have two-step login enabled."); - } + return (true, failureReason); + } - var usingSingleOrgPolicy = policies.Any(p => p.Type == PolicyType.SingleOrg && p.Enabled); - if (usingSingleOrgPolicy) - { - if (userOrgs.Any(ou => ou.OrganizationId != organizationId && ou.Status != OrganizationUserStatusType.Invited)) - { - throw new BadRequestException("User is a member of another organization."); - } - } + public async Task AutoAddSeatsAsync(Organization organization, int seatsToAdd, DateTime? prorationDate = null) + { + if (seatsToAdd < 1 || !organization.Seats.HasValue) + { + return; } - public async Task SaveUserAsync(OrganizationUser user, Guid? savingUserId, - IEnumerable collections) + var (canScale, failureMessage) = CanScale(organization, seatsToAdd); + if (!canScale) { - if (user.Id.Equals(default(Guid))) - { - throw new BadRequestException("Invite the user first."); - } - - var originalUser = await _organizationUserRepository.GetByIdAsync(user.Id); - if (user.Equals(originalUser)) - { - throw new BadRequestException("Please make changes before saving."); - } - - if (savingUserId.HasValue) - { - await ValidateOrganizationUserUpdatePermissions(user.OrganizationId, user.Type, originalUser.Type); - } - - if (user.Type != OrganizationUserType.Owner && - !await HasConfirmedOwnersExceptAsync(user.OrganizationId, new[] { user.Id })) - { - throw new BadRequestException("Organization must have at least one confirmed owner."); - } - - if (user.AccessAll) - { - // We don't need any collections if we're flagged to have all access. - collections = new List(); - } - await _organizationUserRepository.ReplaceAsync(user, collections); - await _eventService.LogOrganizationUserEventAsync(user, EventType.OrganizationUser_Updated); + throw new BadRequestException(failureMessage); } - public async Task DeleteUserAsync(Guid organizationId, Guid organizationUserId, Guid? deletingUserId) + var ownerEmails = (await _organizationUserRepository.GetManyByMinimumRoleAsync(organization.Id, + OrganizationUserType.Owner)).Select(u => u.Email).Distinct(); + var initialSeatCount = organization.Seats.Value; + + await AdjustSeatsAsync(organization, seatsToAdd, prorationDate, ownerEmails); + + if (!organization.OwnersNotifiedOfAutoscaling.HasValue) { - var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); - if (orgUser == null || orgUser.OrganizationId != organizationId) - { - throw new BadRequestException("User not valid."); - } - - if (deletingUserId.HasValue && orgUser.UserId == deletingUserId.Value) - { - throw new BadRequestException("You cannot remove yourself."); - } - - if (orgUser.Type == OrganizationUserType.Owner && deletingUserId.HasValue && - !await _currentContext.OrganizationOwner(organizationId)) - { - throw new BadRequestException("Only owners can delete other owners."); - } - - if (!await HasConfirmedOwnersExceptAsync(organizationId, new[] { organizationUserId })) - { - throw new BadRequestException("Organization must have at least one confirmed owner."); - } - - await _organizationUserRepository.DeleteAsync(orgUser); - await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Removed); - - if (orgUser.UserId.HasValue) - { - await DeleteAndPushUserRegistrationAsync(organizationId, orgUser.UserId.Value); - } - } - - public async Task DeleteUserAsync(Guid organizationId, Guid userId) - { - var orgUser = await _organizationUserRepository.GetByOrganizationAsync(organizationId, userId); - if (orgUser == null) - { - throw new NotFoundException(); - } - - if (!await HasConfirmedOwnersExceptAsync(organizationId, new[] { orgUser.Id })) - { - throw new BadRequestException("Organization must have at least one confirmed owner."); - } - - await _organizationUserRepository.DeleteAsync(orgUser); - await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Removed); - - if (orgUser.UserId.HasValue) - { - await DeleteAndPushUserRegistrationAsync(organizationId, orgUser.UserId.Value); - } - } - - public async Task>> DeleteUsersAsync(Guid organizationId, - IEnumerable organizationUsersId, - Guid? deletingUserId) - { - var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUsersId); - var filteredUsers = orgUsers.Where(u => u.OrganizationId == organizationId) - .ToList(); - - if (!filteredUsers.Any()) - { - throw new BadRequestException("Users invalid."); - } - - if (!await HasConfirmedOwnersExceptAsync(organizationId, organizationUsersId)) - { - throw new BadRequestException("Organization must have at least one confirmed owner."); - } - - var deletingUserIsOwner = false; - if (deletingUserId.HasValue) - { - deletingUserIsOwner = await _currentContext.OrganizationOwner(organizationId); - } - - var result = new List>(); - var deletedUserIds = new List(); - foreach (var orgUser in filteredUsers) - { - try - { - if (deletingUserId.HasValue && orgUser.UserId == deletingUserId) - { - throw new BadRequestException("You cannot remove yourself."); - } - - if (orgUser.Type == OrganizationUserType.Owner && deletingUserId.HasValue && !deletingUserIsOwner) - { - throw new BadRequestException("Only owners can delete other owners."); - } - - await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Removed); - - if (orgUser.UserId.HasValue) - { - await DeleteAndPushUserRegistrationAsync(organizationId, orgUser.UserId.Value); - } - result.Add(Tuple.Create(orgUser, "")); - deletedUserIds.Add(orgUser.Id); - } - catch (BadRequestException e) - { - result.Add(Tuple.Create(orgUser, e.Message)); - } - - await _organizationUserRepository.DeleteManyAsync(deletedUserIds); - } - - return result; - } - - public async Task HasConfirmedOwnersExceptAsync(Guid organizationId, IEnumerable organizationUsersId, bool includeProvider = true) - { - var confirmedOwners = await GetConfirmedOwnersAsync(organizationId); - var confirmedOwnersIds = confirmedOwners.Select(u => u.Id); - bool hasOtherOwner = confirmedOwnersIds.Except(organizationUsersId).Any(); - if (!hasOtherOwner && includeProvider) - { - return (await _currentContext.ProviderIdForOrg(organizationId)).HasValue; - } - return hasOtherOwner; - } - - public async Task UpdateUserGroupsAsync(OrganizationUser organizationUser, IEnumerable groupIds, Guid? loggedInUserId) - { - if (loggedInUserId.HasValue) - { - await ValidateOrganizationUserUpdatePermissions(organizationUser.OrganizationId, organizationUser.Type, null); - } - await _organizationUserRepository.UpdateGroupsAsync(organizationUser.Id, groupIds); - await _eventService.LogOrganizationUserEventAsync(organizationUser, - EventType.OrganizationUser_UpdatedGroups); - } - - public async Task UpdateUserResetPasswordEnrollmentAsync(Guid organizationId, Guid userId, string resetPasswordKey, Guid? callingUserId) - { - // Org User must be the same as the calling user and the organization ID associated with the user must match passed org ID - var orgUser = await _organizationUserRepository.GetByOrganizationAsync(organizationId, userId); - if (!callingUserId.HasValue || orgUser == null || orgUser.UserId != callingUserId.Value || - orgUser.OrganizationId != organizationId) - { - throw new BadRequestException("User not valid."); - } - - // Make sure the organization has the ability to use password reset - var org = await _organizationRepository.GetByIdAsync(organizationId); - if (org == null || !org.UseResetPassword) - { - throw new BadRequestException("Organization does not allow password reset enrollment."); - } - - // Make sure the organization has the policy enabled - var resetPasswordPolicy = - await _policyRepository.GetByOrganizationIdTypeAsync(organizationId, PolicyType.ResetPassword); - if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled) - { - throw new BadRequestException("Organization does not have the password reset policy enabled."); - } - - // Block the user from withdrawal if auto enrollment is enabled - if (resetPasswordKey == null && resetPasswordPolicy.Data != null) - { - var data = JsonSerializer.Deserialize(resetPasswordPolicy.Data, JsonHelpers.IgnoreCase); - - if (data?.AutoEnrollEnabled ?? false) - { - throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to withdraw from Password Reset."); - } - } - - orgUser.ResetPasswordKey = resetPasswordKey; - await _organizationUserRepository.ReplaceAsync(orgUser); - await _eventService.LogOrganizationUserEventAsync(orgUser, resetPasswordKey != null ? - EventType.OrganizationUser_ResetPassword_Enroll : EventType.OrganizationUser_ResetPassword_Withdraw); - } - - public async Task GenerateLicenseAsync(Guid organizationId, Guid installationId) - { - var organization = await GetOrgById(organizationId); - return await GenerateLicenseAsync(organization, installationId); - } - - public async Task GenerateLicenseAsync(Organization organization, Guid installationId, - int? version = null) - { - if (organization == null) - { - throw new NotFoundException(); - } - - var installation = await _installationRepository.GetByIdAsync(installationId); - if (installation == null || !installation.Enabled) - { - throw new BadRequestException("Invalid installation id"); - } - - var subInfo = await _paymentService.GetSubscriptionAsync(organization); - return new OrganizationLicense(organization, subInfo, installationId, _licensingService, version); - } - - public async Task InviteUserAsync(Guid organizationId, Guid? invitingUserId, string email, - OrganizationUserType type, bool accessAll, string externalId, IEnumerable collections) - { - var invite = new OrganizationUserInvite() - { - Emails = new List { email }, - Type = type, - AccessAll = accessAll, - Collections = collections, - }; - var results = await InviteUsersAsync(organizationId, invitingUserId, - new (OrganizationUserInvite, string)[] { (invite, externalId) }); - var result = results.FirstOrDefault(); - if (result == null) - { - throw new BadRequestException("This user has already been invited."); - } - return result; - } - - public async Task ImportAsync(Guid organizationId, - Guid? importingUserId, - IEnumerable groups, - IEnumerable newUsers, - IEnumerable removeUserExternalIds, - bool overwriteExisting) - { - var organization = await GetOrgById(organizationId); - if (organization == null) - { - throw new NotFoundException(); - } - - if (!organization.UseDirectory) - { - throw new BadRequestException("Organization cannot use directory syncing."); - } - - var newUsersSet = new HashSet(newUsers?.Select(u => u.ExternalId) ?? new List()); - var existingUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); - var existingExternalUsers = existingUsers.Where(u => !string.IsNullOrWhiteSpace(u.ExternalId)).ToList(); - var existingExternalUsersIdDict = existingExternalUsers.ToDictionary(u => u.ExternalId, u => u.Id); - - // Users - - // Remove Users - if (removeUserExternalIds?.Any() ?? false) - { - var removeUsersSet = new HashSet(removeUserExternalIds); - var existingUsersDict = existingExternalUsers.ToDictionary(u => u.ExternalId); - - await _organizationUserRepository.DeleteManyAsync(removeUsersSet - .Except(newUsersSet) - .Where(u => existingUsersDict.ContainsKey(u) && existingUsersDict[u].Type != OrganizationUserType.Owner) - .Select(u => existingUsersDict[u].Id)); - } - - if (overwriteExisting) - { - // Remove existing external users that are not in new user set - var usersToDelete = existingExternalUsers.Where(u => - u.Type != OrganizationUserType.Owner && - !newUsersSet.Contains(u.ExternalId) && - existingExternalUsersIdDict.ContainsKey(u.ExternalId)); - await _organizationUserRepository.DeleteManyAsync(usersToDelete.Select(u => u.Id)); - foreach (var deletedUser in usersToDelete) - { - existingExternalUsersIdDict.Remove(deletedUser.ExternalId); - } - } - - if (newUsers?.Any() ?? false) - { - // Marry existing users - var existingUsersEmailsDict = existingUsers - .Where(u => string.IsNullOrWhiteSpace(u.ExternalId)) - .ToDictionary(u => u.Email); - var newUsersEmailsDict = newUsers.ToDictionary(u => u.Email); - var usersToAttach = existingUsersEmailsDict.Keys.Intersect(newUsersEmailsDict.Keys).ToList(); - var usersToUpsert = new List(); - foreach (var user in usersToAttach) - { - var orgUserDetails = existingUsersEmailsDict[user]; - var orgUser = await _organizationUserRepository.GetByIdAsync(orgUserDetails.Id); - if (orgUser != null) - { - orgUser.ExternalId = newUsersEmailsDict[user].ExternalId; - usersToUpsert.Add(orgUser); - existingExternalUsersIdDict.Add(orgUser.ExternalId, orgUser.Id); - } - } - await _organizationUserRepository.UpsertManyAsync(usersToUpsert); - - // Add new users - var existingUsersSet = new HashSet(existingExternalUsersIdDict.Keys); - var usersToAdd = newUsersSet.Except(existingUsersSet).ToList(); - - var seatsAvailable = int.MaxValue; - var enoughSeatsAvailable = true; - if (organization.Seats.HasValue) - { - var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organizationId); - seatsAvailable = organization.Seats.Value - userCount; - enoughSeatsAvailable = seatsAvailable >= usersToAdd.Count; - } - - var userInvites = new List<(OrganizationUserInvite, string)>(); - foreach (var user in newUsers) - { - if (!usersToAdd.Contains(user.ExternalId) || string.IsNullOrWhiteSpace(user.Email)) - { - continue; - } - - try - { - var invite = new OrganizationUserInvite - { - Emails = new List { user.Email }, - Type = OrganizationUserType.User, - AccessAll = false, - Collections = new List(), - }; - userInvites.Add((invite, user.ExternalId)); - } - catch (BadRequestException) - { - // Thrown when the user is already invited to the organization - continue; - } - } - - var invitedUsers = await InviteUsersAsync(organizationId, importingUserId, userInvites); - foreach (var invitedUser in invitedUsers) - { - existingExternalUsersIdDict.Add(invitedUser.ExternalId, invitedUser.Id); - } - } - - - // Groups - if (groups?.Any() ?? false) - { - if (!organization.UseGroups) - { - throw new BadRequestException("Organization cannot use groups."); - } - - var groupsDict = groups.ToDictionary(g => g.Group.ExternalId); - var existingGroups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId); - var existingExternalGroups = existingGroups - .Where(u => !string.IsNullOrWhiteSpace(u.ExternalId)).ToList(); - var existingExternalGroupsDict = existingExternalGroups.ToDictionary(g => g.ExternalId); - - var newGroups = groups - .Where(g => !existingExternalGroupsDict.ContainsKey(g.Group.ExternalId)) - .Select(g => g.Group); - - foreach (var group in newGroups) - { - group.CreationDate = group.RevisionDate = DateTime.UtcNow; - - await _groupRepository.CreateAsync(group); - await UpdateUsersAsync(group, groupsDict[group.ExternalId].ExternalUserIds, - existingExternalUsersIdDict); - } - - var updateGroups = existingExternalGroups - .Where(g => groupsDict.ContainsKey(g.ExternalId)) - .ToList(); - - if (updateGroups.Any()) - { - var groupUsers = await _groupRepository.GetManyGroupUsersByOrganizationIdAsync(organizationId); - var existingGroupUsers = groupUsers - .GroupBy(gu => gu.GroupId) - .ToDictionary(g => g.Key, g => new HashSet(g.Select(gr => gr.OrganizationUserId))); - - foreach (var group in updateGroups) - { - var updatedGroup = groupsDict[group.ExternalId].Group; - if (group.Name != updatedGroup.Name) - { - group.RevisionDate = DateTime.UtcNow; - group.Name = updatedGroup.Name; - - await _groupRepository.ReplaceAsync(group); - } - - await UpdateUsersAsync(group, groupsDict[group.ExternalId].ExternalUserIds, - existingExternalUsersIdDict, - existingGroupUsers.ContainsKey(group.Id) ? existingGroupUsers[group.Id] : null); - } - } - } - - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.DirectorySynced, organization)); - } - - public async Task DeleteSsoUserAsync(Guid userId, Guid? organizationId) - { - await _ssoUserRepository.DeleteAsync(userId, organizationId); - if (organizationId.HasValue) - { - var organizationUser = await _organizationUserRepository.GetByOrganizationAsync(organizationId.Value, userId); - if (organizationUser != null) - { - await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_UnlinkedSso); - } - } - } - - public async Task UpdateOrganizationKeysAsync(Guid orgId, string publicKey, string privateKey) - { - if (!await _currentContext.ManageResetPassword(orgId)) - { - throw new UnauthorizedAccessException(); - } - - // If the keys already exist, error out - var org = await _organizationRepository.GetByIdAsync(orgId); - if (org.PublicKey != null && org.PrivateKey != null) - { - throw new BadRequestException("Organization Keys already exist"); - } - - // Update org with generated public/private key - org.PublicKey = publicKey; - org.PrivateKey = privateKey; - await UpdateAsync(org); - - return org; - } - - private async Task UpdateUsersAsync(Group group, HashSet groupUsers, - Dictionary existingUsersIdDict, HashSet existingUsers = null) - { - var availableUsers = groupUsers.Intersect(existingUsersIdDict.Keys); - var users = new HashSet(availableUsers.Select(u => existingUsersIdDict[u])); - if (existingUsers != null && existingUsers.Count == users.Count && users.SetEquals(existingUsers)) - { - return; - } - - await _groupRepository.UpdateUsersAsync(group.Id, users); - } - - private async Task> GetConfirmedOwnersAsync(Guid organizationId) - { - var owners = await _organizationUserRepository.GetManyByOrganizationAsync(organizationId, - OrganizationUserType.Owner); - return owners.Where(o => o.Status == OrganizationUserStatusType.Confirmed); - } - - private async Task DeleteAndPushUserRegistrationAsync(Guid organizationId, Guid userId) - { - var deviceIds = await GetUserDeviceIdsAsync(userId); - await _pushRegistrationService.DeleteUserRegistrationOrganizationAsync(deviceIds, - organizationId.ToString()); - await _pushNotificationService.PushSyncOrgKeysAsync(userId); - } - - - private async Task> GetUserDeviceIdsAsync(Guid userId) - { - var devices = await _deviceRepository.GetManyByUserIdAsync(userId); - return devices.Where(d => !string.IsNullOrWhiteSpace(d.PushToken)).Select(d => d.Id.ToString()); - } - - private async Task ReplaceAndUpdateCache(Organization org, EventType? orgEvent = null) - { - await _organizationRepository.ReplaceAsync(org); - await _applicationCacheService.UpsertOrganizationAbilityAsync(org); - - if (orgEvent.HasValue) - { - await _eventService.LogOrganizationEventAsync(org, orgEvent.Value); - } - } - - private async Task GetOrgById(Guid id) - { - return await _organizationRepository.GetByIdAsync(id); - } - - private void ValidateOrganizationUpgradeParameters(Models.StaticStore.Plan plan, OrganizationUpgrade upgrade) - { - if (!plan.HasAdditionalStorageOption && upgrade.AdditionalStorageGb > 0) - { - throw new BadRequestException("Plan does not allow additional storage."); - } - - if (upgrade.AdditionalStorageGb < 0) - { - throw new BadRequestException("You can't subtract storage!"); - } - - if (!plan.HasPremiumAccessOption && upgrade.PremiumAccessAddon) - { - throw new BadRequestException("This plan does not allow you to buy the premium access addon."); - } - - if (plan.BaseSeats + upgrade.AdditionalSeats <= 0) - { - throw new BadRequestException("You do not have any seats!"); - } - - if (upgrade.AdditionalSeats < 0) - { - throw new BadRequestException("You can't subtract seats!"); - } - - if (!plan.HasAdditionalSeatsOption && upgrade.AdditionalSeats > 0) - { - throw new BadRequestException("Plan does not allow additional users."); - } - - if (plan.HasAdditionalSeatsOption && plan.MaxAdditionalSeats.HasValue && - upgrade.AdditionalSeats > plan.MaxAdditionalSeats.Value) - { - throw new BadRequestException($"Selected plan allows a maximum of " + - $"{plan.MaxAdditionalSeats.GetValueOrDefault(0)} additional users."); - } - } - - private async Task ValidateOrganizationUserUpdatePermissions(Guid organizationId, OrganizationUserType newType, - OrganizationUserType? oldType) - { - if (await _currentContext.OrganizationOwner(organizationId)) - { - return; - } - - if (oldType == OrganizationUserType.Owner || newType == OrganizationUserType.Owner) - { - throw new BadRequestException("Only an Owner can configure another Owner's account."); - } - - if (await _currentContext.OrganizationAdmin(organizationId)) - { - return; - } - - if (oldType == OrganizationUserType.Custom || newType == OrganizationUserType.Custom) - { - throw new BadRequestException("Only Owners and Admins can configure Custom accounts."); - } - - if (!await _currentContext.ManageUsers(organizationId)) - { - throw new BadRequestException("Your account does not have permission to manage users."); - } - - if (oldType == OrganizationUserType.Admin || newType == OrganizationUserType.Admin) - { - throw new BadRequestException("Custom users can not manage Admins or Owners."); - } - } - - private async Task ValidateDeleteOrganizationAsync(Organization organization) - { - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); - if (ssoConfig?.GetData()?.KeyConnectorEnabled == true) - { - throw new BadRequestException("You cannot delete an Organization that is using Key Connector."); - } - } - - public async Task RevokeUserAsync(OrganizationUser organizationUser, Guid? revokingUserId) - { - if (organizationUser.Status == OrganizationUserStatusType.Revoked) - { - throw new BadRequestException("Already revoked."); - } - - if (revokingUserId.HasValue && organizationUser.UserId == revokingUserId.Value) - { - throw new BadRequestException("You cannot revoke yourself."); - } - - if (organizationUser.Type == OrganizationUserType.Owner && revokingUserId.HasValue && - !await _currentContext.OrganizationOwner(organizationUser.OrganizationId)) - { - throw new BadRequestException("Only owners can revoke other owners."); - } - - if (!await HasConfirmedOwnersExceptAsync(organizationUser.OrganizationId, new[] { organizationUser.Id })) - { - throw new BadRequestException("Organization must have at least one confirmed owner."); - } - - await _organizationUserRepository.RevokeAsync(organizationUser.Id); - organizationUser.Status = OrganizationUserStatusType.Revoked; - await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Revoked); - } - - public async Task>> RevokeUsersAsync(Guid organizationId, - IEnumerable organizationUserIds, Guid? revokingUserId) - { - var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUserIds); - var filteredUsers = orgUsers.Where(u => u.OrganizationId == organizationId) - .ToList(); - - if (!filteredUsers.Any()) - { - throw new BadRequestException("Users invalid."); - } - - if (!await HasConfirmedOwnersExceptAsync(organizationId, organizationUserIds)) - { - throw new BadRequestException("Organization must have at least one confirmed owner."); - } - - var deletingUserIsOwner = false; - if (revokingUserId.HasValue) - { - deletingUserIsOwner = await _currentContext.OrganizationOwner(organizationId); - } - - var result = new List>(); - - foreach (var organizationUser in filteredUsers) - { - try - { - if (organizationUser.Status == OrganizationUserStatusType.Revoked) - { - throw new BadRequestException("Already revoked."); - } - - if (revokingUserId.HasValue && organizationUser.UserId == revokingUserId) - { - throw new BadRequestException("You cannot revoke yourself."); - } - - if (organizationUser.Type == OrganizationUserType.Owner && revokingUserId.HasValue && !deletingUserIsOwner) - { - throw new BadRequestException("Only owners can revoke other owners."); - } - - await _organizationUserRepository.RevokeAsync(organizationUser.Id); - organizationUser.Status = OrganizationUserStatusType.Revoked; - await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Revoked); - - result.Add(Tuple.Create(organizationUser, "")); - } - catch (BadRequestException e) - { - result.Add(Tuple.Create(organizationUser, e.Message)); - } - } - - return result; - } - - public async Task RestoreUserAsync(OrganizationUser organizationUser, Guid? restoringUserId, IUserService userService) - { - if (organizationUser.Status != OrganizationUserStatusType.Revoked) - { - throw new BadRequestException("Already active."); - } - - if (restoringUserId.HasValue && organizationUser.UserId == restoringUserId.Value) - { - throw new BadRequestException("You cannot restore yourself."); - } - - if (organizationUser.Type == OrganizationUserType.Owner && restoringUserId.HasValue && - !await _currentContext.OrganizationOwner(organizationUser.OrganizationId)) - { - throw new BadRequestException("Only owners can restore other owners."); - } - - await CheckPoliciesBeforeRestoreAsync(organizationUser, userService); - - var status = GetPriorActiveOrganizationUserStatusType(organizationUser); - - await _organizationUserRepository.RestoreAsync(organizationUser.Id, status); - organizationUser.Status = status; - await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored); - } - - public async Task>> RestoreUsersAsync(Guid organizationId, - IEnumerable organizationUserIds, Guid? restoringUserId, IUserService userService) - { - var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUserIds); - var filteredUsers = orgUsers.Where(u => u.OrganizationId == organizationId) - .ToList(); - - if (!filteredUsers.Any()) - { - throw new BadRequestException("Users invalid."); - } - - var deletingUserIsOwner = false; - if (restoringUserId.HasValue) - { - deletingUserIsOwner = await _currentContext.OrganizationOwner(organizationId); - } - - var result = new List>(); - - foreach (var organizationUser in filteredUsers) - { - try - { - if (organizationUser.Status != OrganizationUserStatusType.Revoked) - { - throw new BadRequestException("Already active."); - } - - if (restoringUserId.HasValue && organizationUser.UserId == restoringUserId) - { - throw new BadRequestException("You cannot restore yourself."); - } - - if (organizationUser.Type == OrganizationUserType.Owner && restoringUserId.HasValue && !deletingUserIsOwner) - { - throw new BadRequestException("Only owners can restore other owners."); - } - - await CheckPoliciesBeforeRestoreAsync(organizationUser, userService); - - var status = GetPriorActiveOrganizationUserStatusType(organizationUser); - - await _organizationUserRepository.RestoreAsync(organizationUser.Id, status); - organizationUser.Status = status; - await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored); - - result.Add(Tuple.Create(organizationUser, "")); - } - catch (BadRequestException e) - { - result.Add(Tuple.Create(organizationUser, e.Message)); - } - } - - return result; - } - - private async Task CheckPoliciesBeforeRestoreAsync(OrganizationUser orgUser, IUserService userService) - { - // An invited OrganizationUser isn't linked with a user account yet, so these checks are irrelevant - // The user will be subject to the same checks when they try to accept the invite - if (GetPriorActiveOrganizationUserStatusType(orgUser) == OrganizationUserStatusType.Invited) - { - return; - } - - var userId = orgUser.UserId.Value; - - // Enforce Single Organization Policy of organization user is being restored to - var allOrgUsers = await _organizationUserRepository.GetManyByUserAsync(userId); - var hasOtherOrgs = allOrgUsers.Any(ou => ou.OrganizationId != orgUser.OrganizationId); - var singleOrgPoliciesApplyingToRevokedUsers = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(userId, - PolicyType.SingleOrg, OrganizationUserStatusType.Revoked); - var singleOrgPolicyApplies = singleOrgPoliciesApplyingToRevokedUsers.Any(p => p.OrganizationId == orgUser.OrganizationId); - - if (hasOtherOrgs && singleOrgPolicyApplies) - { - throw new BadRequestException("You cannot restore this user until " + - "they leave or remove all other organizations."); - } - - // Enforce Single Organization Policy of other organizations user is a member of - var singleOrgPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(userId, - PolicyType.SingleOrg); - if (singleOrgPolicyCount > 0) - { - throw new BadRequestException("You cannot restore this user because they are a member of " + - "another organization which forbids it"); - } - - // Enforce Two Factor Authentication Policy of organization user is trying to join - var user = await _userRepository.GetByIdAsync(userId); - if (!await userService.TwoFactorIsEnabledAsync(user)) - { - var invitedTwoFactorPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(userId, - PolicyType.TwoFactorAuthentication, OrganizationUserStatusType.Invited); - if (invitedTwoFactorPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId)) - { - throw new BadRequestException("You cannot restore this user until they enable " + - "two-step login on their user account."); - } - } - } - - static OrganizationUserStatusType GetPriorActiveOrganizationUserStatusType(OrganizationUser organizationUser) - { - // Determine status to revert back to - var status = OrganizationUserStatusType.Invited; - if (organizationUser.UserId.HasValue && string.IsNullOrWhiteSpace(organizationUser.Email)) - { - // Has UserId & Email is null, then Accepted - status = OrganizationUserStatusType.Accepted; - if (!string.IsNullOrWhiteSpace(organizationUser.Key)) - { - // We have an org key for this user, user was confirmed - status = OrganizationUserStatusType.Confirmed; - } - } - - return status; + await _mailService.SendOrganizationAutoscaledEmailAsync(organization, initialSeatCount, + ownerEmails); + organization.OwnersNotifiedOfAutoscaling = DateTime.UtcNow; + await _organizationRepository.UpsertAsync(organization); } } + + private async Task CheckPolicies(ICollection policies, Guid organizationId, User user, + ICollection userOrgs, IUserService userService) + { + var usingTwoFactorPolicy = policies.Any(p => p.Type == PolicyType.TwoFactorAuthentication && p.Enabled); + if (usingTwoFactorPolicy && !await userService.TwoFactorIsEnabledAsync(user)) + { + throw new BadRequestException("User does not have two-step login enabled."); + } + + var usingSingleOrgPolicy = policies.Any(p => p.Type == PolicyType.SingleOrg && p.Enabled); + if (usingSingleOrgPolicy) + { + if (userOrgs.Any(ou => ou.OrganizationId != organizationId && ou.Status != OrganizationUserStatusType.Invited)) + { + throw new BadRequestException("User is a member of another organization."); + } + } + } + + public async Task SaveUserAsync(OrganizationUser user, Guid? savingUserId, + IEnumerable collections) + { + if (user.Id.Equals(default(Guid))) + { + throw new BadRequestException("Invite the user first."); + } + + var originalUser = await _organizationUserRepository.GetByIdAsync(user.Id); + if (user.Equals(originalUser)) + { + throw new BadRequestException("Please make changes before saving."); + } + + if (savingUserId.HasValue) + { + await ValidateOrganizationUserUpdatePermissions(user.OrganizationId, user.Type, originalUser.Type); + } + + if (user.Type != OrganizationUserType.Owner && + !await HasConfirmedOwnersExceptAsync(user.OrganizationId, new[] { user.Id })) + { + throw new BadRequestException("Organization must have at least one confirmed owner."); + } + + if (user.AccessAll) + { + // We don't need any collections if we're flagged to have all access. + collections = new List(); + } + await _organizationUserRepository.ReplaceAsync(user, collections); + await _eventService.LogOrganizationUserEventAsync(user, EventType.OrganizationUser_Updated); + } + + public async Task DeleteUserAsync(Guid organizationId, Guid organizationUserId, Guid? deletingUserId) + { + var orgUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); + if (orgUser == null || orgUser.OrganizationId != organizationId) + { + throw new BadRequestException("User not valid."); + } + + if (deletingUserId.HasValue && orgUser.UserId == deletingUserId.Value) + { + throw new BadRequestException("You cannot remove yourself."); + } + + if (orgUser.Type == OrganizationUserType.Owner && deletingUserId.HasValue && + !await _currentContext.OrganizationOwner(organizationId)) + { + throw new BadRequestException("Only owners can delete other owners."); + } + + if (!await HasConfirmedOwnersExceptAsync(organizationId, new[] { organizationUserId })) + { + throw new BadRequestException("Organization must have at least one confirmed owner."); + } + + await _organizationUserRepository.DeleteAsync(orgUser); + await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Removed); + + if (orgUser.UserId.HasValue) + { + await DeleteAndPushUserRegistrationAsync(organizationId, orgUser.UserId.Value); + } + } + + public async Task DeleteUserAsync(Guid organizationId, Guid userId) + { + var orgUser = await _organizationUserRepository.GetByOrganizationAsync(organizationId, userId); + if (orgUser == null) + { + throw new NotFoundException(); + } + + if (!await HasConfirmedOwnersExceptAsync(organizationId, new[] { orgUser.Id })) + { + throw new BadRequestException("Organization must have at least one confirmed owner."); + } + + await _organizationUserRepository.DeleteAsync(orgUser); + await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Removed); + + if (orgUser.UserId.HasValue) + { + await DeleteAndPushUserRegistrationAsync(organizationId, orgUser.UserId.Value); + } + } + + public async Task>> DeleteUsersAsync(Guid organizationId, + IEnumerable organizationUsersId, + Guid? deletingUserId) + { + var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUsersId); + var filteredUsers = orgUsers.Where(u => u.OrganizationId == organizationId) + .ToList(); + + if (!filteredUsers.Any()) + { + throw new BadRequestException("Users invalid."); + } + + if (!await HasConfirmedOwnersExceptAsync(organizationId, organizationUsersId)) + { + throw new BadRequestException("Organization must have at least one confirmed owner."); + } + + var deletingUserIsOwner = false; + if (deletingUserId.HasValue) + { + deletingUserIsOwner = await _currentContext.OrganizationOwner(organizationId); + } + + var result = new List>(); + var deletedUserIds = new List(); + foreach (var orgUser in filteredUsers) + { + try + { + if (deletingUserId.HasValue && orgUser.UserId == deletingUserId) + { + throw new BadRequestException("You cannot remove yourself."); + } + + if (orgUser.Type == OrganizationUserType.Owner && deletingUserId.HasValue && !deletingUserIsOwner) + { + throw new BadRequestException("Only owners can delete other owners."); + } + + await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Removed); + + if (orgUser.UserId.HasValue) + { + await DeleteAndPushUserRegistrationAsync(organizationId, orgUser.UserId.Value); + } + result.Add(Tuple.Create(orgUser, "")); + deletedUserIds.Add(orgUser.Id); + } + catch (BadRequestException e) + { + result.Add(Tuple.Create(orgUser, e.Message)); + } + + await _organizationUserRepository.DeleteManyAsync(deletedUserIds); + } + + return result; + } + + public async Task HasConfirmedOwnersExceptAsync(Guid organizationId, IEnumerable organizationUsersId, bool includeProvider = true) + { + var confirmedOwners = await GetConfirmedOwnersAsync(organizationId); + var confirmedOwnersIds = confirmedOwners.Select(u => u.Id); + bool hasOtherOwner = confirmedOwnersIds.Except(organizationUsersId).Any(); + if (!hasOtherOwner && includeProvider) + { + return (await _currentContext.ProviderIdForOrg(organizationId)).HasValue; + } + return hasOtherOwner; + } + + public async Task UpdateUserGroupsAsync(OrganizationUser organizationUser, IEnumerable groupIds, Guid? loggedInUserId) + { + if (loggedInUserId.HasValue) + { + await ValidateOrganizationUserUpdatePermissions(organizationUser.OrganizationId, organizationUser.Type, null); + } + await _organizationUserRepository.UpdateGroupsAsync(organizationUser.Id, groupIds); + await _eventService.LogOrganizationUserEventAsync(organizationUser, + EventType.OrganizationUser_UpdatedGroups); + } + + public async Task UpdateUserResetPasswordEnrollmentAsync(Guid organizationId, Guid userId, string resetPasswordKey, Guid? callingUserId) + { + // Org User must be the same as the calling user and the organization ID associated with the user must match passed org ID + var orgUser = await _organizationUserRepository.GetByOrganizationAsync(organizationId, userId); + if (!callingUserId.HasValue || orgUser == null || orgUser.UserId != callingUserId.Value || + orgUser.OrganizationId != organizationId) + { + throw new BadRequestException("User not valid."); + } + + // Make sure the organization has the ability to use password reset + var org = await _organizationRepository.GetByIdAsync(organizationId); + if (org == null || !org.UseResetPassword) + { + throw new BadRequestException("Organization does not allow password reset enrollment."); + } + + // Make sure the organization has the policy enabled + var resetPasswordPolicy = + await _policyRepository.GetByOrganizationIdTypeAsync(organizationId, PolicyType.ResetPassword); + if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled) + { + throw new BadRequestException("Organization does not have the password reset policy enabled."); + } + + // Block the user from withdrawal if auto enrollment is enabled + if (resetPasswordKey == null && resetPasswordPolicy.Data != null) + { + var data = JsonSerializer.Deserialize(resetPasswordPolicy.Data, JsonHelpers.IgnoreCase); + + if (data?.AutoEnrollEnabled ?? false) + { + throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to withdraw from Password Reset."); + } + } + + orgUser.ResetPasswordKey = resetPasswordKey; + await _organizationUserRepository.ReplaceAsync(orgUser); + await _eventService.LogOrganizationUserEventAsync(orgUser, resetPasswordKey != null ? + EventType.OrganizationUser_ResetPassword_Enroll : EventType.OrganizationUser_ResetPassword_Withdraw); + } + + public async Task GenerateLicenseAsync(Guid organizationId, Guid installationId) + { + var organization = await GetOrgById(organizationId); + return await GenerateLicenseAsync(organization, installationId); + } + + public async Task GenerateLicenseAsync(Organization organization, Guid installationId, + int? version = null) + { + if (organization == null) + { + throw new NotFoundException(); + } + + var installation = await _installationRepository.GetByIdAsync(installationId); + if (installation == null || !installation.Enabled) + { + throw new BadRequestException("Invalid installation id"); + } + + var subInfo = await _paymentService.GetSubscriptionAsync(organization); + return new OrganizationLicense(organization, subInfo, installationId, _licensingService, version); + } + + public async Task InviteUserAsync(Guid organizationId, Guid? invitingUserId, string email, + OrganizationUserType type, bool accessAll, string externalId, IEnumerable collections) + { + var invite = new OrganizationUserInvite() + { + Emails = new List { email }, + Type = type, + AccessAll = accessAll, + Collections = collections, + }; + var results = await InviteUsersAsync(organizationId, invitingUserId, + new (OrganizationUserInvite, string)[] { (invite, externalId) }); + var result = results.FirstOrDefault(); + if (result == null) + { + throw new BadRequestException("This user has already been invited."); + } + return result; + } + + public async Task ImportAsync(Guid organizationId, + Guid? importingUserId, + IEnumerable groups, + IEnumerable newUsers, + IEnumerable removeUserExternalIds, + bool overwriteExisting) + { + var organization = await GetOrgById(organizationId); + if (organization == null) + { + throw new NotFoundException(); + } + + if (!organization.UseDirectory) + { + throw new BadRequestException("Organization cannot use directory syncing."); + } + + var newUsersSet = new HashSet(newUsers?.Select(u => u.ExternalId) ?? new List()); + var existingUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); + var existingExternalUsers = existingUsers.Where(u => !string.IsNullOrWhiteSpace(u.ExternalId)).ToList(); + var existingExternalUsersIdDict = existingExternalUsers.ToDictionary(u => u.ExternalId, u => u.Id); + + // Users + + // Remove Users + if (removeUserExternalIds?.Any() ?? false) + { + var removeUsersSet = new HashSet(removeUserExternalIds); + var existingUsersDict = existingExternalUsers.ToDictionary(u => u.ExternalId); + + await _organizationUserRepository.DeleteManyAsync(removeUsersSet + .Except(newUsersSet) + .Where(u => existingUsersDict.ContainsKey(u) && existingUsersDict[u].Type != OrganizationUserType.Owner) + .Select(u => existingUsersDict[u].Id)); + } + + if (overwriteExisting) + { + // Remove existing external users that are not in new user set + var usersToDelete = existingExternalUsers.Where(u => + u.Type != OrganizationUserType.Owner && + !newUsersSet.Contains(u.ExternalId) && + existingExternalUsersIdDict.ContainsKey(u.ExternalId)); + await _organizationUserRepository.DeleteManyAsync(usersToDelete.Select(u => u.Id)); + foreach (var deletedUser in usersToDelete) + { + existingExternalUsersIdDict.Remove(deletedUser.ExternalId); + } + } + + if (newUsers?.Any() ?? false) + { + // Marry existing users + var existingUsersEmailsDict = existingUsers + .Where(u => string.IsNullOrWhiteSpace(u.ExternalId)) + .ToDictionary(u => u.Email); + var newUsersEmailsDict = newUsers.ToDictionary(u => u.Email); + var usersToAttach = existingUsersEmailsDict.Keys.Intersect(newUsersEmailsDict.Keys).ToList(); + var usersToUpsert = new List(); + foreach (var user in usersToAttach) + { + var orgUserDetails = existingUsersEmailsDict[user]; + var orgUser = await _organizationUserRepository.GetByIdAsync(orgUserDetails.Id); + if (orgUser != null) + { + orgUser.ExternalId = newUsersEmailsDict[user].ExternalId; + usersToUpsert.Add(orgUser); + existingExternalUsersIdDict.Add(orgUser.ExternalId, orgUser.Id); + } + } + await _organizationUserRepository.UpsertManyAsync(usersToUpsert); + + // Add new users + var existingUsersSet = new HashSet(existingExternalUsersIdDict.Keys); + var usersToAdd = newUsersSet.Except(existingUsersSet).ToList(); + + var seatsAvailable = int.MaxValue; + var enoughSeatsAvailable = true; + if (organization.Seats.HasValue) + { + var userCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(organizationId); + seatsAvailable = organization.Seats.Value - userCount; + enoughSeatsAvailable = seatsAvailable >= usersToAdd.Count; + } + + var userInvites = new List<(OrganizationUserInvite, string)>(); + foreach (var user in newUsers) + { + if (!usersToAdd.Contains(user.ExternalId) || string.IsNullOrWhiteSpace(user.Email)) + { + continue; + } + + try + { + var invite = new OrganizationUserInvite + { + Emails = new List { user.Email }, + Type = OrganizationUserType.User, + AccessAll = false, + Collections = new List(), + }; + userInvites.Add((invite, user.ExternalId)); + } + catch (BadRequestException) + { + // Thrown when the user is already invited to the organization + continue; + } + } + + var invitedUsers = await InviteUsersAsync(organizationId, importingUserId, userInvites); + foreach (var invitedUser in invitedUsers) + { + existingExternalUsersIdDict.Add(invitedUser.ExternalId, invitedUser.Id); + } + } + + + // Groups + if (groups?.Any() ?? false) + { + if (!organization.UseGroups) + { + throw new BadRequestException("Organization cannot use groups."); + } + + var groupsDict = groups.ToDictionary(g => g.Group.ExternalId); + var existingGroups = await _groupRepository.GetManyByOrganizationIdAsync(organizationId); + var existingExternalGroups = existingGroups + .Where(u => !string.IsNullOrWhiteSpace(u.ExternalId)).ToList(); + var existingExternalGroupsDict = existingExternalGroups.ToDictionary(g => g.ExternalId); + + var newGroups = groups + .Where(g => !existingExternalGroupsDict.ContainsKey(g.Group.ExternalId)) + .Select(g => g.Group); + + foreach (var group in newGroups) + { + group.CreationDate = group.RevisionDate = DateTime.UtcNow; + + await _groupRepository.CreateAsync(group); + await UpdateUsersAsync(group, groupsDict[group.ExternalId].ExternalUserIds, + existingExternalUsersIdDict); + } + + var updateGroups = existingExternalGroups + .Where(g => groupsDict.ContainsKey(g.ExternalId)) + .ToList(); + + if (updateGroups.Any()) + { + var groupUsers = await _groupRepository.GetManyGroupUsersByOrganizationIdAsync(organizationId); + var existingGroupUsers = groupUsers + .GroupBy(gu => gu.GroupId) + .ToDictionary(g => g.Key, g => new HashSet(g.Select(gr => gr.OrganizationUserId))); + + foreach (var group in updateGroups) + { + var updatedGroup = groupsDict[group.ExternalId].Group; + if (group.Name != updatedGroup.Name) + { + group.RevisionDate = DateTime.UtcNow; + group.Name = updatedGroup.Name; + + await _groupRepository.ReplaceAsync(group); + } + + await UpdateUsersAsync(group, groupsDict[group.ExternalId].ExternalUserIds, + existingExternalUsersIdDict, + existingGroupUsers.ContainsKey(group.Id) ? existingGroupUsers[group.Id] : null); + } + } + } + + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.DirectorySynced, organization)); + } + + public async Task DeleteSsoUserAsync(Guid userId, Guid? organizationId) + { + await _ssoUserRepository.DeleteAsync(userId, organizationId); + if (organizationId.HasValue) + { + var organizationUser = await _organizationUserRepository.GetByOrganizationAsync(organizationId.Value, userId); + if (organizationUser != null) + { + await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_UnlinkedSso); + } + } + } + + public async Task UpdateOrganizationKeysAsync(Guid orgId, string publicKey, string privateKey) + { + if (!await _currentContext.ManageResetPassword(orgId)) + { + throw new UnauthorizedAccessException(); + } + + // If the keys already exist, error out + var org = await _organizationRepository.GetByIdAsync(orgId); + if (org.PublicKey != null && org.PrivateKey != null) + { + throw new BadRequestException("Organization Keys already exist"); + } + + // Update org with generated public/private key + org.PublicKey = publicKey; + org.PrivateKey = privateKey; + await UpdateAsync(org); + + return org; + } + + private async Task UpdateUsersAsync(Group group, HashSet groupUsers, + Dictionary existingUsersIdDict, HashSet existingUsers = null) + { + var availableUsers = groupUsers.Intersect(existingUsersIdDict.Keys); + var users = new HashSet(availableUsers.Select(u => existingUsersIdDict[u])); + if (existingUsers != null && existingUsers.Count == users.Count && users.SetEquals(existingUsers)) + { + return; + } + + await _groupRepository.UpdateUsersAsync(group.Id, users); + } + + private async Task> GetConfirmedOwnersAsync(Guid organizationId) + { + var owners = await _organizationUserRepository.GetManyByOrganizationAsync(organizationId, + OrganizationUserType.Owner); + return owners.Where(o => o.Status == OrganizationUserStatusType.Confirmed); + } + + private async Task DeleteAndPushUserRegistrationAsync(Guid organizationId, Guid userId) + { + var deviceIds = await GetUserDeviceIdsAsync(userId); + await _pushRegistrationService.DeleteUserRegistrationOrganizationAsync(deviceIds, + organizationId.ToString()); + await _pushNotificationService.PushSyncOrgKeysAsync(userId); + } + + + private async Task> GetUserDeviceIdsAsync(Guid userId) + { + var devices = await _deviceRepository.GetManyByUserIdAsync(userId); + return devices.Where(d => !string.IsNullOrWhiteSpace(d.PushToken)).Select(d => d.Id.ToString()); + } + + private async Task ReplaceAndUpdateCache(Organization org, EventType? orgEvent = null) + { + await _organizationRepository.ReplaceAsync(org); + await _applicationCacheService.UpsertOrganizationAbilityAsync(org); + + if (orgEvent.HasValue) + { + await _eventService.LogOrganizationEventAsync(org, orgEvent.Value); + } + } + + private async Task GetOrgById(Guid id) + { + return await _organizationRepository.GetByIdAsync(id); + } + + private void ValidateOrganizationUpgradeParameters(Models.StaticStore.Plan plan, OrganizationUpgrade upgrade) + { + if (!plan.HasAdditionalStorageOption && upgrade.AdditionalStorageGb > 0) + { + throw new BadRequestException("Plan does not allow additional storage."); + } + + if (upgrade.AdditionalStorageGb < 0) + { + throw new BadRequestException("You can't subtract storage!"); + } + + if (!plan.HasPremiumAccessOption && upgrade.PremiumAccessAddon) + { + throw new BadRequestException("This plan does not allow you to buy the premium access addon."); + } + + if (plan.BaseSeats + upgrade.AdditionalSeats <= 0) + { + throw new BadRequestException("You do not have any seats!"); + } + + if (upgrade.AdditionalSeats < 0) + { + throw new BadRequestException("You can't subtract seats!"); + } + + if (!plan.HasAdditionalSeatsOption && upgrade.AdditionalSeats > 0) + { + throw new BadRequestException("Plan does not allow additional users."); + } + + if (plan.HasAdditionalSeatsOption && plan.MaxAdditionalSeats.HasValue && + upgrade.AdditionalSeats > plan.MaxAdditionalSeats.Value) + { + throw new BadRequestException($"Selected plan allows a maximum of " + + $"{plan.MaxAdditionalSeats.GetValueOrDefault(0)} additional users."); + } + } + + private async Task ValidateOrganizationUserUpdatePermissions(Guid organizationId, OrganizationUserType newType, + OrganizationUserType? oldType) + { + if (await _currentContext.OrganizationOwner(organizationId)) + { + return; + } + + if (oldType == OrganizationUserType.Owner || newType == OrganizationUserType.Owner) + { + throw new BadRequestException("Only an Owner can configure another Owner's account."); + } + + if (await _currentContext.OrganizationAdmin(organizationId)) + { + return; + } + + if (oldType == OrganizationUserType.Custom || newType == OrganizationUserType.Custom) + { + throw new BadRequestException("Only Owners and Admins can configure Custom accounts."); + } + + if (!await _currentContext.ManageUsers(organizationId)) + { + throw new BadRequestException("Your account does not have permission to manage users."); + } + + if (oldType == OrganizationUserType.Admin || newType == OrganizationUserType.Admin) + { + throw new BadRequestException("Custom users can not manage Admins or Owners."); + } + } + + private async Task ValidateDeleteOrganizationAsync(Organization organization) + { + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organization.Id); + if (ssoConfig?.GetData()?.KeyConnectorEnabled == true) + { + throw new BadRequestException("You cannot delete an Organization that is using Key Connector."); + } + } + + public async Task RevokeUserAsync(OrganizationUser organizationUser, Guid? revokingUserId) + { + if (organizationUser.Status == OrganizationUserStatusType.Revoked) + { + throw new BadRequestException("Already revoked."); + } + + if (revokingUserId.HasValue && organizationUser.UserId == revokingUserId.Value) + { + throw new BadRequestException("You cannot revoke yourself."); + } + + if (organizationUser.Type == OrganizationUserType.Owner && revokingUserId.HasValue && + !await _currentContext.OrganizationOwner(organizationUser.OrganizationId)) + { + throw new BadRequestException("Only owners can revoke other owners."); + } + + if (!await HasConfirmedOwnersExceptAsync(organizationUser.OrganizationId, new[] { organizationUser.Id })) + { + throw new BadRequestException("Organization must have at least one confirmed owner."); + } + + await _organizationUserRepository.RevokeAsync(organizationUser.Id); + organizationUser.Status = OrganizationUserStatusType.Revoked; + await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Revoked); + } + + public async Task>> RevokeUsersAsync(Guid organizationId, + IEnumerable organizationUserIds, Guid? revokingUserId) + { + var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUserIds); + var filteredUsers = orgUsers.Where(u => u.OrganizationId == organizationId) + .ToList(); + + if (!filteredUsers.Any()) + { + throw new BadRequestException("Users invalid."); + } + + if (!await HasConfirmedOwnersExceptAsync(organizationId, organizationUserIds)) + { + throw new BadRequestException("Organization must have at least one confirmed owner."); + } + + var deletingUserIsOwner = false; + if (revokingUserId.HasValue) + { + deletingUserIsOwner = await _currentContext.OrganizationOwner(organizationId); + } + + var result = new List>(); + + foreach (var organizationUser in filteredUsers) + { + try + { + if (organizationUser.Status == OrganizationUserStatusType.Revoked) + { + throw new BadRequestException("Already revoked."); + } + + if (revokingUserId.HasValue && organizationUser.UserId == revokingUserId) + { + throw new BadRequestException("You cannot revoke yourself."); + } + + if (organizationUser.Type == OrganizationUserType.Owner && revokingUserId.HasValue && !deletingUserIsOwner) + { + throw new BadRequestException("Only owners can revoke other owners."); + } + + await _organizationUserRepository.RevokeAsync(organizationUser.Id); + organizationUser.Status = OrganizationUserStatusType.Revoked; + await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Revoked); + + result.Add(Tuple.Create(organizationUser, "")); + } + catch (BadRequestException e) + { + result.Add(Tuple.Create(organizationUser, e.Message)); + } + } + + return result; + } + + public async Task RestoreUserAsync(OrganizationUser organizationUser, Guid? restoringUserId, IUserService userService) + { + if (organizationUser.Status != OrganizationUserStatusType.Revoked) + { + throw new BadRequestException("Already active."); + } + + if (restoringUserId.HasValue && organizationUser.UserId == restoringUserId.Value) + { + throw new BadRequestException("You cannot restore yourself."); + } + + if (organizationUser.Type == OrganizationUserType.Owner && restoringUserId.HasValue && + !await _currentContext.OrganizationOwner(organizationUser.OrganizationId)) + { + throw new BadRequestException("Only owners can restore other owners."); + } + + await CheckPoliciesBeforeRestoreAsync(organizationUser, userService); + + var status = GetPriorActiveOrganizationUserStatusType(organizationUser); + + await _organizationUserRepository.RestoreAsync(organizationUser.Id, status); + organizationUser.Status = status; + await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored); + } + + public async Task>> RestoreUsersAsync(Guid organizationId, + IEnumerable organizationUserIds, Guid? restoringUserId, IUserService userService) + { + var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUserIds); + var filteredUsers = orgUsers.Where(u => u.OrganizationId == organizationId) + .ToList(); + + if (!filteredUsers.Any()) + { + throw new BadRequestException("Users invalid."); + } + + var deletingUserIsOwner = false; + if (restoringUserId.HasValue) + { + deletingUserIsOwner = await _currentContext.OrganizationOwner(organizationId); + } + + var result = new List>(); + + foreach (var organizationUser in filteredUsers) + { + try + { + if (organizationUser.Status != OrganizationUserStatusType.Revoked) + { + throw new BadRequestException("Already active."); + } + + if (restoringUserId.HasValue && organizationUser.UserId == restoringUserId) + { + throw new BadRequestException("You cannot restore yourself."); + } + + if (organizationUser.Type == OrganizationUserType.Owner && restoringUserId.HasValue && !deletingUserIsOwner) + { + throw new BadRequestException("Only owners can restore other owners."); + } + + await CheckPoliciesBeforeRestoreAsync(organizationUser, userService); + + var status = GetPriorActiveOrganizationUserStatusType(organizationUser); + + await _organizationUserRepository.RestoreAsync(organizationUser.Id, status); + organizationUser.Status = status; + await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored); + + result.Add(Tuple.Create(organizationUser, "")); + } + catch (BadRequestException e) + { + result.Add(Tuple.Create(organizationUser, e.Message)); + } + } + + return result; + } + + private async Task CheckPoliciesBeforeRestoreAsync(OrganizationUser orgUser, IUserService userService) + { + // An invited OrganizationUser isn't linked with a user account yet, so these checks are irrelevant + // The user will be subject to the same checks when they try to accept the invite + if (GetPriorActiveOrganizationUserStatusType(orgUser) == OrganizationUserStatusType.Invited) + { + return; + } + + var userId = orgUser.UserId.Value; + + // Enforce Single Organization Policy of organization user is being restored to + var allOrgUsers = await _organizationUserRepository.GetManyByUserAsync(userId); + var hasOtherOrgs = allOrgUsers.Any(ou => ou.OrganizationId != orgUser.OrganizationId); + var singleOrgPoliciesApplyingToRevokedUsers = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(userId, + PolicyType.SingleOrg, OrganizationUserStatusType.Revoked); + var singleOrgPolicyApplies = singleOrgPoliciesApplyingToRevokedUsers.Any(p => p.OrganizationId == orgUser.OrganizationId); + + if (hasOtherOrgs && singleOrgPolicyApplies) + { + throw new BadRequestException("You cannot restore this user until " + + "they leave or remove all other organizations."); + } + + // Enforce Single Organization Policy of other organizations user is a member of + var singleOrgPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(userId, + PolicyType.SingleOrg); + if (singleOrgPolicyCount > 0) + { + throw new BadRequestException("You cannot restore this user because they are a member of " + + "another organization which forbids it"); + } + + // Enforce Two Factor Authentication Policy of organization user is trying to join + var user = await _userRepository.GetByIdAsync(userId); + if (!await userService.TwoFactorIsEnabledAsync(user)) + { + var invitedTwoFactorPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(userId, + PolicyType.TwoFactorAuthentication, OrganizationUserStatusType.Invited); + if (invitedTwoFactorPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId)) + { + throw new BadRequestException("You cannot restore this user until they enable " + + "two-step login on their user account."); + } + } + } + + static OrganizationUserStatusType GetPriorActiveOrganizationUserStatusType(OrganizationUser organizationUser) + { + // Determine status to revert back to + var status = OrganizationUserStatusType.Invited; + if (organizationUser.UserId.HasValue && string.IsNullOrWhiteSpace(organizationUser.Email)) + { + // Has UserId & Email is null, then Accepted + status = OrganizationUserStatusType.Accepted; + if (!string.IsNullOrWhiteSpace(organizationUser.Key)) + { + // We have an org key for this user, user was confirmed + status = OrganizationUserStatusType.Confirmed; + } + } + + return status; + } } diff --git a/src/Core/Services/Implementations/PolicyService.cs b/src/Core/Services/Implementations/PolicyService.cs index e84a124e6..938975f59 100644 --- a/src/Core/Services/Implementations/PolicyService.cs +++ b/src/Core/Services/Implementations/PolicyService.cs @@ -3,170 +3,169 @@ using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; -namespace Bit.Core.Services -{ - public class PolicyService : IPolicyService - { - private readonly IEventService _eventService; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPolicyRepository _policyRepository; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly IMailService _mailService; +namespace Bit.Core.Services; - public PolicyService( - IEventService eventService, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IPolicyRepository policyRepository, - ISsoConfigRepository ssoConfigRepository, - IMailService mailService) +public class PolicyService : IPolicyService +{ + private readonly IEventService _eventService; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IPolicyRepository _policyRepository; + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly IMailService _mailService; + + public PolicyService( + IEventService eventService, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IPolicyRepository policyRepository, + ISsoConfigRepository ssoConfigRepository, + IMailService mailService) + { + _eventService = eventService; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _policyRepository = policyRepository; + _ssoConfigRepository = ssoConfigRepository; + _mailService = mailService; + } + + public async Task SaveAsync(Policy policy, IUserService userService, IOrganizationService organizationService, + Guid? savingUserId) + { + var org = await _organizationRepository.GetByIdAsync(policy.OrganizationId); + if (org == null) { - _eventService = eventService; - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _policyRepository = policyRepository; - _ssoConfigRepository = ssoConfigRepository; - _mailService = mailService; + throw new BadRequestException("Organization not found"); } - public async Task SaveAsync(Policy policy, IUserService userService, IOrganizationService organizationService, - Guid? savingUserId) + if (!org.UsePolicies) { - var org = await _organizationRepository.GetByIdAsync(policy.OrganizationId); - if (org == null) - { - throw new BadRequestException("Organization not found"); - } + throw new BadRequestException("This organization cannot use policies."); + } - if (!org.UsePolicies) - { - throw new BadRequestException("This organization cannot use policies."); - } - - // Handle dependent policy checks - switch (policy.Type) - { - case PolicyType.SingleOrg: - if (!policy.Enabled) - { - await RequiredBySsoAsync(org); - await RequiredByVaultTimeoutAsync(org); - await RequiredByKeyConnectorAsync(org); - } - break; - - case PolicyType.RequireSso: - if (policy.Enabled) - { - await DependsOnSingleOrgAsync(org); - } - else - { - await RequiredByKeyConnectorAsync(org); - } - break; - - case PolicyType.MaximumVaultTimeout: - if (policy.Enabled) - { - await DependsOnSingleOrgAsync(org); - } - break; - } - - var now = DateTime.UtcNow; - if (policy.Id == default(Guid)) - { - policy.CreationDate = now; - } - - if (policy.Enabled) - { - var currentPolicy = await _policyRepository.GetByIdAsync(policy.Id); - if (!currentPolicy?.Enabled ?? true) + // Handle dependent policy checks + switch (policy.Type) + { + case PolicyType.SingleOrg: + if (!policy.Enabled) { - var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync( - policy.OrganizationId); - var removableOrgUsers = orgUsers.Where(ou => - ou.Status != Enums.OrganizationUserStatusType.Invited && - ou.Type != Enums.OrganizationUserType.Owner && ou.Type != Enums.OrganizationUserType.Admin && - ou.UserId != savingUserId); - switch (policy.Type) - { - case Enums.PolicyType.TwoFactorAuthentication: - foreach (var orgUser in removableOrgUsers) + await RequiredBySsoAsync(org); + await RequiredByVaultTimeoutAsync(org); + await RequiredByKeyConnectorAsync(org); + } + break; + + case PolicyType.RequireSso: + if (policy.Enabled) + { + await DependsOnSingleOrgAsync(org); + } + else + { + await RequiredByKeyConnectorAsync(org); + } + break; + + case PolicyType.MaximumVaultTimeout: + if (policy.Enabled) + { + await DependsOnSingleOrgAsync(org); + } + break; + } + + var now = DateTime.UtcNow; + if (policy.Id == default(Guid)) + { + policy.CreationDate = now; + } + + if (policy.Enabled) + { + var currentPolicy = await _policyRepository.GetByIdAsync(policy.Id); + if (!currentPolicy?.Enabled ?? true) + { + var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync( + policy.OrganizationId); + var removableOrgUsers = orgUsers.Where(ou => + ou.Status != Enums.OrganizationUserStatusType.Invited && + ou.Type != Enums.OrganizationUserType.Owner && ou.Type != Enums.OrganizationUserType.Admin && + ou.UserId != savingUserId); + switch (policy.Type) + { + case Enums.PolicyType.TwoFactorAuthentication: + foreach (var orgUser in removableOrgUsers) + { + if (!await userService.TwoFactorIsEnabledAsync(orgUser)) { - if (!await userService.TwoFactorIsEnabledAsync(orgUser)) - { - await organizationService.DeleteUserAsync(policy.OrganizationId, orgUser.Id, - savingUserId); - await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync( - org.Name, orgUser.Email); - } + await organizationService.DeleteUserAsync(policy.OrganizationId, orgUser.Id, + savingUserId); + await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync( + org.Name, orgUser.Email); } - break; - case Enums.PolicyType.SingleOrg: - var userOrgs = await _organizationUserRepository.GetManyByManyUsersAsync( - removableOrgUsers.Select(ou => ou.UserId.Value)); - foreach (var orgUser in removableOrgUsers) + } + break; + case Enums.PolicyType.SingleOrg: + var userOrgs = await _organizationUserRepository.GetManyByManyUsersAsync( + removableOrgUsers.Select(ou => ou.UserId.Value)); + foreach (var orgUser in removableOrgUsers) + { + if (userOrgs.Any(ou => ou.UserId == orgUser.UserId + && ou.OrganizationId != org.Id + && ou.Status != OrganizationUserStatusType.Invited)) { - if (userOrgs.Any(ou => ou.UserId == orgUser.UserId - && ou.OrganizationId != org.Id - && ou.Status != OrganizationUserStatusType.Invited)) - { - await organizationService.DeleteUserAsync(policy.OrganizationId, orgUser.Id, - savingUserId); - await _mailService.SendOrganizationUserRemovedForPolicySingleOrgEmailAsync( - org.Name, orgUser.Email); - } + await organizationService.DeleteUserAsync(policy.OrganizationId, orgUser.Id, + savingUserId); + await _mailService.SendOrganizationUserRemovedForPolicySingleOrgEmailAsync( + org.Name, orgUser.Email); } - break; - default: - break; - } + } + break; + default: + break; } } - policy.RevisionDate = now; - await _policyRepository.UpsertAsync(policy); - await _eventService.LogPolicyEventAsync(policy, Enums.EventType.Policy_Updated); } + policy.RevisionDate = now; + await _policyRepository.UpsertAsync(policy); + await _eventService.LogPolicyEventAsync(policy, Enums.EventType.Policy_Updated); + } - private async Task DependsOnSingleOrgAsync(Organization org) + private async Task DependsOnSingleOrgAsync(Organization org) + { + var singleOrg = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.SingleOrg); + if (singleOrg?.Enabled != true) { - var singleOrg = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.SingleOrg); - if (singleOrg?.Enabled != true) - { - throw new BadRequestException("Single Organization policy not enabled."); - } + throw new BadRequestException("Single Organization policy not enabled."); } + } - private async Task RequiredBySsoAsync(Organization org) + private async Task RequiredBySsoAsync(Organization org) + { + var requireSso = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.RequireSso); + if (requireSso?.Enabled == true) { - var requireSso = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.RequireSso); - if (requireSso?.Enabled == true) - { - throw new BadRequestException("Single Sign-On Authentication policy is enabled."); - } + throw new BadRequestException("Single Sign-On Authentication policy is enabled."); } + } - private async Task RequiredByKeyConnectorAsync(Organization org) + private async Task RequiredByKeyConnectorAsync(Organization org) + { + + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(org.Id); + if (ssoConfig?.GetData()?.KeyConnectorEnabled == true) { - - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(org.Id); - if (ssoConfig?.GetData()?.KeyConnectorEnabled == true) - { - throw new BadRequestException("Key Connector is enabled."); - } + throw new BadRequestException("Key Connector is enabled."); } + } - private async Task RequiredByVaultTimeoutAsync(Organization org) + private async Task RequiredByVaultTimeoutAsync(Organization org) + { + var vaultTimeout = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.MaximumVaultTimeout); + if (vaultTimeout?.Enabled == true) { - var vaultTimeout = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.MaximumVaultTimeout); - if (vaultTimeout?.Enabled == true) - { - throw new BadRequestException("Maximum Vault Timeout policy is enabled."); - } + throw new BadRequestException("Maximum Vault Timeout policy is enabled."); } } } diff --git a/src/Core/Services/Implementations/RelayPushNotificationService.cs b/src/Core/Services/Implementations/RelayPushNotificationService.cs index b66cb7ca1..b3670ad7b 100644 --- a/src/Core/Services/Implementations/RelayPushNotificationService.cs +++ b/src/Core/Services/Implementations/RelayPushNotificationService.cs @@ -8,219 +8,218 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class RelayPushNotificationService : BaseIdentityClientService, IPushNotificationService { - public class RelayPushNotificationService : BaseIdentityClientService, IPushNotificationService + private readonly IDeviceRepository _deviceRepository; + private readonly IHttpContextAccessor _httpContextAccessor; + + public RelayPushNotificationService( + IHttpClientFactory httpFactory, + IDeviceRepository deviceRepository, + GlobalSettings globalSettings, + IHttpContextAccessor httpContextAccessor, + ILogger logger) + : base( + httpFactory, + globalSettings.PushRelayBaseUri, + globalSettings.Installation.IdentityUri, + "api.push", + $"installation.{globalSettings.Installation.Id}", + globalSettings.Installation.Key, + logger) { - private readonly IDeviceRepository _deviceRepository; - private readonly IHttpContextAccessor _httpContextAccessor; + _deviceRepository = deviceRepository; + _httpContextAccessor = httpContextAccessor; + } - public RelayPushNotificationService( - IHttpClientFactory httpFactory, - IDeviceRepository deviceRepository, - GlobalSettings globalSettings, - IHttpContextAccessor httpContextAccessor, - ILogger logger) - : base( - httpFactory, - globalSettings.PushRelayBaseUri, - globalSettings.Installation.IdentityUri, - "api.push", - $"installation.{globalSettings.Installation.Id}", - globalSettings.Installation.Key, - logger) + public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) + { + await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds); + } + + public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) + { + await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds); + } + + public async Task PushSyncCipherDeleteAsync(Cipher cipher) + { + await PushCipherAsync(cipher, PushType.SyncLoginDelete, null); + } + + private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable collectionIds) + { + if (cipher.OrganizationId.HasValue) { - _deviceRepository = deviceRepository; - _httpContextAccessor = httpContextAccessor; + // We cannot send org pushes since access logic is much more complicated than just the fact that they belong + // to the organization. Potentially we could blindly send to just users that have the access all permission + // device registration needs to be more granular to handle that appropriately. A more brute force approach could + // me to send "full sync" push to all org users, but that has the potential to DDOS the API in bursts. + + // await SendPayloadToOrganizationAsync(cipher.OrganizationId.Value, type, message, true); } - - public async Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) + else if (cipher.UserId.HasValue) { - await PushCipherAsync(cipher, PushType.SyncCipherCreate, collectionIds); - } - - public async Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) - { - await PushCipherAsync(cipher, PushType.SyncCipherUpdate, collectionIds); - } - - public async Task PushSyncCipherDeleteAsync(Cipher cipher) - { - await PushCipherAsync(cipher, PushType.SyncLoginDelete, null); - } - - private async Task PushCipherAsync(Cipher cipher, PushType type, IEnumerable collectionIds) - { - if (cipher.OrganizationId.HasValue) + var message = new SyncCipherPushNotification { - // We cannot send org pushes since access logic is much more complicated than just the fact that they belong - // to the organization. Potentially we could blindly send to just users that have the access all permission - // device registration needs to be more granular to handle that appropriately. A more brute force approach could - // me to send "full sync" push to all org users, but that has the potential to DDOS the API in bursts. - - // await SendPayloadToOrganizationAsync(cipher.OrganizationId.Value, type, message, true); - } - else if (cipher.UserId.HasValue) - { - var message = new SyncCipherPushNotification - { - Id = cipher.Id, - UserId = cipher.UserId, - OrganizationId = cipher.OrganizationId, - RevisionDate = cipher.RevisionDate, - }; - - await SendPayloadToUserAsync(cipher.UserId.Value, type, message, true); - } - } - - public async Task PushSyncFolderCreateAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderCreate); - } - - public async Task PushSyncFolderUpdateAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderUpdate); - } - - public async Task PushSyncFolderDeleteAsync(Folder folder) - { - await PushFolderAsync(folder, PushType.SyncFolderDelete); - } - - private async Task PushFolderAsync(Folder folder, PushType type) - { - var message = new SyncFolderPushNotification - { - Id = folder.Id, - UserId = folder.UserId, - RevisionDate = folder.RevisionDate + Id = cipher.Id, + UserId = cipher.UserId, + OrganizationId = cipher.OrganizationId, + RevisionDate = cipher.RevisionDate, }; - await SendPayloadToUserAsync(folder.UserId, type, message, true); - } - - public async Task PushSyncCiphersAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncCiphers); - } - - public async Task PushSyncVaultAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncVault); - } - - public async Task PushSyncOrgKeysAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncOrgKeys); - } - - public async Task PushSyncSettingsAsync(Guid userId) - { - await PushUserAsync(userId, PushType.SyncSettings); - } - - public async Task PushLogOutAsync(Guid userId) - { - await PushUserAsync(userId, PushType.LogOut); - } - - private async Task PushUserAsync(Guid userId, PushType type) - { - var message = new UserPushNotification - { - UserId = userId, - Date = DateTime.UtcNow - }; - - await SendPayloadToUserAsync(userId, type, message, false); - } - - public async Task PushSyncSendCreateAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendCreate); - } - - public async Task PushSyncSendUpdateAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendUpdate); - } - - public async Task PushSyncSendDeleteAsync(Send send) - { - await PushSendAsync(send, PushType.SyncSendDelete); - } - - private async Task PushSendAsync(Send send, PushType type) - { - if (send.UserId.HasValue) - { - var message = new SyncSendPushNotification - { - Id = send.Id, - UserId = send.UserId.Value, - RevisionDate = send.RevisionDate - }; - - await SendPayloadToUserAsync(message.UserId, type, message, true); - } - } - - private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext) - { - var request = new PushSendRequestModel - { - UserId = userId.ToString(), - Type = type, - Payload = payload - }; - - await AddCurrentContextAsync(request, excludeCurrentContext); - await SendAsync(HttpMethod.Post, "push/send", request); - } - - private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload, bool excludeCurrentContext) - { - var request = new PushSendRequestModel - { - OrganizationId = orgId.ToString(), - Type = type, - Payload = payload - }; - - await AddCurrentContextAsync(request, excludeCurrentContext); - await SendAsync(HttpMethod.Post, "push/send", request); - } - - private async Task AddCurrentContextAsync(PushSendRequestModel request, bool addIdentifier) - { - var currentContext = _httpContextAccessor?.HttpContext?. - RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; - if (!string.IsNullOrWhiteSpace(currentContext?.DeviceIdentifier)) - { - var device = await _deviceRepository.GetByIdentifierAsync(currentContext.DeviceIdentifier); - if (device != null) - { - request.DeviceId = device.Id.ToString(); - } - if (addIdentifier) - { - request.Identifier = currentContext.DeviceIdentifier; - } - } - } - - public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, - string deviceId = null) - { - throw new NotImplementedException(); - } - - public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null) - { - throw new NotImplementedException(); + await SendPayloadToUserAsync(cipher.UserId.Value, type, message, true); } } + + public async Task PushSyncFolderCreateAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderCreate); + } + + public async Task PushSyncFolderUpdateAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderUpdate); + } + + public async Task PushSyncFolderDeleteAsync(Folder folder) + { + await PushFolderAsync(folder, PushType.SyncFolderDelete); + } + + private async Task PushFolderAsync(Folder folder, PushType type) + { + var message = new SyncFolderPushNotification + { + Id = folder.Id, + UserId = folder.UserId, + RevisionDate = folder.RevisionDate + }; + + await SendPayloadToUserAsync(folder.UserId, type, message, true); + } + + public async Task PushSyncCiphersAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncCiphers); + } + + public async Task PushSyncVaultAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncVault); + } + + public async Task PushSyncOrgKeysAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncOrgKeys); + } + + public async Task PushSyncSettingsAsync(Guid userId) + { + await PushUserAsync(userId, PushType.SyncSettings); + } + + public async Task PushLogOutAsync(Guid userId) + { + await PushUserAsync(userId, PushType.LogOut); + } + + private async Task PushUserAsync(Guid userId, PushType type) + { + var message = new UserPushNotification + { + UserId = userId, + Date = DateTime.UtcNow + }; + + await SendPayloadToUserAsync(userId, type, message, false); + } + + public async Task PushSyncSendCreateAsync(Send send) + { + await PushSendAsync(send, PushType.SyncSendCreate); + } + + public async Task PushSyncSendUpdateAsync(Send send) + { + await PushSendAsync(send, PushType.SyncSendUpdate); + } + + public async Task PushSyncSendDeleteAsync(Send send) + { + await PushSendAsync(send, PushType.SyncSendDelete); + } + + private async Task PushSendAsync(Send send, PushType type) + { + if (send.UserId.HasValue) + { + var message = new SyncSendPushNotification + { + Id = send.Id, + UserId = send.UserId.Value, + RevisionDate = send.RevisionDate + }; + + await SendPayloadToUserAsync(message.UserId, type, message, true); + } + } + + private async Task SendPayloadToUserAsync(Guid userId, PushType type, object payload, bool excludeCurrentContext) + { + var request = new PushSendRequestModel + { + UserId = userId.ToString(), + Type = type, + Payload = payload + }; + + await AddCurrentContextAsync(request, excludeCurrentContext); + await SendAsync(HttpMethod.Post, "push/send", request); + } + + private async Task SendPayloadToOrganizationAsync(Guid orgId, PushType type, object payload, bool excludeCurrentContext) + { + var request = new PushSendRequestModel + { + OrganizationId = orgId.ToString(), + Type = type, + Payload = payload + }; + + await AddCurrentContextAsync(request, excludeCurrentContext); + await SendAsync(HttpMethod.Post, "push/send", request); + } + + private async Task AddCurrentContextAsync(PushSendRequestModel request, bool addIdentifier) + { + var currentContext = _httpContextAccessor?.HttpContext?. + RequestServices.GetService(typeof(ICurrentContext)) as ICurrentContext; + if (!string.IsNullOrWhiteSpace(currentContext?.DeviceIdentifier)) + { + var device = await _deviceRepository.GetByIdentifierAsync(currentContext.DeviceIdentifier); + if (device != null) + { + request.DeviceId = device.Id.ToString(); + } + if (addIdentifier) + { + request.Identifier = currentContext.DeviceIdentifier; + } + } + } + + public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, + string deviceId = null) + { + throw new NotImplementedException(); + } + + public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, + string deviceId = null) + { + throw new NotImplementedException(); + } } diff --git a/src/Core/Services/Implementations/RelayPushRegistrationService.cs b/src/Core/Services/Implementations/RelayPushRegistrationService.cs index 82ae88799..2e3087421 100644 --- a/src/Core/Services/Implementations/RelayPushRegistrationService.cs +++ b/src/Core/Services/Implementations/RelayPushRegistrationService.cs @@ -3,65 +3,64 @@ using Bit.Core.Models.Api; using Bit.Core.Settings; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class RelayPushRegistrationService : BaseIdentityClientService, IPushRegistrationService { - public class RelayPushRegistrationService : BaseIdentityClientService, IPushRegistrationService + + public RelayPushRegistrationService( + IHttpClientFactory httpFactory, + GlobalSettings globalSettings, + ILogger logger) + : base( + httpFactory, + globalSettings.PushRelayBaseUri, + globalSettings.Installation.IdentityUri, + "api.push", + $"installation.{globalSettings.Installation.Id}", + globalSettings.Installation.Key, + logger) { + } - public RelayPushRegistrationService( - IHttpClientFactory httpFactory, - GlobalSettings globalSettings, - ILogger logger) - : base( - httpFactory, - globalSettings.PushRelayBaseUri, - globalSettings.Installation.IdentityUri, - "api.push", - $"installation.{globalSettings.Installation.Id}", - globalSettings.Installation.Key, - logger) + public async Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, + string identifier, DeviceType type) + { + var requestModel = new PushRegistrationRequestModel { + DeviceId = deviceId, + Identifier = identifier, + PushToken = pushToken, + Type = type, + UserId = userId + }; + await SendAsync(HttpMethod.Post, "push/register", requestModel); + } + + public async Task DeleteRegistrationAsync(string deviceId) + { + await SendAsync(HttpMethod.Delete, string.Concat("push/", deviceId)); + } + + public async Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) + { + if (!deviceIds.Any()) + { + return; } - public async Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, - string identifier, DeviceType type) + var requestModel = new PushUpdateRequestModel(deviceIds, organizationId); + await SendAsync(HttpMethod.Put, "push/add-organization", requestModel); + } + + public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) + { + if (!deviceIds.Any()) { - var requestModel = new PushRegistrationRequestModel - { - DeviceId = deviceId, - Identifier = identifier, - PushToken = pushToken, - Type = type, - UserId = userId - }; - await SendAsync(HttpMethod.Post, "push/register", requestModel); + return; } - public async Task DeleteRegistrationAsync(string deviceId) - { - await SendAsync(HttpMethod.Delete, string.Concat("push/", deviceId)); - } - - public async Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) - { - if (!deviceIds.Any()) - { - return; - } - - var requestModel = new PushUpdateRequestModel(deviceIds, organizationId); - await SendAsync(HttpMethod.Put, "push/add-organization", requestModel); - } - - public async Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) - { - if (!deviceIds.Any()) - { - return; - } - - var requestModel = new PushUpdateRequestModel(deviceIds, organizationId); - await SendAsync(HttpMethod.Put, "push/delete-organization", requestModel); - } + var requestModel = new PushUpdateRequestModel(deviceIds, organizationId); + await SendAsync(HttpMethod.Put, "push/delete-organization", requestModel); } } diff --git a/src/Core/Services/Implementations/RepositoryEventWriteService.cs b/src/Core/Services/Implementations/RepositoryEventWriteService.cs index 11d028340..a8299c1e8 100644 --- a/src/Core/Services/Implementations/RepositoryEventWriteService.cs +++ b/src/Core/Services/Implementations/RepositoryEventWriteService.cs @@ -1,26 +1,25 @@ using Bit.Core.Models.Data; using Bit.Core.Repositories; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class RepositoryEventWriteService : IEventWriteService { - public class RepositoryEventWriteService : IEventWriteService + private readonly IEventRepository _eventRepository; + + public RepositoryEventWriteService( + IEventRepository eventRepository) { - private readonly IEventRepository _eventRepository; + _eventRepository = eventRepository; + } - public RepositoryEventWriteService( - IEventRepository eventRepository) - { - _eventRepository = eventRepository; - } + public async Task CreateAsync(IEvent e) + { + await _eventRepository.CreateAsync(e); + } - public async Task CreateAsync(IEvent e) - { - await _eventRepository.CreateAsync(e); - } - - public async Task CreateManyAsync(IEnumerable e) - { - await _eventRepository.CreateManyAsync(e); - } + public async Task CreateManyAsync(IEnumerable e) + { + await _eventRepository.CreateManyAsync(e); } } diff --git a/src/Core/Services/Implementations/SendGridMailDeliveryService.cs b/src/Core/Services/Implementations/SendGridMailDeliveryService.cs index 0e34b170d..a35d11997 100644 --- a/src/Core/Services/Implementations/SendGridMailDeliveryService.cs +++ b/src/Core/Services/Implementations/SendGridMailDeliveryService.cs @@ -6,110 +6,109 @@ using Microsoft.Extensions.Logging; using SendGrid; using SendGrid.Helpers.Mail; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class SendGridMailDeliveryService : IMailDeliveryService, IDisposable { - public class SendGridMailDeliveryService : IMailDeliveryService, IDisposable + private readonly GlobalSettings _globalSettings; + private readonly IWebHostEnvironment _hostingEnvironment; + private readonly ILogger _logger; + private readonly ISendGridClient _client; + private readonly string _senderTag; + private readonly string _replyToEmail; + + public SendGridMailDeliveryService( + GlobalSettings globalSettings, + IWebHostEnvironment hostingEnvironment, + ILogger logger) + : this(new SendGridClient(globalSettings.Mail.SendGridApiKey), + globalSettings, hostingEnvironment, logger) { - private readonly GlobalSettings _globalSettings; - private readonly IWebHostEnvironment _hostingEnvironment; - private readonly ILogger _logger; - private readonly ISendGridClient _client; - private readonly string _senderTag; - private readonly string _replyToEmail; + } - public SendGridMailDeliveryService( - GlobalSettings globalSettings, - IWebHostEnvironment hostingEnvironment, - ILogger logger) - : this(new SendGridClient(globalSettings.Mail.SendGridApiKey), - globalSettings, hostingEnvironment, logger) + public void Dispose() + { + // TODO: nothing to dispose + } + + public SendGridMailDeliveryService( + ISendGridClient client, + GlobalSettings globalSettings, + IWebHostEnvironment hostingEnvironment, + ILogger logger) + { + if (string.IsNullOrWhiteSpace(globalSettings.Mail?.SendGridApiKey)) { + throw new ArgumentNullException(nameof(globalSettings.Mail.SendGridApiKey)); } - public void Dispose() + _globalSettings = globalSettings; + _hostingEnvironment = hostingEnvironment; + _logger = logger; + _client = client; + _senderTag = $"Server_{globalSettings.ProjectName?.Replace(' ', '_')}"; + _replyToEmail = CoreHelpers.PunyEncode(globalSettings.Mail.ReplyToEmail); + } + + public async Task SendEmailAsync(MailMessage message) + { + var msg = new SendGridMessage(); + msg.SetFrom(new EmailAddress(_replyToEmail, _globalSettings.SiteName)); + msg.AddTos(message.ToEmails.Select(e => new EmailAddress(CoreHelpers.PunyEncode(e))).ToList()); + if (message.BccEmails?.Any() ?? false) { - // TODO: nothing to dispose + msg.AddBccs(message.BccEmails.Select(e => new EmailAddress(CoreHelpers.PunyEncode(e))).ToList()); } - public SendGridMailDeliveryService( - ISendGridClient client, - GlobalSettings globalSettings, - IWebHostEnvironment hostingEnvironment, - ILogger logger) - { - if (string.IsNullOrWhiteSpace(globalSettings.Mail?.SendGridApiKey)) - { - throw new ArgumentNullException(nameof(globalSettings.Mail.SendGridApiKey)); - } + msg.SetSubject(message.Subject); + msg.AddContent(MimeType.Text, message.TextContent); + msg.AddContent(MimeType.Html, message.HtmlContent); - _globalSettings = globalSettings; - _hostingEnvironment = hostingEnvironment; - _logger = logger; - _client = client; - _senderTag = $"Server_{globalSettings.ProjectName?.Replace(' ', '_')}"; - _replyToEmail = CoreHelpers.PunyEncode(globalSettings.Mail.ReplyToEmail); + msg.AddCategory($"type:{message.Category}"); + msg.AddCategory($"env:{_hostingEnvironment.EnvironmentName}"); + msg.AddCategory($"sender:{_senderTag}"); + + msg.SetClickTracking(false, false); + msg.SetOpenTracking(false); + + if (message.MetaData != null && + message.MetaData.ContainsKey("SendGridBypassListManagement") && + Convert.ToBoolean(message.MetaData["SendGridBypassListManagement"])) + { + msg.SetBypassListManagement(true); } - public async Task SendEmailAsync(MailMessage message) + try { - var msg = new SendGridMessage(); - msg.SetFrom(new EmailAddress(_replyToEmail, _globalSettings.SiteName)); - msg.AddTos(message.ToEmails.Select(e => new EmailAddress(CoreHelpers.PunyEncode(e))).ToList()); - if (message.BccEmails?.Any() ?? false) + var success = await SendAsync(msg, false); + if (!success) { - msg.AddBccs(message.BccEmails.Select(e => new EmailAddress(CoreHelpers.PunyEncode(e))).ToList()); - } - - msg.SetSubject(message.Subject); - msg.AddContent(MimeType.Text, message.TextContent); - msg.AddContent(MimeType.Html, message.HtmlContent); - - msg.AddCategory($"type:{message.Category}"); - msg.AddCategory($"env:{_hostingEnvironment.EnvironmentName}"); - msg.AddCategory($"sender:{_senderTag}"); - - msg.SetClickTracking(false, false); - msg.SetOpenTracking(false); - - if (message.MetaData != null && - message.MetaData.ContainsKey("SendGridBypassListManagement") && - Convert.ToBoolean(message.MetaData["SendGridBypassListManagement"])) - { - msg.SetBypassListManagement(true); - } - - try - { - var success = await SendAsync(msg, false); - if (!success) - { - _logger.LogWarning("Failed to send email. Retrying..."); - await SendAsync(msg, true); - } - } - catch (Exception e) - { - _logger.LogWarning(e, "Failed to send email (with exception). Retrying..."); + _logger.LogWarning("Failed to send email. Retrying..."); await SendAsync(msg, true); - throw; } } - - private async Task SendAsync(SendGridMessage message, bool retry) + catch (Exception e) { - if (retry) - { - // wait and try again - await Task.Delay(2000); - } - - var response = await _client.SendEmailAsync(message); - if (!response.IsSuccessStatusCode) - { - var responseBody = await response.Body.ReadAsStringAsync(); - _logger.LogError("SendGrid email sending failed with {0}: {1}", response.StatusCode, responseBody); - } - return response.IsSuccessStatusCode; + _logger.LogWarning(e, "Failed to send email (with exception). Retrying..."); + await SendAsync(msg, true); + throw; } } + + private async Task SendAsync(SendGridMessage message, bool retry) + { + if (retry) + { + // wait and try again + await Task.Delay(2000); + } + + var response = await _client.SendEmailAsync(message); + if (!response.IsSuccessStatusCode) + { + var responseBody = await response.Body.ReadAsStringAsync(); + _logger.LogError("SendGrid email sending failed with {0}: {1}", response.StatusCode, responseBody); + } + return response.IsSuccessStatusCode; + } } diff --git a/src/Core/Services/Implementations/SendService.cs b/src/Core/Services/Implementations/SendService.cs index 6c41c6d9c..f16f2da41 100644 --- a/src/Core/Services/Implementations/SendService.cs +++ b/src/Core/Services/Implementations/SendService.cs @@ -11,330 +11,329 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.AspNetCore.Identity; -namespace Bit.Core.Services -{ - public class SendService : ISendService - { - public const long MAX_FILE_SIZE = Constants.FileSize501mb; - public const string MAX_FILE_SIZE_READABLE = "500 MB"; - private readonly ISendRepository _sendRepository; - private readonly IUserRepository _userRepository; - private readonly IPolicyRepository _policyRepository; - private readonly IUserService _userService; - private readonly IOrganizationRepository _organizationRepository; - private readonly ISendFileStorageService _sendFileStorageService; - private readonly IPasswordHasher _passwordHasher; - private readonly IPushNotificationService _pushService; - private readonly IReferenceEventService _referenceEventService; - private readonly GlobalSettings _globalSettings; - private readonly ICurrentContext _currentContext; - private const long _fileSizeLeeway = 1024L * 1024L; // 1MB +namespace Bit.Core.Services; - public SendService( - ISendRepository sendRepository, - IUserRepository userRepository, - IUserService userService, - IOrganizationRepository organizationRepository, - ISendFileStorageService sendFileStorageService, - IPasswordHasher passwordHasher, - IPushNotificationService pushService, - IReferenceEventService referenceEventService, - GlobalSettings globalSettings, - IPolicyRepository policyRepository, - ICurrentContext currentContext) +public class SendService : ISendService +{ + public const long MAX_FILE_SIZE = Constants.FileSize501mb; + public const string MAX_FILE_SIZE_READABLE = "500 MB"; + private readonly ISendRepository _sendRepository; + private readonly IUserRepository _userRepository; + private readonly IPolicyRepository _policyRepository; + private readonly IUserService _userService; + private readonly IOrganizationRepository _organizationRepository; + private readonly ISendFileStorageService _sendFileStorageService; + private readonly IPasswordHasher _passwordHasher; + private readonly IPushNotificationService _pushService; + private readonly IReferenceEventService _referenceEventService; + private readonly GlobalSettings _globalSettings; + private readonly ICurrentContext _currentContext; + private const long _fileSizeLeeway = 1024L * 1024L; // 1MB + + public SendService( + ISendRepository sendRepository, + IUserRepository userRepository, + IUserService userService, + IOrganizationRepository organizationRepository, + ISendFileStorageService sendFileStorageService, + IPasswordHasher passwordHasher, + IPushNotificationService pushService, + IReferenceEventService referenceEventService, + GlobalSettings globalSettings, + IPolicyRepository policyRepository, + ICurrentContext currentContext) + { + _sendRepository = sendRepository; + _userRepository = userRepository; + _userService = userService; + _policyRepository = policyRepository; + _organizationRepository = organizationRepository; + _sendFileStorageService = sendFileStorageService; + _passwordHasher = passwordHasher; + _pushService = pushService; + _referenceEventService = referenceEventService; + _globalSettings = globalSettings; + _currentContext = currentContext; + } + + public async Task SaveSendAsync(Send send) + { + // Make sure user can save Sends + await ValidateUserCanSaveAsync(send.UserId, send); + + if (send.Id == default(Guid)) { - _sendRepository = sendRepository; - _userRepository = userRepository; - _userService = userService; - _policyRepository = policyRepository; - _organizationRepository = organizationRepository; - _sendFileStorageService = sendFileStorageService; - _passwordHasher = passwordHasher; - _pushService = pushService; - _referenceEventService = referenceEventService; - _globalSettings = globalSettings; - _currentContext = currentContext; + await _sendRepository.CreateAsync(send); + await _pushService.PushSyncSendCreateAsync(send); + await RaiseReferenceEventAsync(send, ReferenceEventType.SendCreated); + } + else + { + send.RevisionDate = DateTime.UtcNow; + await _sendRepository.UpsertAsync(send); + await _pushService.PushSyncSendUpdateAsync(send); + } + } + + public async Task SaveFileSendAsync(Send send, SendFileData data, long fileLength) + { + if (send.Type != SendType.File) + { + throw new BadRequestException("Send is not of type \"file\"."); } - public async Task SaveSendAsync(Send send) + if (fileLength < 1) { - // Make sure user can save Sends - await ValidateUserCanSaveAsync(send.UserId, send); + throw new BadRequestException("No file data."); + } - if (send.Id == default(Guid)) + var storageBytesRemaining = await StorageRemainingForSendAsync(send); + + if (storageBytesRemaining < fileLength) + { + throw new BadRequestException("Not enough storage available."); + } + + var fileId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); + + try + { + data.Id = fileId; + data.Size = fileLength; + data.Validated = false; + send.Data = JsonSerializer.Serialize(data, + JsonHelpers.IgnoreWritingNull); + await SaveSendAsync(send); + return await _sendFileStorageService.GetSendFileUploadUrlAsync(send, fileId); + } + catch + { + // Clean up since this is not transactional + await _sendFileStorageService.DeleteFileAsync(send, fileId); + throw; + } + } + + public async Task UploadFileToExistingSendAsync(Stream stream, Send send) + { + if (send?.Data == null) + { + throw new BadRequestException("Send does not have file data"); + } + + if (send.Type != SendType.File) + { + throw new BadRequestException("Not a File Type Send."); + } + + var data = JsonSerializer.Deserialize(send.Data); + + if (data.Validated) + { + throw new BadRequestException("File has already been uploaded."); + } + + await _sendFileStorageService.UploadNewFileAsync(stream, send, data.Id); + + if (!await ValidateSendFile(send)) + { + throw new BadRequestException("File received does not match expected file length."); + } + } + + public async Task ValidateSendFile(Send send) + { + var fileData = JsonSerializer.Deserialize(send.Data); + + var (valid, realSize) = await _sendFileStorageService.ValidateFileAsync(send, fileData.Id, fileData.Size, _fileSizeLeeway); + + if (!valid || realSize > MAX_FILE_SIZE) + { + // File reported differs in size from that promised. Must be a rogue client. Delete Send + await DeleteSendAsync(send); + return false; + } + + // Update Send data if necessary + if (realSize != fileData.Size) + { + fileData.Size = realSize.Value; + } + fileData.Validated = true; + send.Data = JsonSerializer.Serialize(fileData, + JsonHelpers.IgnoreWritingNull); + await SaveSendAsync(send); + + return valid; + } + + public async Task DeleteSendAsync(Send send) + { + await _sendRepository.DeleteAsync(send); + if (send.Type == Enums.SendType.File) + { + var data = JsonSerializer.Deserialize(send.Data); + await _sendFileStorageService.DeleteFileAsync(send, data.Id); + } + await _pushService.PushSyncSendDeleteAsync(send); + } + + public (bool grant, bool passwordRequiredError, bool passwordInvalidError) SendCanBeAccessed(Send send, + string password) + { + var now = DateTime.UtcNow; + if (send == null || send.MaxAccessCount.GetValueOrDefault(int.MaxValue) <= send.AccessCount || + send.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < now || send.Disabled || + send.DeletionDate < now) + { + return (false, false, false); + } + if (!string.IsNullOrWhiteSpace(send.Password)) + { + if (string.IsNullOrWhiteSpace(password)) { - await _sendRepository.CreateAsync(send); - await _pushService.PushSyncSendCreateAsync(send); - await RaiseReferenceEventAsync(send, ReferenceEventType.SendCreated); + return (false, true, false); + } + var passwordResult = _passwordHasher.VerifyHashedPassword(new User(), send.Password, password); + if (passwordResult == PasswordVerificationResult.SuccessRehashNeeded) + { + send.Password = HashPassword(password); + } + if (passwordResult == PasswordVerificationResult.Failed) + { + return (false, false, true); + } + } + + return (true, false, false); + } + + // Response: Send, password required, password invalid + public async Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password) + { + if (send.Type != SendType.File) + { + throw new BadRequestException("Can only get a download URL for a file type of Send"); + } + + var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password); + + if (!grantAccess) + { + return (null, passwordRequired, passwordInvalid); + } + + send.AccessCount++; + await _sendRepository.ReplaceAsync(send); + await _pushService.PushSyncSendUpdateAsync(send); + return (await _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId), false, false); + } + + // Response: Send, password required, password invalid + public async Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password) + { + var send = await _sendRepository.GetByIdAsync(sendId); + var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password); + + if (!grantAccess) + { + return (null, passwordRequired, passwordInvalid); + } + + // TODO: maybe move this to a simple ++ sproc? + if (send.Type != SendType.File) + { + // File sends are incremented during file download + send.AccessCount++; + } + + await _sendRepository.ReplaceAsync(send); + await _pushService.PushSyncSendUpdateAsync(send); + await RaiseReferenceEventAsync(send, ReferenceEventType.SendAccessed); + return (send, false, false); + } + + private async Task RaiseReferenceEventAsync(Send send, ReferenceEventType eventType) + { + await _referenceEventService.RaiseEventAsync(new ReferenceEvent + { + Id = send.UserId ?? default, + Type = eventType, + Source = ReferenceEventSource.User, + SendType = send.Type, + MaxAccessCount = send.MaxAccessCount, + HasPassword = !string.IsNullOrWhiteSpace(send.Password), + }); + } + + public string HashPassword(string password) + { + return _passwordHasher.HashPassword(new User(), password); + } + + private async Task ValidateUserCanSaveAsync(Guid? userId, Send send) + { + if (!userId.HasValue || (!_currentContext.Organizations?.Any() ?? true)) + { + return; + } + + var disableSendPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(userId.Value, + PolicyType.DisableSend); + if (disableSendPolicyCount > 0) + { + throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send."); + } + + if (send.HideEmail.GetValueOrDefault()) + { + var sendOptionsPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(userId.Value, PolicyType.SendOptions); + if (sendOptionsPolicies.Any(p => p.GetDataModel()?.DisableHideEmail ?? false)) + { + throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send."); + } + } + } + + private async Task StorageRemainingForSendAsync(Send send) + { + var storageBytesRemaining = 0L; + if (send.UserId.HasValue) + { + var user = await _userRepository.GetByIdAsync(send.UserId.Value); + if (!await _userService.CanAccessPremium(user)) + { + throw new BadRequestException("You must have premium status to use file Sends."); + } + + if (!user.EmailVerified) + { + throw new BadRequestException("You must confirm your email to use file Sends."); + } + + if (user.Premium) + { + storageBytesRemaining = user.StorageBytesRemaining(); } else { - send.RevisionDate = DateTime.UtcNow; - await _sendRepository.UpsertAsync(send); - await _pushService.PushSyncSendUpdateAsync(send); + // Users that get access to file storage/premium from their organization get the default + // 1 GB max storage. + storageBytesRemaining = user.StorageBytesRemaining( + _globalSettings.SelfHosted ? (short)10240 : (short)1); } } - - public async Task SaveFileSendAsync(Send send, SendFileData data, long fileLength) + else if (send.OrganizationId.HasValue) { - if (send.Type != SendType.File) + var org = await _organizationRepository.GetByIdAsync(send.OrganizationId.Value); + if (!org.MaxStorageGb.HasValue) { - throw new BadRequestException("Send is not of type \"file\"."); + throw new BadRequestException("This organization cannot use file sends."); } - if (fileLength < 1) - { - throw new BadRequestException("No file data."); - } - - var storageBytesRemaining = await StorageRemainingForSendAsync(send); - - if (storageBytesRemaining < fileLength) - { - throw new BadRequestException("Not enough storage available."); - } - - var fileId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); - - try - { - data.Id = fileId; - data.Size = fileLength; - data.Validated = false; - send.Data = JsonSerializer.Serialize(data, - JsonHelpers.IgnoreWritingNull); - await SaveSendAsync(send); - return await _sendFileStorageService.GetSendFileUploadUrlAsync(send, fileId); - } - catch - { - // Clean up since this is not transactional - await _sendFileStorageService.DeleteFileAsync(send, fileId); - throw; - } + storageBytesRemaining = org.StorageBytesRemaining(); } - public async Task UploadFileToExistingSendAsync(Stream stream, Send send) - { - if (send?.Data == null) - { - throw new BadRequestException("Send does not have file data"); - } - - if (send.Type != SendType.File) - { - throw new BadRequestException("Not a File Type Send."); - } - - var data = JsonSerializer.Deserialize(send.Data); - - if (data.Validated) - { - throw new BadRequestException("File has already been uploaded."); - } - - await _sendFileStorageService.UploadNewFileAsync(stream, send, data.Id); - - if (!await ValidateSendFile(send)) - { - throw new BadRequestException("File received does not match expected file length."); - } - } - - public async Task ValidateSendFile(Send send) - { - var fileData = JsonSerializer.Deserialize(send.Data); - - var (valid, realSize) = await _sendFileStorageService.ValidateFileAsync(send, fileData.Id, fileData.Size, _fileSizeLeeway); - - if (!valid || realSize > MAX_FILE_SIZE) - { - // File reported differs in size from that promised. Must be a rogue client. Delete Send - await DeleteSendAsync(send); - return false; - } - - // Update Send data if necessary - if (realSize != fileData.Size) - { - fileData.Size = realSize.Value; - } - fileData.Validated = true; - send.Data = JsonSerializer.Serialize(fileData, - JsonHelpers.IgnoreWritingNull); - await SaveSendAsync(send); - - return valid; - } - - public async Task DeleteSendAsync(Send send) - { - await _sendRepository.DeleteAsync(send); - if (send.Type == Enums.SendType.File) - { - var data = JsonSerializer.Deserialize(send.Data); - await _sendFileStorageService.DeleteFileAsync(send, data.Id); - } - await _pushService.PushSyncSendDeleteAsync(send); - } - - public (bool grant, bool passwordRequiredError, bool passwordInvalidError) SendCanBeAccessed(Send send, - string password) - { - var now = DateTime.UtcNow; - if (send == null || send.MaxAccessCount.GetValueOrDefault(int.MaxValue) <= send.AccessCount || - send.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < now || send.Disabled || - send.DeletionDate < now) - { - return (false, false, false); - } - if (!string.IsNullOrWhiteSpace(send.Password)) - { - if (string.IsNullOrWhiteSpace(password)) - { - return (false, true, false); - } - var passwordResult = _passwordHasher.VerifyHashedPassword(new User(), send.Password, password); - if (passwordResult == PasswordVerificationResult.SuccessRehashNeeded) - { - send.Password = HashPassword(password); - } - if (passwordResult == PasswordVerificationResult.Failed) - { - return (false, false, true); - } - } - - return (true, false, false); - } - - // Response: Send, password required, password invalid - public async Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password) - { - if (send.Type != SendType.File) - { - throw new BadRequestException("Can only get a download URL for a file type of Send"); - } - - var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password); - - if (!grantAccess) - { - return (null, passwordRequired, passwordInvalid); - } - - send.AccessCount++; - await _sendRepository.ReplaceAsync(send); - await _pushService.PushSyncSendUpdateAsync(send); - return (await _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId), false, false); - } - - // Response: Send, password required, password invalid - public async Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password) - { - var send = await _sendRepository.GetByIdAsync(sendId); - var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password); - - if (!grantAccess) - { - return (null, passwordRequired, passwordInvalid); - } - - // TODO: maybe move this to a simple ++ sproc? - if (send.Type != SendType.File) - { - // File sends are incremented during file download - send.AccessCount++; - } - - await _sendRepository.ReplaceAsync(send); - await _pushService.PushSyncSendUpdateAsync(send); - await RaiseReferenceEventAsync(send, ReferenceEventType.SendAccessed); - return (send, false, false); - } - - private async Task RaiseReferenceEventAsync(Send send, ReferenceEventType eventType) - { - await _referenceEventService.RaiseEventAsync(new ReferenceEvent - { - Id = send.UserId ?? default, - Type = eventType, - Source = ReferenceEventSource.User, - SendType = send.Type, - MaxAccessCount = send.MaxAccessCount, - HasPassword = !string.IsNullOrWhiteSpace(send.Password), - }); - } - - public string HashPassword(string password) - { - return _passwordHasher.HashPassword(new User(), password); - } - - private async Task ValidateUserCanSaveAsync(Guid? userId, Send send) - { - if (!userId.HasValue || (!_currentContext.Organizations?.Any() ?? true)) - { - return; - } - - var disableSendPolicyCount = await _policyRepository.GetCountByTypeApplicableToUserIdAsync(userId.Value, - PolicyType.DisableSend); - if (disableSendPolicyCount > 0) - { - throw new BadRequestException("Due to an Enterprise Policy, you are only able to delete an existing Send."); - } - - if (send.HideEmail.GetValueOrDefault()) - { - var sendOptionsPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(userId.Value, PolicyType.SendOptions); - if (sendOptionsPolicies.Any(p => p.GetDataModel()?.DisableHideEmail ?? false)) - { - throw new BadRequestException("Due to an Enterprise Policy, you are not allowed to hide your email address from recipients when creating or editing a Send."); - } - } - } - - private async Task StorageRemainingForSendAsync(Send send) - { - var storageBytesRemaining = 0L; - if (send.UserId.HasValue) - { - var user = await _userRepository.GetByIdAsync(send.UserId.Value); - if (!await _userService.CanAccessPremium(user)) - { - throw new BadRequestException("You must have premium status to use file Sends."); - } - - if (!user.EmailVerified) - { - throw new BadRequestException("You must confirm your email to use file Sends."); - } - - if (user.Premium) - { - storageBytesRemaining = user.StorageBytesRemaining(); - } - else - { - // Users that get access to file storage/premium from their organization get the default - // 1 GB max storage. - storageBytesRemaining = user.StorageBytesRemaining( - _globalSettings.SelfHosted ? (short)10240 : (short)1); - } - } - else if (send.OrganizationId.HasValue) - { - var org = await _organizationRepository.GetByIdAsync(send.OrganizationId.Value); - if (!org.MaxStorageGb.HasValue) - { - throw new BadRequestException("This organization cannot use file sends."); - } - - storageBytesRemaining = org.StorageBytesRemaining(); - } - - return storageBytesRemaining; - } + return storageBytesRemaining; } } diff --git a/src/Core/Services/Implementations/SsoConfigService.cs b/src/Core/Services/Implementations/SsoConfigService.cs index 4af794967..5f44cb931 100644 --- a/src/Core/Services/Implementations/SsoConfigService.cs +++ b/src/Core/Services/Implementations/SsoConfigService.cs @@ -3,105 +3,104 @@ using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class SsoConfigService : ISsoConfigService { - public class SsoConfigService : ISsoConfigService + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly IPolicyRepository _policyRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IEventService _eventService; + + public SsoConfigService( + ISsoConfigRepository ssoConfigRepository, + IPolicyRepository policyRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IEventService eventService) { - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly IPolicyRepository _policyRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IEventService _eventService; + _ssoConfigRepository = ssoConfigRepository; + _policyRepository = policyRepository; + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + _eventService = eventService; + } - public SsoConfigService( - ISsoConfigRepository ssoConfigRepository, - IPolicyRepository policyRepository, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IEventService eventService) + public async Task SaveAsync(SsoConfig config, Organization organization) + { + var now = DateTime.UtcNow; + config.RevisionDate = now; + if (config.Id == default) { - _ssoConfigRepository = ssoConfigRepository; - _policyRepository = policyRepository; - _organizationRepository = organizationRepository; - _organizationUserRepository = organizationUserRepository; - _eventService = eventService; + config.CreationDate = now; } - public async Task SaveAsync(SsoConfig config, Organization organization) + var useKeyConnector = config.GetData().KeyConnectorEnabled; + if (useKeyConnector) { - var now = DateTime.UtcNow; - config.RevisionDate = now; - if (config.Id == default) - { - config.CreationDate = now; - } - - var useKeyConnector = config.GetData().KeyConnectorEnabled; - if (useKeyConnector) - { - await VerifyDependenciesAsync(config, organization); - } - - var oldConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(config.OrganizationId); - var disabledKeyConnector = oldConfig?.GetData()?.KeyConnectorEnabled == true && !useKeyConnector; - if (disabledKeyConnector && await AnyOrgUserHasKeyConnectorEnabledAsync(config.OrganizationId)) - { - throw new BadRequestException("Key Connector cannot be disabled at this moment."); - } - - await LogEventsAsync(config, oldConfig); - await _ssoConfigRepository.UpsertAsync(config); + await VerifyDependenciesAsync(config, organization); } - private async Task AnyOrgUserHasKeyConnectorEnabledAsync(Guid organizationId) + var oldConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(config.OrganizationId); + var disabledKeyConnector = oldConfig?.GetData()?.KeyConnectorEnabled == true && !useKeyConnector; + if (disabledKeyConnector && await AnyOrgUserHasKeyConnectorEnabledAsync(config.OrganizationId)) { - var userDetails = - await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); - return userDetails.Any(u => u.UsesKeyConnector); + throw new BadRequestException("Key Connector cannot be disabled at this moment."); } - private async Task VerifyDependenciesAsync(SsoConfig config, Organization organization) + await LogEventsAsync(config, oldConfig); + await _ssoConfigRepository.UpsertAsync(config); + } + + private async Task AnyOrgUserHasKeyConnectorEnabledAsync(Guid organizationId) + { + var userDetails = + await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); + return userDetails.Any(u => u.UsesKeyConnector); + } + + private async Task VerifyDependenciesAsync(SsoConfig config, Organization organization) + { + if (!organization.UseKeyConnector) { - if (!organization.UseKeyConnector) - { - throw new BadRequestException("Organization cannot use Key Connector."); - } - - var singleOrgPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.SingleOrg); - if (singleOrgPolicy is not { Enabled: true }) - { - throw new BadRequestException("Key Connector requires the Single Organization policy to be enabled."); - } - - var ssoPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.RequireSso); - if (ssoPolicy is not { Enabled: true }) - { - throw new BadRequestException("Key Connector requires the Single Sign-On Authentication policy to be enabled."); - } - - if (!config.Enabled) - { - throw new BadRequestException("You must enable SSO to use Key Connector."); - } + throw new BadRequestException("Organization cannot use Key Connector."); } - private async Task LogEventsAsync(SsoConfig config, SsoConfig oldConfig) + var singleOrgPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.SingleOrg); + if (singleOrgPolicy is not { Enabled: true }) { - var organization = await _organizationRepository.GetByIdAsync(config.OrganizationId); - if (oldConfig?.Enabled != config.Enabled) - { - var e = config.Enabled ? EventType.Organization_EnabledSso : EventType.Organization_DisabledSso; - await _eventService.LogOrganizationEventAsync(organization, e); - } + throw new BadRequestException("Key Connector requires the Single Organization policy to be enabled."); + } - var keyConnectorEnabled = config.GetData().KeyConnectorEnabled; - if (oldConfig?.GetData()?.KeyConnectorEnabled != keyConnectorEnabled) - { - var e = keyConnectorEnabled - ? EventType.Organization_EnabledKeyConnector - : EventType.Organization_DisabledKeyConnector; - await _eventService.LogOrganizationEventAsync(organization, e); - } + var ssoPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.RequireSso); + if (ssoPolicy is not { Enabled: true }) + { + throw new BadRequestException("Key Connector requires the Single Sign-On Authentication policy to be enabled."); + } + + if (!config.Enabled) + { + throw new BadRequestException("You must enable SSO to use Key Connector."); + } + } + + private async Task LogEventsAsync(SsoConfig config, SsoConfig oldConfig) + { + var organization = await _organizationRepository.GetByIdAsync(config.OrganizationId); + if (oldConfig?.Enabled != config.Enabled) + { + var e = config.Enabled ? EventType.Organization_EnabledSso : EventType.Organization_DisabledSso; + await _eventService.LogOrganizationEventAsync(organization, e); + } + + var keyConnectorEnabled = config.GetData().KeyConnectorEnabled; + if (oldConfig?.GetData()?.KeyConnectorEnabled != keyConnectorEnabled) + { + var e = keyConnectorEnabled + ? EventType.Organization_EnabledKeyConnector + : EventType.Organization_DisabledKeyConnector; + await _eventService.LogOrganizationEventAsync(organization, e); } } } diff --git a/src/Core/Services/Implementations/StripeAdapter.cs b/src/Core/Services/Implementations/StripeAdapter.cs index eb467dd57..b4776bc6e 100644 --- a/src/Core/Services/Implementations/StripeAdapter.cs +++ b/src/Core/Services/Implementations/StripeAdapter.cs @@ -1,218 +1,217 @@ using Bit.Core.Models.BitStripe; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class StripeAdapter : IStripeAdapter { - public class StripeAdapter : IStripeAdapter + private readonly Stripe.CustomerService _customerService; + private readonly Stripe.SubscriptionService _subscriptionService; + private readonly Stripe.InvoiceService _invoiceService; + private readonly Stripe.PaymentMethodService _paymentMethodService; + private readonly Stripe.TaxRateService _taxRateService; + private readonly Stripe.TaxIdService _taxIdService; + private readonly Stripe.ChargeService _chargeService; + private readonly Stripe.RefundService _refundService; + private readonly Stripe.CardService _cardService; + private readonly Stripe.BankAccountService _bankAccountService; + private readonly Stripe.PriceService _priceService; + private readonly Stripe.TestHelpers.TestClockService _testClockService; + + public StripeAdapter() { - private readonly Stripe.CustomerService _customerService; - private readonly Stripe.SubscriptionService _subscriptionService; - private readonly Stripe.InvoiceService _invoiceService; - private readonly Stripe.PaymentMethodService _paymentMethodService; - private readonly Stripe.TaxRateService _taxRateService; - private readonly Stripe.TaxIdService _taxIdService; - private readonly Stripe.ChargeService _chargeService; - private readonly Stripe.RefundService _refundService; - private readonly Stripe.CardService _cardService; - private readonly Stripe.BankAccountService _bankAccountService; - private readonly Stripe.PriceService _priceService; - private readonly Stripe.TestHelpers.TestClockService _testClockService; + _customerService = new Stripe.CustomerService(); + _subscriptionService = new Stripe.SubscriptionService(); + _invoiceService = new Stripe.InvoiceService(); + _paymentMethodService = new Stripe.PaymentMethodService(); + _taxRateService = new Stripe.TaxRateService(); + _taxIdService = new Stripe.TaxIdService(); + _chargeService = new Stripe.ChargeService(); + _refundService = new Stripe.RefundService(); + _cardService = new Stripe.CardService(); + _bankAccountService = new Stripe.BankAccountService(); + _priceService = new Stripe.PriceService(); + _testClockService = new Stripe.TestHelpers.TestClockService(); + } - public StripeAdapter() + public Task CustomerCreateAsync(Stripe.CustomerCreateOptions options) + { + return _customerService.CreateAsync(options); + } + + public Task CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null) + { + return _customerService.GetAsync(id, options); + } + + public Task CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null) + { + return _customerService.UpdateAsync(id, options); + } + + public Task CustomerDeleteAsync(string id) + { + return _customerService.DeleteAsync(id); + } + + public Task SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions options) + { + return _subscriptionService.CreateAsync(options); + } + + public Task SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null) + { + return _subscriptionService.GetAsync(id, options); + } + + public Task SubscriptionUpdateAsync(string id, + Stripe.SubscriptionUpdateOptions options = null) + { + return _subscriptionService.UpdateAsync(id, options); + } + + public Task SubscriptionCancelAsync(string Id, Stripe.SubscriptionCancelOptions options = null) + { + return _subscriptionService.CancelAsync(Id, options); + } + + public Task InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options) + { + return _invoiceService.UpcomingAsync(options); + } + + public Task InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options) + { + return _invoiceService.GetAsync(id, options); + } + + public Task> InvoiceListAsync(Stripe.InvoiceListOptions options) + { + return _invoiceService.ListAsync(options); + } + + public Task InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options) + { + return _invoiceService.UpdateAsync(id, options); + } + + public Task InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options) + { + return _invoiceService.FinalizeInvoiceAsync(id, options); + } + + public Task InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options) + { + return _invoiceService.SendInvoiceAsync(id, options); + } + + public Task InvoicePayAsync(string id, Stripe.InvoicePayOptions options = null) + { + return _invoiceService.PayAsync(id, options); + } + + public Task InvoiceDeleteAsync(string id, Stripe.InvoiceDeleteOptions options = null) + { + return _invoiceService.DeleteAsync(id, options); + } + + public Task InvoiceVoidInvoiceAsync(string id, Stripe.InvoiceVoidOptions options = null) + { + return _invoiceService.VoidInvoiceAsync(id, options); + } + + public IEnumerable PaymentMethodListAutoPaging(Stripe.PaymentMethodListOptions options) + { + return _paymentMethodService.ListAutoPaging(options); + } + + public Task PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null) + { + return _paymentMethodService.AttachAsync(id, options); + } + + public Task PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null) + { + return _paymentMethodService.DetachAsync(id, options); + } + + public Task TaxRateCreateAsync(Stripe.TaxRateCreateOptions options) + { + return _taxRateService.CreateAsync(options); + } + + public Task TaxRateUpdateAsync(string id, Stripe.TaxRateUpdateOptions options) + { + return _taxRateService.UpdateAsync(id, options); + } + + public Task TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options) + { + return _taxIdService.CreateAsync(id, options); + } + + public Task TaxIdDeleteAsync(string customerId, string taxIdId, + Stripe.TaxIdDeleteOptions options = null) + { + return _taxIdService.DeleteAsync(customerId, taxIdId); + } + + public Task> ChargeListAsync(Stripe.ChargeListOptions options) + { + return _chargeService.ListAsync(options); + } + + public Task RefundCreateAsync(Stripe.RefundCreateOptions options) + { + return _refundService.CreateAsync(options); + } + + public Task CardDeleteAsync(string customerId, string cardId, Stripe.CardDeleteOptions options = null) + { + return _cardService.DeleteAsync(customerId, cardId, options); + } + + public Task BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null) + { + return _bankAccountService.CreateAsync(customerId, options); + } + + public Task BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null) + { + return _bankAccountService.DeleteAsync(customerId, bankAccount, options); + } + + public async Task> SubscriptionListAsync(StripeSubscriptionListOptions options) + { + if (!options.SelectAll) { - _customerService = new Stripe.CustomerService(); - _subscriptionService = new Stripe.SubscriptionService(); - _invoiceService = new Stripe.InvoiceService(); - _paymentMethodService = new Stripe.PaymentMethodService(); - _taxRateService = new Stripe.TaxRateService(); - _taxIdService = new Stripe.TaxIdService(); - _chargeService = new Stripe.ChargeService(); - _refundService = new Stripe.RefundService(); - _cardService = new Stripe.CardService(); - _bankAccountService = new Stripe.BankAccountService(); - _priceService = new Stripe.PriceService(); - _testClockService = new Stripe.TestHelpers.TestClockService(); + return (await _subscriptionService.ListAsync(options.ToStripeApiOptions())).Data; } - public Task CustomerCreateAsync(Stripe.CustomerCreateOptions options) + options.Limit = 100; + var items = new List(); + await foreach (var i in _subscriptionService.ListAutoPagingAsync(options.ToStripeApiOptions())) { - return _customerService.CreateAsync(options); + items.Add(i); } + return items; + } - public Task CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null) + public async Task> PriceListAsync(Stripe.PriceListOptions options = null) + { + return await _priceService.ListAsync(options); + } + + public async Task> TestClockListAsync() + { + var items = new List(); + var options = new Stripe.TestHelpers.TestClockListOptions() { - return _customerService.GetAsync(id, options); - } - - public Task CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null) + Limit = 100 + }; + await foreach (var i in _testClockService.ListAutoPagingAsync(options)) { - return _customerService.UpdateAsync(id, options); - } - - public Task CustomerDeleteAsync(string id) - { - return _customerService.DeleteAsync(id); - } - - public Task SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions options) - { - return _subscriptionService.CreateAsync(options); - } - - public Task SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null) - { - return _subscriptionService.GetAsync(id, options); - } - - public Task SubscriptionUpdateAsync(string id, - Stripe.SubscriptionUpdateOptions options = null) - { - return _subscriptionService.UpdateAsync(id, options); - } - - public Task SubscriptionCancelAsync(string Id, Stripe.SubscriptionCancelOptions options = null) - { - return _subscriptionService.CancelAsync(Id, options); - } - - public Task InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options) - { - return _invoiceService.UpcomingAsync(options); - } - - public Task InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options) - { - return _invoiceService.GetAsync(id, options); - } - - public Task> InvoiceListAsync(Stripe.InvoiceListOptions options) - { - return _invoiceService.ListAsync(options); - } - - public Task InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options) - { - return _invoiceService.UpdateAsync(id, options); - } - - public Task InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options) - { - return _invoiceService.FinalizeInvoiceAsync(id, options); - } - - public Task InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options) - { - return _invoiceService.SendInvoiceAsync(id, options); - } - - public Task InvoicePayAsync(string id, Stripe.InvoicePayOptions options = null) - { - return _invoiceService.PayAsync(id, options); - } - - public Task InvoiceDeleteAsync(string id, Stripe.InvoiceDeleteOptions options = null) - { - return _invoiceService.DeleteAsync(id, options); - } - - public Task InvoiceVoidInvoiceAsync(string id, Stripe.InvoiceVoidOptions options = null) - { - return _invoiceService.VoidInvoiceAsync(id, options); - } - - public IEnumerable PaymentMethodListAutoPaging(Stripe.PaymentMethodListOptions options) - { - return _paymentMethodService.ListAutoPaging(options); - } - - public Task PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null) - { - return _paymentMethodService.AttachAsync(id, options); - } - - public Task PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null) - { - return _paymentMethodService.DetachAsync(id, options); - } - - public Task TaxRateCreateAsync(Stripe.TaxRateCreateOptions options) - { - return _taxRateService.CreateAsync(options); - } - - public Task TaxRateUpdateAsync(string id, Stripe.TaxRateUpdateOptions options) - { - return _taxRateService.UpdateAsync(id, options); - } - - public Task TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options) - { - return _taxIdService.CreateAsync(id, options); - } - - public Task TaxIdDeleteAsync(string customerId, string taxIdId, - Stripe.TaxIdDeleteOptions options = null) - { - return _taxIdService.DeleteAsync(customerId, taxIdId); - } - - public Task> ChargeListAsync(Stripe.ChargeListOptions options) - { - return _chargeService.ListAsync(options); - } - - public Task RefundCreateAsync(Stripe.RefundCreateOptions options) - { - return _refundService.CreateAsync(options); - } - - public Task CardDeleteAsync(string customerId, string cardId, Stripe.CardDeleteOptions options = null) - { - return _cardService.DeleteAsync(customerId, cardId, options); - } - - public Task BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null) - { - return _bankAccountService.CreateAsync(customerId, options); - } - - public Task BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null) - { - return _bankAccountService.DeleteAsync(customerId, bankAccount, options); - } - - public async Task> SubscriptionListAsync(StripeSubscriptionListOptions options) - { - if (!options.SelectAll) - { - return (await _subscriptionService.ListAsync(options.ToStripeApiOptions())).Data; - } - - options.Limit = 100; - var items = new List(); - await foreach (var i in _subscriptionService.ListAutoPagingAsync(options.ToStripeApiOptions())) - { - items.Add(i); - } - return items; - } - - public async Task> PriceListAsync(Stripe.PriceListOptions options = null) - { - return await _priceService.ListAsync(options); - } - - public async Task> TestClockListAsync() - { - var items = new List(); - var options = new Stripe.TestHelpers.TestClockListOptions() - { - Limit = 100 - }; - await foreach (var i in _testClockService.ListAutoPagingAsync(options)) - { - items.Add(i); - } - return items; + items.Add(i); } + return items; } } diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index e3ed16368..25561db4b 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -8,74 +8,388 @@ using Microsoft.Extensions.Logging; using StaticStore = Bit.Core.Models.StaticStore; using TaxRate = Bit.Core.Entities.TaxRate; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class StripePaymentService : IPaymentService { - public class StripePaymentService : IPaymentService + private const string PremiumPlanId = "premium-annually"; + private const string PremiumPlanAppleIapId = "premium-annually-appleiap"; + private const decimal PremiumPlanAppleIapPrice = 14.99M; + private const string StoragePlanId = "storage-gb-annually"; + + private readonly ITransactionRepository _transactionRepository; + private readonly IUserRepository _userRepository; + private readonly IAppleIapService _appleIapService; + private readonly ILogger _logger; + private readonly Braintree.IBraintreeGateway _btGateway; + private readonly ITaxRateRepository _taxRateRepository; + private readonly IStripeAdapter _stripeAdapter; + + public StripePaymentService( + ITransactionRepository transactionRepository, + IUserRepository userRepository, + IAppleIapService appleIapService, + ILogger logger, + ITaxRateRepository taxRateRepository, + IStripeAdapter stripeAdapter, + Braintree.IBraintreeGateway braintreeGateway) { - private const string PremiumPlanId = "premium-annually"; - private const string PremiumPlanAppleIapId = "premium-annually-appleiap"; - private const decimal PremiumPlanAppleIapPrice = 14.99M; - private const string StoragePlanId = "storage-gb-annually"; + _transactionRepository = transactionRepository; + _userRepository = userRepository; + _appleIapService = appleIapService; + _logger = logger; + _taxRateRepository = taxRateRepository; + _stripeAdapter = stripeAdapter; + _btGateway = braintreeGateway; + } - private readonly ITransactionRepository _transactionRepository; - private readonly IUserRepository _userRepository; - private readonly IAppleIapService _appleIapService; - private readonly ILogger _logger; - private readonly Braintree.IBraintreeGateway _btGateway; - private readonly ITaxRateRepository _taxRateRepository; - private readonly IStripeAdapter _stripeAdapter; + public async Task PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType, + string paymentToken, StaticStore.Plan plan, short additionalStorageGb, + int additionalSeats, bool premiumAccessAddon, TaxInfo taxInfo) + { + Braintree.Customer braintreeCustomer = null; + string stipeCustomerSourceToken = null; + string stipeCustomerPaymentMethodId = null; + var stripeCustomerMetadata = new Dictionary(); + var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card || + paymentMethodType == PaymentMethodType.BankAccount; - public StripePaymentService( - ITransactionRepository transactionRepository, - IUserRepository userRepository, - IAppleIapService appleIapService, - ILogger logger, - ITaxRateRepository taxRateRepository, - IStripeAdapter stripeAdapter, - Braintree.IBraintreeGateway braintreeGateway) + if (stripePaymentMethod && !string.IsNullOrWhiteSpace(paymentToken)) { - _transactionRepository = transactionRepository; - _userRepository = userRepository; - _appleIapService = appleIapService; - _logger = logger; - _taxRateRepository = taxRateRepository; - _stripeAdapter = stripeAdapter; - _btGateway = braintreeGateway; + if (paymentToken.StartsWith("pm_")) + { + stipeCustomerPaymentMethodId = paymentToken; + } + else + { + stipeCustomerSourceToken = paymentToken; + } + } + else if (paymentMethodType == PaymentMethodType.PayPal) + { + var randomSuffix = Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false); + var customerResult = await _btGateway.Customer.CreateAsync(new Braintree.CustomerRequest + { + PaymentMethodNonce = paymentToken, + Email = org.BillingEmail, + Id = org.BraintreeCustomerIdPrefix() + org.Id.ToString("N").ToLower() + randomSuffix, + CustomFields = new Dictionary + { + [org.BraintreeIdField()] = org.Id.ToString() + } + }); + + if (!customerResult.IsSuccess() || customerResult.Target.PaymentMethods.Length == 0) + { + throw new GatewayException("Failed to create PayPal customer record."); + } + + braintreeCustomer = customerResult.Target; + stripeCustomerMetadata.Add("btCustomerId", braintreeCustomer.Id); + } + else + { + throw new GatewayException("Payment method is not supported at this time."); } - public async Task PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType, - string paymentToken, StaticStore.Plan plan, short additionalStorageGb, - int additionalSeats, bool premiumAccessAddon, TaxInfo taxInfo) + if (taxInfo != null && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressCountry) && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressPostalCode)) { - Braintree.Customer braintreeCustomer = null; - string stipeCustomerSourceToken = null; - string stipeCustomerPaymentMethodId = null; - var stripeCustomerMetadata = new Dictionary(); - var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card || - paymentMethodType == PaymentMethodType.BankAccount; - - if (stripePaymentMethod && !string.IsNullOrWhiteSpace(paymentToken)) + var taxRateSearch = new TaxRate { - if (paymentToken.StartsWith("pm_")) + Country = taxInfo.BillingAddressCountry, + PostalCode = taxInfo.BillingAddressPostalCode + }; + var taxRates = await _taxRateRepository.GetByLocationAsync(taxRateSearch); + + // should only be one tax rate per country/zip combo + var taxRate = taxRates.FirstOrDefault(); + if (taxRate != null) + { + taxInfo.StripeTaxRateId = taxRate.Id; + } + } + + var subCreateOptions = new OrganizationPurchaseSubscriptionOptions(org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon); + + Stripe.Customer customer = null; + Stripe.Subscription subscription; + try + { + customer = await _stripeAdapter.CustomerCreateAsync(new Stripe.CustomerCreateOptions + { + Description = org.BusinessName, + Email = org.BillingEmail, + Source = stipeCustomerSourceToken, + PaymentMethod = stipeCustomerPaymentMethodId, + Metadata = stripeCustomerMetadata, + InvoiceSettings = new Stripe.CustomerInvoiceSettingsOptions { - stipeCustomerPaymentMethodId = paymentToken; - } - else + DefaultPaymentMethod = stipeCustomerPaymentMethodId + }, + Address = new Stripe.AddressOptions { - stipeCustomerSourceToken = paymentToken; + Country = taxInfo.BillingAddressCountry, + PostalCode = taxInfo.BillingAddressPostalCode, + // Line1 is required in Stripe's API, suggestion in Docs is to use Business Name intead. + Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, + Line2 = taxInfo.BillingAddressLine2, + City = taxInfo.BillingAddressCity, + State = taxInfo.BillingAddressState, + }, + TaxIdData = !taxInfo.HasTaxId ? null : new List + { + new Stripe.CustomerTaxIdDataOptions + { + Type = taxInfo.TaxIdType, + Value = taxInfo.TaxIdNumber, + }, + }, + }); + subCreateOptions.AddExpand("latest_invoice.payment_intent"); + subCreateOptions.Customer = customer.Id; + subscription = await _stripeAdapter.SubscriptionCreateAsync(subCreateOptions); + if (subscription.Status == "incomplete" && subscription.LatestInvoice?.PaymentIntent != null) + { + if (subscription.LatestInvoice.PaymentIntent.Status == "requires_payment_method") + { + await _stripeAdapter.SubscriptionCancelAsync(subscription.Id, new Stripe.SubscriptionCancelOptions()); + throw new GatewayException("Payment method was declined."); } } - else if (paymentMethodType == PaymentMethodType.PayPal) + } + catch (Exception ex) + { + _logger.LogError(ex, "Error creating customer, walking back operation."); + if (customer != null) + { + await _stripeAdapter.CustomerDeleteAsync(customer.Id); + } + if (braintreeCustomer != null) + { + await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id); + } + throw; + } + + org.Gateway = GatewayType.Stripe; + org.GatewayCustomerId = customer.Id; + org.GatewaySubscriptionId = subscription.Id; + + if (subscription.Status == "incomplete" && + subscription.LatestInvoice?.PaymentIntent?.Status == "requires_action") + { + org.Enabled = false; + return subscription.LatestInvoice.PaymentIntent.ClientSecret; + } + else + { + org.Enabled = true; + org.ExpirationDate = subscription.CurrentPeriodEnd; + return null; + } + } + + private async Task ChangeOrganizationSponsorship(Organization org, OrganizationSponsorship sponsorship, bool applySponsorship) + { + var existingPlan = Utilities.StaticStore.GetPlan(org.PlanType); + var sponsoredPlan = sponsorship != null ? + Utilities.StaticStore.GetSponsoredPlan(sponsorship.PlanSponsorshipType.Value) : + null; + var subscriptionUpdate = new SponsorOrganizationSubscriptionUpdate(existingPlan, sponsoredPlan, applySponsorship); + + await FinalizeSubscriptionChangeAsync(org, subscriptionUpdate, DateTime.UtcNow); + + var sub = await _stripeAdapter.SubscriptionGetAsync(org.GatewaySubscriptionId); + org.ExpirationDate = sub.CurrentPeriodEnd; + sponsorship.ValidUntil = sub.CurrentPeriodEnd; + + } + + public Task SponsorOrganizationAsync(Organization org, OrganizationSponsorship sponsorship) => + ChangeOrganizationSponsorship(org, sponsorship, true); + + public Task RemoveOrganizationSponsorshipAsync(Organization org, OrganizationSponsorship sponsorship) => + ChangeOrganizationSponsorship(org, sponsorship, false); + + public async Task UpgradeFreeOrganizationAsync(Organization org, StaticStore.Plan plan, + short additionalStorageGb, int additionalSeats, bool premiumAccessAddon, TaxInfo taxInfo) + { + if (!string.IsNullOrWhiteSpace(org.GatewaySubscriptionId)) + { + throw new BadRequestException("Organization already has a subscription."); + } + + var customerOptions = new Stripe.CustomerGetOptions(); + customerOptions.AddExpand("default_source"); + customerOptions.AddExpand("invoice_settings.default_payment_method"); + var customer = await _stripeAdapter.CustomerGetAsync(org.GatewayCustomerId, customerOptions); + if (customer == null) + { + throw new GatewayException("Could not find customer payment profile."); + } + + if (taxInfo != null && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressCountry) && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressPostalCode)) + { + var taxRateSearch = new TaxRate + { + Country = taxInfo.BillingAddressCountry, + PostalCode = taxInfo.BillingAddressPostalCode + }; + var taxRates = await _taxRateRepository.GetByLocationAsync(taxRateSearch); + + // should only be one tax rate per country/zip combo + var taxRate = taxRates.FirstOrDefault(); + if (taxRate != null) + { + taxInfo.StripeTaxRateId = taxRate.Id; + } + } + + var subCreateOptions = new OrganizationUpgradeSubscriptionOptions(customer.Id, org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon); + var (stripePaymentMethod, paymentMethodType) = IdentifyPaymentMethod(customer, subCreateOptions); + + var subscription = await ChargeForNewSubscriptionAsync(org, customer, false, + stripePaymentMethod, paymentMethodType, subCreateOptions, null); + org.GatewaySubscriptionId = subscription.Id; + + if (subscription.Status == "incomplete" && + subscription.LatestInvoice?.PaymentIntent?.Status == "requires_action") + { + org.Enabled = false; + return subscription.LatestInvoice.PaymentIntent.ClientSecret; + } + else + { + org.Enabled = true; + org.ExpirationDate = subscription.CurrentPeriodEnd; + return null; + } + } + + private (bool stripePaymentMethod, PaymentMethodType PaymentMethodType) IdentifyPaymentMethod( + Stripe.Customer customer, Stripe.SubscriptionCreateOptions subCreateOptions) + { + var stripePaymentMethod = false; + var paymentMethodType = PaymentMethodType.Credit; + var hasBtCustomerId = customer.Metadata.ContainsKey("btCustomerId"); + if (hasBtCustomerId) + { + paymentMethodType = PaymentMethodType.PayPal; + } + else + { + if (customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card") + { + paymentMethodType = PaymentMethodType.Card; + stripePaymentMethod = true; + } + else if (customer.DefaultSource != null) + { + if (customer.DefaultSource is Stripe.Card || customer.DefaultSource is Stripe.SourceCard) + { + paymentMethodType = PaymentMethodType.Card; + stripePaymentMethod = true; + } + else if (customer.DefaultSource is Stripe.BankAccount || customer.DefaultSource is Stripe.SourceAchDebit) + { + paymentMethodType = PaymentMethodType.BankAccount; + stripePaymentMethod = true; + } + } + else + { + var paymentMethod = GetLatestCardPaymentMethod(customer.Id); + if (paymentMethod != null) + { + paymentMethodType = PaymentMethodType.Card; + stripePaymentMethod = true; + subCreateOptions.DefaultPaymentMethod = paymentMethod.Id; + } + } + } + return (stripePaymentMethod, paymentMethodType); + } + + public async Task PurchasePremiumAsync(User user, PaymentMethodType paymentMethodType, + string paymentToken, short additionalStorageGb, TaxInfo taxInfo) + { + if (paymentMethodType != PaymentMethodType.Credit && string.IsNullOrWhiteSpace(paymentToken)) + { + throw new BadRequestException("Payment token is required."); + } + if (paymentMethodType == PaymentMethodType.Credit && + (user.Gateway != GatewayType.Stripe || string.IsNullOrWhiteSpace(user.GatewayCustomerId))) + { + throw new BadRequestException("Your account does not have any credit available."); + } + if (paymentMethodType == PaymentMethodType.BankAccount || paymentMethodType == PaymentMethodType.GoogleInApp) + { + throw new GatewayException("Payment method is not supported at this time."); + } + if ((paymentMethodType == PaymentMethodType.GoogleInApp || + paymentMethodType == PaymentMethodType.AppleInApp) && additionalStorageGb > 0) + { + throw new BadRequestException("You cannot add storage with this payment method."); + } + + var createdStripeCustomer = false; + Stripe.Customer customer = null; + Braintree.Customer braintreeCustomer = null; + var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card || + paymentMethodType == PaymentMethodType.BankAccount || paymentMethodType == PaymentMethodType.Credit; + + string stipeCustomerPaymentMethodId = null; + string stipeCustomerSourceToken = null; + if (stripePaymentMethod && !string.IsNullOrWhiteSpace(paymentToken)) + { + if (paymentToken.StartsWith("pm_")) + { + stipeCustomerPaymentMethodId = paymentToken; + } + else + { + stipeCustomerSourceToken = paymentToken; + } + } + + if (user.Gateway == GatewayType.Stripe && !string.IsNullOrWhiteSpace(user.GatewayCustomerId)) + { + if (!string.IsNullOrWhiteSpace(paymentToken)) + { + try + { + await UpdatePaymentMethodAsync(user, paymentMethodType, paymentToken, true, taxInfo); + } + catch (Exception e) + { + var message = e.Message.ToLowerInvariant(); + if (message.Contains("apple") || message.Contains("in-app")) + { + throw; + } + } + } + try + { + customer = await _stripeAdapter.CustomerGetAsync(user.GatewayCustomerId); + } + catch { } + } + + if (customer == null && !string.IsNullOrWhiteSpace(paymentToken)) + { + var stripeCustomerMetadata = new Dictionary(); + if (paymentMethodType == PaymentMethodType.PayPal) { var randomSuffix = Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false); var customerResult = await _btGateway.Customer.CreateAsync(new Braintree.CustomerRequest { PaymentMethodNonce = paymentToken, - Email = org.BillingEmail, - Id = org.BraintreeCustomerIdPrefix() + org.Id.ToString("N").ToLower() + randomSuffix, + Email = user.Email, + Id = user.BraintreeCustomerIdPrefix() + user.Id.ToString("N").ToLower() + randomSuffix, CustomFields = new Dictionary { - [org.BraintreeIdField()] = org.Id.ToString() + [user.BraintreeIdField()] = user.Id.ToString() } }); @@ -87,1701 +401,1386 @@ namespace Bit.Core.Services braintreeCustomer = customerResult.Target; stripeCustomerMetadata.Add("btCustomerId", braintreeCustomer.Id); } - else + else if (paymentMethodType == PaymentMethodType.AppleInApp) + { + var verifiedReceiptStatus = await _appleIapService.GetVerifiedReceiptStatusAsync(paymentToken); + if (verifiedReceiptStatus == null) + { + throw new GatewayException("Cannot verify apple in-app purchase."); + } + var receiptOriginalTransactionId = verifiedReceiptStatus.GetOriginalTransactionId(); + await VerifyAppleReceiptNotInUseAsync(receiptOriginalTransactionId, user); + await _appleIapService.SaveReceiptAsync(verifiedReceiptStatus, user.Id); + stripeCustomerMetadata.Add("appleReceipt", receiptOriginalTransactionId); + } + else if (!stripePaymentMethod) { throw new GatewayException("Payment method is not supported at this time."); } - if (taxInfo != null && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressCountry) && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressPostalCode)) + customer = await _stripeAdapter.CustomerCreateAsync(new Stripe.CustomerCreateOptions { - var taxRateSearch = new TaxRate + Description = user.Name, + Email = user.Email, + Metadata = stripeCustomerMetadata, + PaymentMethod = stipeCustomerPaymentMethodId, + Source = stipeCustomerSourceToken, + InvoiceSettings = new Stripe.CustomerInvoiceSettingsOptions + { + DefaultPaymentMethod = stipeCustomerPaymentMethodId + }, + Address = new Stripe.AddressOptions + { + Line1 = string.Empty, + Country = taxInfo.BillingAddressCountry, + PostalCode = taxInfo.BillingAddressPostalCode, + }, + }); + createdStripeCustomer = true; + } + + if (customer == null) + { + throw new GatewayException("Could not set up customer payment profile."); + } + + var subCreateOptions = new Stripe.SubscriptionCreateOptions + { + Customer = customer.Id, + Items = new List(), + Metadata = new Dictionary + { + [user.GatewayIdField()] = user.Id.ToString() + } + }; + + subCreateOptions.Items.Add(new Stripe.SubscriptionItemOptions + { + Plan = paymentMethodType == PaymentMethodType.AppleInApp ? PremiumPlanAppleIapId : PremiumPlanId, + Quantity = 1, + }); + + if (!string.IsNullOrWhiteSpace(taxInfo?.BillingAddressCountry) + && !string.IsNullOrWhiteSpace(taxInfo?.BillingAddressPostalCode)) + { + var taxRates = await _taxRateRepository.GetByLocationAsync( + new TaxRate() { Country = taxInfo.BillingAddressCountry, PostalCode = taxInfo.BillingAddressPostalCode + } + ); + var taxRate = taxRates.FirstOrDefault(); + if (taxRate != null) + { + subCreateOptions.DefaultTaxRates = new List(1) + { + taxRate.Id }; - var taxRates = await _taxRateRepository.GetByLocationAsync(taxRateSearch); - - // should only be one tax rate per country/zip combo - var taxRate = taxRates.FirstOrDefault(); - if (taxRate != null) - { - taxInfo.StripeTaxRateId = taxRate.Id; - } } + } - var subCreateOptions = new OrganizationPurchaseSubscriptionOptions(org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon); - - Stripe.Customer customer = null; - Stripe.Subscription subscription; - try + if (additionalStorageGb > 0) + { + subCreateOptions.Items.Add(new Stripe.SubscriptionItemOptions { - customer = await _stripeAdapter.CustomerCreateAsync(new Stripe.CustomerCreateOptions + Plan = StoragePlanId, + Quantity = additionalStorageGb + }); + } + + var subscription = await ChargeForNewSubscriptionAsync(user, customer, createdStripeCustomer, + stripePaymentMethod, paymentMethodType, subCreateOptions, braintreeCustomer); + + user.Gateway = GatewayType.Stripe; + user.GatewayCustomerId = customer.Id; + user.GatewaySubscriptionId = subscription.Id; + + if (subscription.Status == "incomplete" && + subscription.LatestInvoice?.PaymentIntent?.Status == "requires_action") + { + return subscription.LatestInvoice.PaymentIntent.ClientSecret; + } + else + { + user.Premium = true; + user.PremiumExpirationDate = subscription.CurrentPeriodEnd; + return null; + } + } + + private async Task ChargeForNewSubscriptionAsync(ISubscriber subcriber, Stripe.Customer customer, + bool createdStripeCustomer, bool stripePaymentMethod, PaymentMethodType paymentMethodType, + Stripe.SubscriptionCreateOptions subCreateOptions, Braintree.Customer braintreeCustomer) + { + var addedCreditToStripeCustomer = false; + Braintree.Transaction braintreeTransaction = null; + Transaction appleTransaction = null; + + var subInvoiceMetadata = new Dictionary(); + Stripe.Subscription subscription = null; + try + { + if (!stripePaymentMethod) + { + var previewInvoice = await _stripeAdapter.InvoiceUpcomingAsync(new Stripe.UpcomingInvoiceOptions { - Description = org.BusinessName, - Email = org.BillingEmail, - Source = stipeCustomerSourceToken, - PaymentMethod = stipeCustomerPaymentMethodId, - Metadata = stripeCustomerMetadata, - InvoiceSettings = new Stripe.CustomerInvoiceSettingsOptions - { - DefaultPaymentMethod = stipeCustomerPaymentMethodId - }, - Address = new Stripe.AddressOptions - { - Country = taxInfo.BillingAddressCountry, - PostalCode = taxInfo.BillingAddressPostalCode, - // Line1 is required in Stripe's API, suggestion in Docs is to use Business Name intead. - Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, - Line2 = taxInfo.BillingAddressLine2, - City = taxInfo.BillingAddressCity, - State = taxInfo.BillingAddressState, - }, - TaxIdData = !taxInfo.HasTaxId ? null : new List - { - new Stripe.CustomerTaxIdDataOptions - { - Type = taxInfo.TaxIdType, - Value = taxInfo.TaxIdNumber, - }, - }, + Customer = customer.Id, + SubscriptionItems = ToInvoiceSubscriptionItemOptions(subCreateOptions.Items), + SubscriptionDefaultTaxRates = subCreateOptions.DefaultTaxRates, }); - subCreateOptions.AddExpand("latest_invoice.payment_intent"); - subCreateOptions.Customer = customer.Id; - subscription = await _stripeAdapter.SubscriptionCreateAsync(subCreateOptions); - if (subscription.Status == "incomplete" && subscription.LatestInvoice?.PaymentIntent != null) + + if (previewInvoice.AmountDue > 0) { - if (subscription.LatestInvoice.PaymentIntent.Status == "requires_payment_method") + var appleReceiptOrigTransactionId = customer.Metadata != null && + customer.Metadata.ContainsKey("appleReceipt") ? customer.Metadata["appleReceipt"] : null; + var braintreeCustomerId = customer.Metadata != null && + customer.Metadata.ContainsKey("btCustomerId") ? customer.Metadata["btCustomerId"] : null; + if (!string.IsNullOrWhiteSpace(appleReceiptOrigTransactionId)) { - await _stripeAdapter.SubscriptionCancelAsync(subscription.Id, new Stripe.SubscriptionCancelOptions()); - throw new GatewayException("Payment method was declined."); + if (!subcriber.IsUser()) + { + throw new GatewayException("In-app purchase is only allowed for users."); + } + + var appleReceipt = await _appleIapService.GetReceiptAsync( + appleReceiptOrigTransactionId); + var verifiedAppleReceipt = await _appleIapService.GetVerifiedReceiptStatusAsync( + appleReceipt.Item1); + if (verifiedAppleReceipt == null) + { + throw new GatewayException("Failed to get Apple in-app purchase receipt data."); + } + subInvoiceMetadata.Add("appleReceipt", verifiedAppleReceipt.GetOriginalTransactionId()); + var lastTransactionId = verifiedAppleReceipt.GetLastTransactionId(); + subInvoiceMetadata.Add("appleReceiptTransactionId", lastTransactionId); + var existingTransaction = await _transactionRepository.GetByGatewayIdAsync( + GatewayType.AppStore, lastTransactionId); + if (existingTransaction == null) + { + appleTransaction = verifiedAppleReceipt.BuildTransactionFromLastTransaction( + PremiumPlanAppleIapPrice, subcriber.Id); + appleTransaction.Type = TransactionType.Charge; + await _transactionRepository.CreateAsync(appleTransaction); + } } + else if (!string.IsNullOrWhiteSpace(braintreeCustomerId)) + { + var btInvoiceAmount = (previewInvoice.AmountDue / 100M); + var transactionResult = await _btGateway.Transaction.SaleAsync( + new Braintree.TransactionRequest + { + Amount = btInvoiceAmount, + CustomerId = braintreeCustomerId, + Options = new Braintree.TransactionOptionsRequest + { + SubmitForSettlement = true, + PayPal = new Braintree.TransactionOptionsPayPalRequest + { + CustomField = $"{subcriber.BraintreeIdField()}:{subcriber.Id}" + } + }, + CustomFields = new Dictionary + { + [subcriber.BraintreeIdField()] = subcriber.Id.ToString() + } + }); + + if (!transactionResult.IsSuccess()) + { + throw new GatewayException("Failed to charge PayPal customer."); + } + + braintreeTransaction = transactionResult.Target; + subInvoiceMetadata.Add("btTransactionId", braintreeTransaction.Id); + subInvoiceMetadata.Add("btPayPalTransactionId", + braintreeTransaction.PayPalDetails.AuthorizationId); + } + else + { + throw new GatewayException("No payment was able to be collected."); + } + + await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions + { + Balance = customer.Balance - previewInvoice.AmountDue + }); + addedCreditToStripeCustomer = true; } } - catch (Exception ex) + else if (paymentMethodType == PaymentMethodType.Credit) { - _logger.LogError(ex, "Error creating customer, walking back operation."); - if (customer != null) + var previewInvoice = await _stripeAdapter.InvoiceUpcomingAsync(new Stripe.UpcomingInvoiceOptions + { + Customer = customer.Id, + SubscriptionItems = ToInvoiceSubscriptionItemOptions(subCreateOptions.Items), + SubscriptionDefaultTaxRates = subCreateOptions.DefaultTaxRates, + }); + if (previewInvoice.AmountDue > 0) + { + throw new GatewayException("Your account does not have enough credit available."); + } + } + + subCreateOptions.OffSession = true; + subCreateOptions.AddExpand("latest_invoice.payment_intent"); + subscription = await _stripeAdapter.SubscriptionCreateAsync(subCreateOptions); + if (subscription.Status == "incomplete" && subscription.LatestInvoice?.PaymentIntent != null) + { + if (subscription.LatestInvoice.PaymentIntent.Status == "requires_payment_method") + { + await _stripeAdapter.SubscriptionCancelAsync(subscription.Id, new Stripe.SubscriptionCancelOptions()); + throw new GatewayException("Payment method was declined."); + } + } + + if (!stripePaymentMethod && subInvoiceMetadata.Any()) + { + var invoices = await _stripeAdapter.InvoiceListAsync(new Stripe.InvoiceListOptions + { + Subscription = subscription.Id + }); + + var invoice = invoices?.FirstOrDefault(); + if (invoice == null) + { + throw new GatewayException("Invoice not found."); + } + + await _stripeAdapter.InvoiceUpdateAsync(invoice.Id, new Stripe.InvoiceUpdateOptions + { + Metadata = subInvoiceMetadata + }); + } + + return subscription; + } + catch (Exception e) + { + if (customer != null) + { + if (createdStripeCustomer) { await _stripeAdapter.CustomerDeleteAsync(customer.Id); } - if (braintreeCustomer != null) + else if (addedCreditToStripeCustomer || customer.Balance < 0) { - await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id); + await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions + { + Balance = customer.Balance + }); } - throw; + } + if (braintreeTransaction != null) + { + await _btGateway.Transaction.RefundAsync(braintreeTransaction.Id); + } + if (braintreeCustomer != null) + { + await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id); + } + if (appleTransaction != null) + { + await _transactionRepository.DeleteAsync(appleTransaction); } - org.Gateway = GatewayType.Stripe; - org.GatewayCustomerId = customer.Id; - org.GatewaySubscriptionId = subscription.Id; + if (e is Stripe.StripeException strEx && + (strEx.StripeError?.Message?.Contains("cannot be used because it is not verified") ?? false)) + { + throw new GatewayException("Bank account is not yet verified."); + } - if (subscription.Status == "incomplete" && - subscription.LatestInvoice?.PaymentIntent?.Status == "requires_action") - { - org.Enabled = false; - return subscription.LatestInvoice.PaymentIntent.ClientSecret; - } - else - { - org.Enabled = true; - org.ExpirationDate = subscription.CurrentPeriodEnd; - return null; - } + throw; + } + } + + private List ToInvoiceSubscriptionItemOptions( + List subItemOptions) + { + return subItemOptions.Select(si => new Stripe.InvoiceSubscriptionItemOptions + { + Plan = si.Plan, + Quantity = si.Quantity + }).ToList(); + } + + private async Task FinalizeSubscriptionChangeAsync(IStorableSubscriber storableSubscriber, + SubscriptionUpdate subscriptionUpdate, DateTime? prorationDate) + { + // remember, when in doubt, throw + + var sub = await _stripeAdapter.SubscriptionGetAsync(storableSubscriber.GatewaySubscriptionId); + if (sub == null) + { + throw new GatewayException("Subscription not found."); } - private async Task ChangeOrganizationSponsorship(Organization org, OrganizationSponsorship sponsorship, bool applySponsorship) + prorationDate ??= DateTime.UtcNow; + var collectionMethod = sub.CollectionMethod; + var daysUntilDue = sub.DaysUntilDue; + var chargeNow = collectionMethod == "charge_automatically"; + var updatedItemOptions = subscriptionUpdate.UpgradeItemsOptions(sub); + + var subUpdateOptions = new Stripe.SubscriptionUpdateOptions { - var existingPlan = Utilities.StaticStore.GetPlan(org.PlanType); - var sponsoredPlan = sponsorship != null ? - Utilities.StaticStore.GetSponsoredPlan(sponsorship.PlanSponsorshipType.Value) : - null; - var subscriptionUpdate = new SponsorOrganizationSubscriptionUpdate(existingPlan, sponsoredPlan, applySponsorship); - - await FinalizeSubscriptionChangeAsync(org, subscriptionUpdate, DateTime.UtcNow); - - var sub = await _stripeAdapter.SubscriptionGetAsync(org.GatewaySubscriptionId); - org.ExpirationDate = sub.CurrentPeriodEnd; - sponsorship.ValidUntil = sub.CurrentPeriodEnd; + Items = updatedItemOptions, + ProrationBehavior = "always_invoice", + DaysUntilDue = daysUntilDue ?? 1, + CollectionMethod = "send_invoice", + ProrationDate = prorationDate, + }; + if (!subscriptionUpdate.UpdateNeeded(sub)) + { + // No need to update subscription, quantity matches + return null; } - public Task SponsorOrganizationAsync(Organization org, OrganizationSponsorship sponsorship) => - ChangeOrganizationSponsorship(org, sponsorship, true); + var customer = await _stripeAdapter.CustomerGetAsync(sub.CustomerId); - public Task RemoveOrganizationSponsorshipAsync(Organization org, OrganizationSponsorship sponsorship) => - ChangeOrganizationSponsorship(org, sponsorship, false); - - public async Task UpgradeFreeOrganizationAsync(Organization org, StaticStore.Plan plan, - short additionalStorageGb, int additionalSeats, bool premiumAccessAddon, TaxInfo taxInfo) + if (!string.IsNullOrWhiteSpace(customer?.Address?.Country) + && !string.IsNullOrWhiteSpace(customer?.Address?.PostalCode)) { - if (!string.IsNullOrWhiteSpace(org.GatewaySubscriptionId)) - { - throw new BadRequestException("Organization already has a subscription."); - } - - var customerOptions = new Stripe.CustomerGetOptions(); - customerOptions.AddExpand("default_source"); - customerOptions.AddExpand("invoice_settings.default_payment_method"); - var customer = await _stripeAdapter.CustomerGetAsync(org.GatewayCustomerId, customerOptions); - if (customer == null) - { - throw new GatewayException("Could not find customer payment profile."); - } - - if (taxInfo != null && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressCountry) && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressPostalCode)) - { - var taxRateSearch = new TaxRate + var taxRates = await _taxRateRepository.GetByLocationAsync( + new TaxRate() { - Country = taxInfo.BillingAddressCountry, - PostalCode = taxInfo.BillingAddressPostalCode + Country = customer.Address.Country, + PostalCode = customer.Address.PostalCode + } + ); + var taxRate = taxRates.FirstOrDefault(); + if (taxRate != null && !sub.DefaultTaxRates.Any(x => x.Equals(taxRate.Id))) + { + subUpdateOptions.DefaultTaxRates = new List(1) + { + taxRate.Id }; - var taxRates = await _taxRateRepository.GetByLocationAsync(taxRateSearch); - - // should only be one tax rate per country/zip combo - var taxRate = taxRates.FirstOrDefault(); - if (taxRate != null) - { - taxInfo.StripeTaxRateId = taxRate.Id; - } - } - - var subCreateOptions = new OrganizationUpgradeSubscriptionOptions(customer.Id, org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon); - var (stripePaymentMethod, paymentMethodType) = IdentifyPaymentMethod(customer, subCreateOptions); - - var subscription = await ChargeForNewSubscriptionAsync(org, customer, false, - stripePaymentMethod, paymentMethodType, subCreateOptions, null); - org.GatewaySubscriptionId = subscription.Id; - - if (subscription.Status == "incomplete" && - subscription.LatestInvoice?.PaymentIntent?.Status == "requires_action") - { - org.Enabled = false; - return subscription.LatestInvoice.PaymentIntent.ClientSecret; - } - else - { - org.Enabled = true; - org.ExpirationDate = subscription.CurrentPeriodEnd; - return null; } } - private (bool stripePaymentMethod, PaymentMethodType PaymentMethodType) IdentifyPaymentMethod( - Stripe.Customer customer, Stripe.SubscriptionCreateOptions subCreateOptions) + string paymentIntentClientSecret = null; + try { - var stripePaymentMethod = false; - var paymentMethodType = PaymentMethodType.Credit; - var hasBtCustomerId = customer.Metadata.ContainsKey("btCustomerId"); - if (hasBtCustomerId) - { - paymentMethodType = PaymentMethodType.PayPal; - } - else - { - if (customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card") - { - paymentMethodType = PaymentMethodType.Card; - stripePaymentMethod = true; - } - else if (customer.DefaultSource != null) - { - if (customer.DefaultSource is Stripe.Card || customer.DefaultSource is Stripe.SourceCard) - { - paymentMethodType = PaymentMethodType.Card; - stripePaymentMethod = true; - } - else if (customer.DefaultSource is Stripe.BankAccount || customer.DefaultSource is Stripe.SourceAchDebit) - { - paymentMethodType = PaymentMethodType.BankAccount; - stripePaymentMethod = true; - } - } - else - { - var paymentMethod = GetLatestCardPaymentMethod(customer.Id); - if (paymentMethod != null) - { - paymentMethodType = PaymentMethodType.Card; - stripePaymentMethod = true; - subCreateOptions.DefaultPaymentMethod = paymentMethod.Id; - } - } - } - return (stripePaymentMethod, paymentMethodType); - } + var subResponse = await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, subUpdateOptions); - public async Task PurchasePremiumAsync(User user, PaymentMethodType paymentMethodType, - string paymentToken, short additionalStorageGb, TaxInfo taxInfo) - { - if (paymentMethodType != PaymentMethodType.Credit && string.IsNullOrWhiteSpace(paymentToken)) + var invoice = await _stripeAdapter.InvoiceGetAsync(subResponse?.LatestInvoiceId, new Stripe.InvoiceGetOptions()); + if (invoice == null) { - throw new BadRequestException("Payment token is required."); - } - if (paymentMethodType == PaymentMethodType.Credit && - (user.Gateway != GatewayType.Stripe || string.IsNullOrWhiteSpace(user.GatewayCustomerId))) - { - throw new BadRequestException("Your account does not have any credit available."); - } - if (paymentMethodType == PaymentMethodType.BankAccount || paymentMethodType == PaymentMethodType.GoogleInApp) - { - throw new GatewayException("Payment method is not supported at this time."); - } - if ((paymentMethodType == PaymentMethodType.GoogleInApp || - paymentMethodType == PaymentMethodType.AppleInApp) && additionalStorageGb > 0) - { - throw new BadRequestException("You cannot add storage with this payment method."); + throw new BadRequestException("Unable to locate draft invoice for subscription update."); } - var createdStripeCustomer = false; - Stripe.Customer customer = null; - Braintree.Customer braintreeCustomer = null; - var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card || - paymentMethodType == PaymentMethodType.BankAccount || paymentMethodType == PaymentMethodType.Credit; - - string stipeCustomerPaymentMethodId = null; - string stipeCustomerSourceToken = null; - if (stripePaymentMethod && !string.IsNullOrWhiteSpace(paymentToken)) + if (invoice.AmountDue > 0 && updatedItemOptions.Any(i => i.Quantity > 0)) { - if (paymentToken.StartsWith("pm_")) - { - stipeCustomerPaymentMethodId = paymentToken; - } - else - { - stipeCustomerSourceToken = paymentToken; - } - } - - if (user.Gateway == GatewayType.Stripe && !string.IsNullOrWhiteSpace(user.GatewayCustomerId)) - { - if (!string.IsNullOrWhiteSpace(paymentToken)) - { - try - { - await UpdatePaymentMethodAsync(user, paymentMethodType, paymentToken, true, taxInfo); - } - catch (Exception e) - { - var message = e.Message.ToLowerInvariant(); - if (message.Contains("apple") || message.Contains("in-app")) - { - throw; - } - } - } try { - customer = await _stripeAdapter.CustomerGetAsync(user.GatewayCustomerId); - } - catch { } - } - - if (customer == null && !string.IsNullOrWhiteSpace(paymentToken)) - { - var stripeCustomerMetadata = new Dictionary(); - if (paymentMethodType == PaymentMethodType.PayPal) - { - var randomSuffix = Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false); - var customerResult = await _btGateway.Customer.CreateAsync(new Braintree.CustomerRequest + if (chargeNow) { - PaymentMethodNonce = paymentToken, - Email = user.Email, - Id = user.BraintreeCustomerIdPrefix() + user.Id.ToString("N").ToLower() + randomSuffix, - CustomFields = new Dictionary - { - [user.BraintreeIdField()] = user.Id.ToString() - } - }); - - if (!customerResult.IsSuccess() || customerResult.Target.PaymentMethods.Length == 0) - { - throw new GatewayException("Failed to create PayPal customer record."); + paymentIntentClientSecret = await PayInvoiceAfterSubscriptionChangeAsync( + storableSubscriber, invoice); } - - braintreeCustomer = customerResult.Target; - stripeCustomerMetadata.Add("btCustomerId", braintreeCustomer.Id); - } - else if (paymentMethodType == PaymentMethodType.AppleInApp) - { - var verifiedReceiptStatus = await _appleIapService.GetVerifiedReceiptStatusAsync(paymentToken); - if (verifiedReceiptStatus == null) + else { - throw new GatewayException("Cannot verify apple in-app purchase."); - } - var receiptOriginalTransactionId = verifiedReceiptStatus.GetOriginalTransactionId(); - await VerifyAppleReceiptNotInUseAsync(receiptOriginalTransactionId, user); - await _appleIapService.SaveReceiptAsync(verifiedReceiptStatus, user.Id); - stripeCustomerMetadata.Add("appleReceipt", receiptOriginalTransactionId); - } - else if (!stripePaymentMethod) - { - throw new GatewayException("Payment method is not supported at this time."); - } - - customer = await _stripeAdapter.CustomerCreateAsync(new Stripe.CustomerCreateOptions - { - Description = user.Name, - Email = user.Email, - Metadata = stripeCustomerMetadata, - PaymentMethod = stipeCustomerPaymentMethodId, - Source = stipeCustomerSourceToken, - InvoiceSettings = new Stripe.CustomerInvoiceSettingsOptions - { - DefaultPaymentMethod = stipeCustomerPaymentMethodId - }, - Address = new Stripe.AddressOptions - { - Line1 = string.Empty, - Country = taxInfo.BillingAddressCountry, - PostalCode = taxInfo.BillingAddressPostalCode, - }, - }); - createdStripeCustomer = true; - } - - if (customer == null) - { - throw new GatewayException("Could not set up customer payment profile."); - } - - var subCreateOptions = new Stripe.SubscriptionCreateOptions - { - Customer = customer.Id, - Items = new List(), - Metadata = new Dictionary - { - [user.GatewayIdField()] = user.Id.ToString() - } - }; - - subCreateOptions.Items.Add(new Stripe.SubscriptionItemOptions - { - Plan = paymentMethodType == PaymentMethodType.AppleInApp ? PremiumPlanAppleIapId : PremiumPlanId, - Quantity = 1, - }); - - if (!string.IsNullOrWhiteSpace(taxInfo?.BillingAddressCountry) - && !string.IsNullOrWhiteSpace(taxInfo?.BillingAddressPostalCode)) - { - var taxRates = await _taxRateRepository.GetByLocationAsync( - new TaxRate() - { - Country = taxInfo.BillingAddressCountry, - PostalCode = taxInfo.BillingAddressPostalCode - } - ); - var taxRate = taxRates.FirstOrDefault(); - if (taxRate != null) - { - subCreateOptions.DefaultTaxRates = new List(1) - { - taxRate.Id - }; - } - } - - if (additionalStorageGb > 0) - { - subCreateOptions.Items.Add(new Stripe.SubscriptionItemOptions - { - Plan = StoragePlanId, - Quantity = additionalStorageGb - }); - } - - var subscription = await ChargeForNewSubscriptionAsync(user, customer, createdStripeCustomer, - stripePaymentMethod, paymentMethodType, subCreateOptions, braintreeCustomer); - - user.Gateway = GatewayType.Stripe; - user.GatewayCustomerId = customer.Id; - user.GatewaySubscriptionId = subscription.Id; - - if (subscription.Status == "incomplete" && - subscription.LatestInvoice?.PaymentIntent?.Status == "requires_action") - { - return subscription.LatestInvoice.PaymentIntent.ClientSecret; - } - else - { - user.Premium = true; - user.PremiumExpirationDate = subscription.CurrentPeriodEnd; - return null; - } - } - - private async Task ChargeForNewSubscriptionAsync(ISubscriber subcriber, Stripe.Customer customer, - bool createdStripeCustomer, bool stripePaymentMethod, PaymentMethodType paymentMethodType, - Stripe.SubscriptionCreateOptions subCreateOptions, Braintree.Customer braintreeCustomer) - { - var addedCreditToStripeCustomer = false; - Braintree.Transaction braintreeTransaction = null; - Transaction appleTransaction = null; - - var subInvoiceMetadata = new Dictionary(); - Stripe.Subscription subscription = null; - try - { - if (!stripePaymentMethod) - { - var previewInvoice = await _stripeAdapter.InvoiceUpcomingAsync(new Stripe.UpcomingInvoiceOptions - { - Customer = customer.Id, - SubscriptionItems = ToInvoiceSubscriptionItemOptions(subCreateOptions.Items), - SubscriptionDefaultTaxRates = subCreateOptions.DefaultTaxRates, - }); - - if (previewInvoice.AmountDue > 0) - { - var appleReceiptOrigTransactionId = customer.Metadata != null && - customer.Metadata.ContainsKey("appleReceipt") ? customer.Metadata["appleReceipt"] : null; - var braintreeCustomerId = customer.Metadata != null && - customer.Metadata.ContainsKey("btCustomerId") ? customer.Metadata["btCustomerId"] : null; - if (!string.IsNullOrWhiteSpace(appleReceiptOrigTransactionId)) + invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(subResponse.LatestInvoiceId, new Stripe.InvoiceFinalizeOptions { - if (!subcriber.IsUser()) - { - throw new GatewayException("In-app purchase is only allowed for users."); - } - - var appleReceipt = await _appleIapService.GetReceiptAsync( - appleReceiptOrigTransactionId); - var verifiedAppleReceipt = await _appleIapService.GetVerifiedReceiptStatusAsync( - appleReceipt.Item1); - if (verifiedAppleReceipt == null) - { - throw new GatewayException("Failed to get Apple in-app purchase receipt data."); - } - subInvoiceMetadata.Add("appleReceipt", verifiedAppleReceipt.GetOriginalTransactionId()); - var lastTransactionId = verifiedAppleReceipt.GetLastTransactionId(); - subInvoiceMetadata.Add("appleReceiptTransactionId", lastTransactionId); - var existingTransaction = await _transactionRepository.GetByGatewayIdAsync( - GatewayType.AppStore, lastTransactionId); - if (existingTransaction == null) - { - appleTransaction = verifiedAppleReceipt.BuildTransactionFromLastTransaction( - PremiumPlanAppleIapPrice, subcriber.Id); - appleTransaction.Type = TransactionType.Charge; - await _transactionRepository.CreateAsync(appleTransaction); - } - } - else if (!string.IsNullOrWhiteSpace(braintreeCustomerId)) - { - var btInvoiceAmount = (previewInvoice.AmountDue / 100M); - var transactionResult = await _btGateway.Transaction.SaleAsync( - new Braintree.TransactionRequest - { - Amount = btInvoiceAmount, - CustomerId = braintreeCustomerId, - Options = new Braintree.TransactionOptionsRequest - { - SubmitForSettlement = true, - PayPal = new Braintree.TransactionOptionsPayPalRequest - { - CustomField = $"{subcriber.BraintreeIdField()}:{subcriber.Id}" - } - }, - CustomFields = new Dictionary - { - [subcriber.BraintreeIdField()] = subcriber.Id.ToString() - } - }); - - if (!transactionResult.IsSuccess()) - { - throw new GatewayException("Failed to charge PayPal customer."); - } - - braintreeTransaction = transactionResult.Target; - subInvoiceMetadata.Add("btTransactionId", braintreeTransaction.Id); - subInvoiceMetadata.Add("btPayPalTransactionId", - braintreeTransaction.PayPalDetails.AuthorizationId); - } - else - { - throw new GatewayException("No payment was able to be collected."); - } - - await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions - { - Balance = customer.Balance - previewInvoice.AmountDue + AutoAdvance = false, }); - addedCreditToStripeCustomer = true; + await _stripeAdapter.InvoiceSendInvoiceAsync(invoice.Id, new Stripe.InvoiceSendOptions()); + paymentIntentClientSecret = null; } } - else if (paymentMethodType == PaymentMethodType.Credit) - { - var previewInvoice = await _stripeAdapter.InvoiceUpcomingAsync(new Stripe.UpcomingInvoiceOptions - { - Customer = customer.Id, - SubscriptionItems = ToInvoiceSubscriptionItemOptions(subCreateOptions.Items), - SubscriptionDefaultTaxRates = subCreateOptions.DefaultTaxRates, - }); - if (previewInvoice.AmountDue > 0) - { - throw new GatewayException("Your account does not have enough credit available."); - } - } - - subCreateOptions.OffSession = true; - subCreateOptions.AddExpand("latest_invoice.payment_intent"); - subscription = await _stripeAdapter.SubscriptionCreateAsync(subCreateOptions); - if (subscription.Status == "incomplete" && subscription.LatestInvoice?.PaymentIntent != null) - { - if (subscription.LatestInvoice.PaymentIntent.Status == "requires_payment_method") - { - await _stripeAdapter.SubscriptionCancelAsync(subscription.Id, new Stripe.SubscriptionCancelOptions()); - throw new GatewayException("Payment method was declined."); - } - } - - if (!stripePaymentMethod && subInvoiceMetadata.Any()) - { - var invoices = await _stripeAdapter.InvoiceListAsync(new Stripe.InvoiceListOptions - { - Subscription = subscription.Id - }); - - var invoice = invoices?.FirstOrDefault(); - if (invoice == null) - { - throw new GatewayException("Invoice not found."); - } - - await _stripeAdapter.InvoiceUpdateAsync(invoice.Id, new Stripe.InvoiceUpdateOptions - { - Metadata = subInvoiceMetadata - }); - } - - return subscription; - } - catch (Exception e) - { - if (customer != null) - { - if (createdStripeCustomer) - { - await _stripeAdapter.CustomerDeleteAsync(customer.Id); - } - else if (addedCreditToStripeCustomer || customer.Balance < 0) - { - await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions - { - Balance = customer.Balance - }); - } - } - if (braintreeTransaction != null) - { - await _btGateway.Transaction.RefundAsync(braintreeTransaction.Id); - } - if (braintreeCustomer != null) - { - await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id); - } - if (appleTransaction != null) - { - await _transactionRepository.DeleteAsync(appleTransaction); - } - - if (e is Stripe.StripeException strEx && - (strEx.StripeError?.Message?.Contains("cannot be used because it is not verified") ?? false)) - { - throw new GatewayException("Bank account is not yet verified."); - } - - throw; - } - } - - private List ToInvoiceSubscriptionItemOptions( - List subItemOptions) - { - return subItemOptions.Select(si => new Stripe.InvoiceSubscriptionItemOptions - { - Plan = si.Plan, - Quantity = si.Quantity - }).ToList(); - } - - private async Task FinalizeSubscriptionChangeAsync(IStorableSubscriber storableSubscriber, - SubscriptionUpdate subscriptionUpdate, DateTime? prorationDate) - { - // remember, when in doubt, throw - - var sub = await _stripeAdapter.SubscriptionGetAsync(storableSubscriber.GatewaySubscriptionId); - if (sub == null) - { - throw new GatewayException("Subscription not found."); - } - - prorationDate ??= DateTime.UtcNow; - var collectionMethod = sub.CollectionMethod; - var daysUntilDue = sub.DaysUntilDue; - var chargeNow = collectionMethod == "charge_automatically"; - var updatedItemOptions = subscriptionUpdate.UpgradeItemsOptions(sub); - - var subUpdateOptions = new Stripe.SubscriptionUpdateOptions - { - Items = updatedItemOptions, - ProrationBehavior = "always_invoice", - DaysUntilDue = daysUntilDue ?? 1, - CollectionMethod = "send_invoice", - ProrationDate = prorationDate, - }; - - if (!subscriptionUpdate.UpdateNeeded(sub)) - { - // No need to update subscription, quantity matches - return null; - } - - var customer = await _stripeAdapter.CustomerGetAsync(sub.CustomerId); - - if (!string.IsNullOrWhiteSpace(customer?.Address?.Country) - && !string.IsNullOrWhiteSpace(customer?.Address?.PostalCode)) - { - var taxRates = await _taxRateRepository.GetByLocationAsync( - new TaxRate() - { - Country = customer.Address.Country, - PostalCode = customer.Address.PostalCode - } - ); - var taxRate = taxRates.FirstOrDefault(); - if (taxRate != null && !sub.DefaultTaxRates.Any(x => x.Equals(taxRate.Id))) - { - subUpdateOptions.DefaultTaxRates = new List(1) - { - taxRate.Id - }; - } - } - - string paymentIntentClientSecret = null; - try - { - var subResponse = await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, subUpdateOptions); - - var invoice = await _stripeAdapter.InvoiceGetAsync(subResponse?.LatestInvoiceId, new Stripe.InvoiceGetOptions()); - if (invoice == null) - { - throw new BadRequestException("Unable to locate draft invoice for subscription update."); - } - - if (invoice.AmountDue > 0 && updatedItemOptions.Any(i => i.Quantity > 0)) - { - try - { - if (chargeNow) - { - paymentIntentClientSecret = await PayInvoiceAfterSubscriptionChangeAsync( - storableSubscriber, invoice); - } - else - { - invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(subResponse.LatestInvoiceId, new Stripe.InvoiceFinalizeOptions - { - AutoAdvance = false, - }); - await _stripeAdapter.InvoiceSendInvoiceAsync(invoice.Id, new Stripe.InvoiceSendOptions()); - paymentIntentClientSecret = null; - } - } - catch - { - // Need to revert the subscription - await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new Stripe.SubscriptionUpdateOptions - { - Items = subscriptionUpdate.RevertItemsOptions(sub), - // This proration behavior prevents a false "credit" from - // being applied forward to the next month's invoice - ProrationBehavior = "none", - CollectionMethod = collectionMethod, - DaysUntilDue = daysUntilDue, - }); - throw; - } - } - else if (!invoice.Paid) - { - // Pay invoice with no charge to customer this completes the invoice immediately without waiting the scheduled 1h - invoice = await _stripeAdapter.InvoicePayAsync(subResponse.LatestInvoiceId); - paymentIntentClientSecret = null; - } - - } - finally - { - // Change back the subscription collection method and/or days until due - if (collectionMethod != "send_invoice" || daysUntilDue == null) + catch { + // Need to revert the subscription await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new Stripe.SubscriptionUpdateOptions { + Items = subscriptionUpdate.RevertItemsOptions(sub), + // This proration behavior prevents a false "credit" from + // being applied forward to the next month's invoice + ProrationBehavior = "none", CollectionMethod = collectionMethod, DaysUntilDue = daysUntilDue, }); - } - } - - return paymentIntentClientSecret; - } - - public Task AdjustSeatsAsync(Organization organization, StaticStore.Plan plan, int additionalSeats, DateTime? prorationDate = null) - { - return FinalizeSubscriptionChangeAsync(organization, new SeatSubscriptionUpdate(organization, plan, additionalSeats), prorationDate); - } - - public Task AdjustStorageAsync(IStorableSubscriber storableSubscriber, int additionalStorage, - string storagePlanId, DateTime? prorationDate = null) - { - return FinalizeSubscriptionChangeAsync(storableSubscriber, new StorageSubscriptionUpdate(storagePlanId, additionalStorage), prorationDate); - } - - public async Task CancelAndRecoverChargesAsync(ISubscriber subscriber) - { - if (!string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) - { - await _stripeAdapter.SubscriptionCancelAsync(subscriber.GatewaySubscriptionId, - new Stripe.SubscriptionCancelOptions()); - } - - if (string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) - { - return; - } - - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); - if (customer == null) - { - return; - } - - if (customer.Metadata.ContainsKey("btCustomerId")) - { - var transactionRequest = new Braintree.TransactionSearchRequest() - .CustomerId.Is(customer.Metadata["btCustomerId"]); - var transactions = _btGateway.Transaction.Search(transactionRequest); - - if ((transactions?.MaximumCount ?? 0) > 0) - { - var txs = transactions.Cast().Where(c => c.RefundedTransactionId == null); - foreach (var transaction in txs) - { - await _btGateway.Transaction.RefundAsync(transaction.Id); - } - } - - await _btGateway.Customer.DeleteAsync(customer.Metadata["btCustomerId"]); - } - else - { - var charges = await _stripeAdapter.ChargeListAsync(new Stripe.ChargeListOptions - { - Customer = subscriber.GatewayCustomerId - }); - - if (charges?.Data != null) - { - foreach (var charge in charges.Data.Where(c => c.Captured && !c.Refunded)) - { - await _stripeAdapter.RefundCreateAsync(new Stripe.RefundCreateOptions { Charge = charge.Id }); - } - } - } - - await _stripeAdapter.CustomerDeleteAsync(subscriber.GatewayCustomerId); - } - - public async Task PayInvoiceAfterSubscriptionChangeAsync(ISubscriber subscriber, Stripe.Invoice invoice) - { - var customerOptions = new Stripe.CustomerGetOptions(); - customerOptions.AddExpand("default_source"); - customerOptions.AddExpand("invoice_settings.default_payment_method"); - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerOptions); - var usingInAppPaymentMethod = customer.Metadata.ContainsKey("appleReceipt"); - if (usingInAppPaymentMethod) - { - throw new BadRequestException("Cannot perform this action with in-app purchase payment method. " + - "Contact support."); - } - - string paymentIntentClientSecret = null; - - // Invoice them and pay now instead of waiting until Stripe does this automatically. - - string cardPaymentMethodId = null; - if (!customer.Metadata.ContainsKey("btCustomerId")) - { - var hasDefaultCardPaymentMethod = customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card"; - var hasDefaultValidSource = customer.DefaultSource != null && - (customer.DefaultSource is Stripe.Card || customer.DefaultSource is Stripe.BankAccount); - if (!hasDefaultCardPaymentMethod && !hasDefaultValidSource) - { - cardPaymentMethodId = GetLatestCardPaymentMethod(customer.Id)?.Id; - if (cardPaymentMethodId == null) - { - // We're going to delete this draft invoice, it can't be paid - try - { - await _stripeAdapter.InvoiceDeleteAsync(invoice.Id); - } - catch - { - await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, new Stripe.InvoiceFinalizeOptions - { - AutoAdvance = false - }); - await _stripeAdapter.InvoiceVoidInvoiceAsync(invoice.Id); - } - throw new BadRequestException("No payment method is available."); - } - } - } - - Braintree.Transaction braintreeTransaction = null; - try - { - // Finalize the invoice (from Draft) w/o auto-advance so we - // can attempt payment manually. - invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, new Stripe.InvoiceFinalizeOptions - { - AutoAdvance = false, - }); - var invoicePayOptions = new Stripe.InvoicePayOptions - { - PaymentMethod = cardPaymentMethodId, - }; - if (customer?.Metadata?.ContainsKey("btCustomerId") ?? false) - { - invoicePayOptions.PaidOutOfBand = true; - var btInvoiceAmount = (invoice.AmountDue / 100M); - var transactionResult = await _btGateway.Transaction.SaleAsync( - new Braintree.TransactionRequest - { - Amount = btInvoiceAmount, - CustomerId = customer.Metadata["btCustomerId"], - Options = new Braintree.TransactionOptionsRequest - { - SubmitForSettlement = true, - PayPal = new Braintree.TransactionOptionsPayPalRequest - { - CustomField = $"{subscriber.BraintreeIdField()}:{subscriber.Id}" - } - }, - CustomFields = new Dictionary - { - [subscriber.BraintreeIdField()] = subscriber.Id.ToString() - } - }); - - if (!transactionResult.IsSuccess()) - { - throw new GatewayException("Failed to charge PayPal customer."); - } - - braintreeTransaction = transactionResult.Target; - invoice = await _stripeAdapter.InvoiceUpdateAsync(invoice.Id, new Stripe.InvoiceUpdateOptions - { - Metadata = new Dictionary - { - ["btTransactionId"] = braintreeTransaction.Id, - ["btPayPalTransactionId"] = - braintreeTransaction.PayPalDetails.AuthorizationId - }, - }); - invoicePayOptions.PaidOutOfBand = true; - } - - try - { - invoice = await _stripeAdapter.InvoicePayAsync(invoice.Id, invoicePayOptions); - } - catch (Stripe.StripeException e) - { - if (e.HttpStatusCode == System.Net.HttpStatusCode.PaymentRequired && - e.StripeError?.Code == "invoice_payment_intent_requires_action") - { - // SCA required, get intent client secret - var invoiceGetOptions = new Stripe.InvoiceGetOptions(); - invoiceGetOptions.AddExpand("payment_intent"); - invoice = await _stripeAdapter.InvoiceGetAsync(invoice.Id, invoiceGetOptions); - paymentIntentClientSecret = invoice?.PaymentIntent?.ClientSecret; - } - else - { - throw new GatewayException("Unable to pay invoice."); - } - } - } - catch (Exception e) - { - if (braintreeTransaction != null) - { - await _btGateway.Transaction.RefundAsync(braintreeTransaction.Id); - } - if (invoice != null) - { - if (invoice.Status == "paid") - { - // It's apparently paid, so we need to return w/o throwing an exception - return paymentIntentClientSecret; - } - - invoice = await _stripeAdapter.InvoiceVoidInvoiceAsync(invoice.Id, new Stripe.InvoiceVoidOptions()); - - // HACK: Workaround for customer balance credit - if (invoice.StartingBalance < 0) - { - // Customer had a balance applied to this invoice. Since we can't fully trust Stripe to - // credit it back to the customer (even though their docs claim they will), we need to - // check that balance against the current customer balance and determine if it needs to be re-applied - customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerOptions); - - // Assumption: Customer balance should now be $0, otherwise payment would not have failed. - if (customer.Balance == 0) - { - await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions - { - Balance = invoice.StartingBalance - }); - } - } - } - - if (e is Stripe.StripeException strEx && - (strEx.StripeError?.Message?.Contains("cannot be used because it is not verified") ?? false)) - { - throw new GatewayException("Bank account is not yet verified."); - } - - // Let the caller perform any subscription change cleanup - throw; - } - return paymentIntentClientSecret; - } - - public async Task CancelSubscriptionAsync(ISubscriber subscriber, bool endOfPeriod = false, - bool skipInAppPurchaseCheck = false) - { - if (subscriber == null) - { - throw new ArgumentNullException(nameof(subscriber)); - } - - if (string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) - { - throw new GatewayException("No subscription."); - } - - if (!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId) && !skipInAppPurchaseCheck) - { - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); - if (customer.Metadata.ContainsKey("appleReceipt")) - { - throw new BadRequestException("You are required to manage your subscription from the app store."); - } - } - - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); - if (sub == null) - { - throw new GatewayException("Subscription was not found."); - } - - if (sub.CanceledAt.HasValue || sub.Status == "canceled" || sub.Status == "unpaid" || - sub.Status == "incomplete_expired") - { - // Already canceled - return; - } - - try - { - var canceledSub = endOfPeriod ? - await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, - new Stripe.SubscriptionUpdateOptions { CancelAtPeriodEnd = true }) : - await _stripeAdapter.SubscriptionCancelAsync(sub.Id, new Stripe.SubscriptionCancelOptions()); - if (!canceledSub.CanceledAt.HasValue) - { - throw new GatewayException("Unable to cancel subscription."); - } - } - catch (Stripe.StripeException e) - { - if (e.Message != $"No such subscription: {subscriber.GatewaySubscriptionId}") - { throw; } } + else if (!invoice.Paid) + { + // Pay invoice with no charge to customer this completes the invoice immediately without waiting the scheduled 1h + invoice = await _stripeAdapter.InvoicePayAsync(subResponse.LatestInvoiceId); + paymentIntentClientSecret = null; + } + } - - public async Task ReinstateSubscriptionAsync(ISubscriber subscriber) + finally { - if (subscriber == null) + // Change back the subscription collection method and/or days until due + if (collectionMethod != "send_invoice" || daysUntilDue == null) { - throw new ArgumentNullException(nameof(subscriber)); - } - - if (string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) - { - throw new GatewayException("No subscription."); - } - - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); - if (sub == null) - { - throw new GatewayException("Subscription was not found."); - } - - if ((sub.Status != "active" && sub.Status != "trialing" && !sub.Status.StartsWith("incomplete")) || - !sub.CanceledAt.HasValue) - { - throw new GatewayException("Subscription is not marked for cancellation."); - } - - var updatedSub = await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, - new Stripe.SubscriptionUpdateOptions { CancelAtPeriodEnd = false }); - if (updatedSub.CanceledAt.HasValue) - { - throw new GatewayException("Unable to reinstate subscription."); + await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new Stripe.SubscriptionUpdateOptions + { + CollectionMethod = collectionMethod, + DaysUntilDue = daysUntilDue, + }); } } - public async Task UpdatePaymentMethodAsync(ISubscriber subscriber, PaymentMethodType paymentMethodType, - string paymentToken, bool allowInAppPurchases = false, TaxInfo taxInfo = null) + return paymentIntentClientSecret; + } + + public Task AdjustSeatsAsync(Organization organization, StaticStore.Plan plan, int additionalSeats, DateTime? prorationDate = null) + { + return FinalizeSubscriptionChangeAsync(organization, new SeatSubscriptionUpdate(organization, plan, additionalSeats), prorationDate); + } + + public Task AdjustStorageAsync(IStorableSubscriber storableSubscriber, int additionalStorage, + string storagePlanId, DateTime? prorationDate = null) + { + return FinalizeSubscriptionChangeAsync(storableSubscriber, new StorageSubscriptionUpdate(storagePlanId, additionalStorage), prorationDate); + } + + public async Task CancelAndRecoverChargesAsync(ISubscriber subscriber) + { + if (!string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) { - if (subscriber == null) + await _stripeAdapter.SubscriptionCancelAsync(subscriber.GatewaySubscriptionId, + new Stripe.SubscriptionCancelOptions()); + } + + if (string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) + { + return; + } + + var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); + if (customer == null) + { + return; + } + + if (customer.Metadata.ContainsKey("btCustomerId")) + { + var transactionRequest = new Braintree.TransactionSearchRequest() + .CustomerId.Is(customer.Metadata["btCustomerId"]); + var transactions = _btGateway.Transaction.Search(transactionRequest); + + if ((transactions?.MaximumCount ?? 0) > 0) { - throw new ArgumentNullException(nameof(subscriber)); - } - - if (subscriber.Gateway.HasValue && subscriber.Gateway.Value != GatewayType.Stripe) - { - throw new GatewayException("Switching from one payment type to another is not supported. " + - "Contact us for assistance."); - } - - var createdCustomer = false; - AppleReceiptStatus appleReceiptStatus = null; - Braintree.Customer braintreeCustomer = null; - string stipeCustomerSourceToken = null; - string stipeCustomerPaymentMethodId = null; - var stripeCustomerMetadata = new Dictionary(); - var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card || - paymentMethodType == PaymentMethodType.BankAccount; - var inAppPurchase = paymentMethodType == PaymentMethodType.AppleInApp || - paymentMethodType == PaymentMethodType.GoogleInApp; - - Stripe.Customer customer = null; - - if (!allowInAppPurchases && inAppPurchase) - { - throw new GatewayException("In-app purchase payment method is not allowed."); - } - - if (!subscriber.IsUser() && inAppPurchase) - { - throw new GatewayException("In-app purchase payment method is only allowed for users."); - } - - if (!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) - { - var options = new Stripe.CustomerGetOptions(); - options.AddExpand("sources"); - customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, options); - if (customer.Metadata?.Any() ?? false) + var txs = transactions.Cast().Where(c => c.RefundedTransactionId == null); + foreach (var transaction in txs) { - stripeCustomerMetadata = customer.Metadata; + await _btGateway.Transaction.RefundAsync(transaction.Id); } } - if (inAppPurchase && customer != null && customer.Balance != 0) + await _btGateway.Customer.DeleteAsync(customer.Metadata["btCustomerId"]); + } + else + { + var charges = await _stripeAdapter.ChargeListAsync(new Stripe.ChargeListOptions { - throw new GatewayException("Customer balance cannot exist when using in-app purchases."); - } + Customer = subscriber.GatewayCustomerId + }); - if (!inAppPurchase && customer != null && stripeCustomerMetadata.ContainsKey("appleReceipt")) + if (charges?.Data != null) { - throw new GatewayException("Cannot change from in-app payment method. Contact support."); - } - - var hadBtCustomer = stripeCustomerMetadata.ContainsKey("btCustomerId"); - if (stripePaymentMethod) - { - if (paymentToken.StartsWith("pm_")) + foreach (var charge in charges.Data.Where(c => c.Captured && !c.Refunded)) { - stipeCustomerPaymentMethodId = paymentToken; - } - else - { - stipeCustomerSourceToken = paymentToken; + await _stripeAdapter.RefundCreateAsync(new Stripe.RefundCreateOptions { Charge = charge.Id }); } } - else if (paymentMethodType == PaymentMethodType.PayPal) + } + + await _stripeAdapter.CustomerDeleteAsync(subscriber.GatewayCustomerId); + } + + public async Task PayInvoiceAfterSubscriptionChangeAsync(ISubscriber subscriber, Stripe.Invoice invoice) + { + var customerOptions = new Stripe.CustomerGetOptions(); + customerOptions.AddExpand("default_source"); + customerOptions.AddExpand("invoice_settings.default_payment_method"); + var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerOptions); + var usingInAppPaymentMethod = customer.Metadata.ContainsKey("appleReceipt"); + if (usingInAppPaymentMethod) + { + throw new BadRequestException("Cannot perform this action with in-app purchase payment method. " + + "Contact support."); + } + + string paymentIntentClientSecret = null; + + // Invoice them and pay now instead of waiting until Stripe does this automatically. + + string cardPaymentMethodId = null; + if (!customer.Metadata.ContainsKey("btCustomerId")) + { + var hasDefaultCardPaymentMethod = customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card"; + var hasDefaultValidSource = customer.DefaultSource != null && + (customer.DefaultSource is Stripe.Card || customer.DefaultSource is Stripe.BankAccount); + if (!hasDefaultCardPaymentMethod && !hasDefaultValidSource) { - if (hadBtCustomer) + cardPaymentMethodId = GetLatestCardPaymentMethod(customer.Id)?.Id; + if (cardPaymentMethodId == null) { - var pmResult = await _btGateway.PaymentMethod.CreateAsync(new Braintree.PaymentMethodRequest + // We're going to delete this draft invoice, it can't be paid + try { - CustomerId = stripeCustomerMetadata["btCustomerId"], - PaymentMethodNonce = paymentToken - }); - - if (pmResult.IsSuccess()) + await _stripeAdapter.InvoiceDeleteAsync(invoice.Id); + } + catch { - var customerResult = await _btGateway.Customer.UpdateAsync( - stripeCustomerMetadata["btCustomerId"], new Braintree.CustomerRequest + await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, new Stripe.InvoiceFinalizeOptions + { + AutoAdvance = false + }); + await _stripeAdapter.InvoiceVoidInvoiceAsync(invoice.Id); + } + throw new BadRequestException("No payment method is available."); + } + } + } + + Braintree.Transaction braintreeTransaction = null; + try + { + // Finalize the invoice (from Draft) w/o auto-advance so we + // can attempt payment manually. + invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, new Stripe.InvoiceFinalizeOptions + { + AutoAdvance = false, + }); + var invoicePayOptions = new Stripe.InvoicePayOptions + { + PaymentMethod = cardPaymentMethodId, + }; + if (customer?.Metadata?.ContainsKey("btCustomerId") ?? false) + { + invoicePayOptions.PaidOutOfBand = true; + var btInvoiceAmount = (invoice.AmountDue / 100M); + var transactionResult = await _btGateway.Transaction.SaleAsync( + new Braintree.TransactionRequest + { + Amount = btInvoiceAmount, + CustomerId = customer.Metadata["btCustomerId"], + Options = new Braintree.TransactionOptionsRequest + { + SubmitForSettlement = true, + PayPal = new Braintree.TransactionOptionsPayPalRequest { - DefaultPaymentMethodToken = pmResult.Target.Token - }); - - if (customerResult.IsSuccess() && customerResult.Target.PaymentMethods.Length > 0) - { - braintreeCustomer = customerResult.Target; - } - else - { - await _btGateway.PaymentMethod.DeleteAsync(pmResult.Target.Token); - hadBtCustomer = false; - } - } - else - { - hadBtCustomer = false; - } - } - - if (!hadBtCustomer) - { - var customerResult = await _btGateway.Customer.CreateAsync(new Braintree.CustomerRequest - { - PaymentMethodNonce = paymentToken, - Email = subscriber.BillingEmailAddress(), - Id = subscriber.BraintreeCustomerIdPrefix() + subscriber.Id.ToString("N").ToLower() + - Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false), + CustomField = $"{subscriber.BraintreeIdField()}:{subscriber.Id}" + } + }, CustomFields = new Dictionary { [subscriber.BraintreeIdField()] = subscriber.Id.ToString() } }); - if (!customerResult.IsSuccess() || customerResult.Target.PaymentMethods.Length == 0) + if (!transactionResult.IsSuccess()) + { + throw new GatewayException("Failed to charge PayPal customer."); + } + + braintreeTransaction = transactionResult.Target; + invoice = await _stripeAdapter.InvoiceUpdateAsync(invoice.Id, new Stripe.InvoiceUpdateOptions + { + Metadata = new Dictionary { - throw new GatewayException("Failed to create PayPal customer record."); - } - - braintreeCustomer = customerResult.Target; - } - } - else if (paymentMethodType == PaymentMethodType.AppleInApp) - { - appleReceiptStatus = await _appleIapService.GetVerifiedReceiptStatusAsync(paymentToken); - if (appleReceiptStatus == null) - { - throw new GatewayException("Cannot verify Apple in-app purchase."); - } - await VerifyAppleReceiptNotInUseAsync(appleReceiptStatus.GetOriginalTransactionId(), subscriber); - } - else - { - throw new GatewayException("Payment method is not supported at this time."); - } - - if (stripeCustomerMetadata.ContainsKey("btCustomerId")) - { - if (braintreeCustomer?.Id != stripeCustomerMetadata["btCustomerId"]) - { - var nowSec = Utilities.CoreHelpers.ToEpocSeconds(DateTime.UtcNow); - stripeCustomerMetadata.Add($"btCustomerId_{nowSec}", stripeCustomerMetadata["btCustomerId"]); - } - stripeCustomerMetadata["btCustomerId"] = braintreeCustomer?.Id; - } - else if (!string.IsNullOrWhiteSpace(braintreeCustomer?.Id)) - { - stripeCustomerMetadata.Add("btCustomerId", braintreeCustomer.Id); - } - - if (appleReceiptStatus != null) - { - var originalTransactionId = appleReceiptStatus.GetOriginalTransactionId(); - if (stripeCustomerMetadata.ContainsKey("appleReceipt")) - { - if (originalTransactionId != stripeCustomerMetadata["appleReceipt"]) - { - var nowSec = Utilities.CoreHelpers.ToEpocSeconds(DateTime.UtcNow); - stripeCustomerMetadata.Add($"appleReceipt_{nowSec}", stripeCustomerMetadata["appleReceipt"]); - } - stripeCustomerMetadata["appleReceipt"] = originalTransactionId; - } - else - { - stripeCustomerMetadata.Add("appleReceipt", originalTransactionId); - } - await _appleIapService.SaveReceiptAsync(appleReceiptStatus, subscriber.Id); + ["btTransactionId"] = braintreeTransaction.Id, + ["btPayPalTransactionId"] = + braintreeTransaction.PayPalDetails.AuthorizationId + }, + }); + invoicePayOptions.PaidOutOfBand = true; } try { - if (customer == null) + invoice = await _stripeAdapter.InvoicePayAsync(invoice.Id, invoicePayOptions); + } + catch (Stripe.StripeException e) + { + if (e.HttpStatusCode == System.Net.HttpStatusCode.PaymentRequired && + e.StripeError?.Code == "invoice_payment_intent_requires_action") { - customer = await _stripeAdapter.CustomerCreateAsync(new Stripe.CustomerCreateOptions - { - Description = subscriber.BillingName(), - Email = subscriber.BillingEmailAddress(), - Metadata = stripeCustomerMetadata, - Source = stipeCustomerSourceToken, - PaymentMethod = stipeCustomerPaymentMethodId, - InvoiceSettings = new Stripe.CustomerInvoiceSettingsOptions - { - DefaultPaymentMethod = stipeCustomerPaymentMethodId - }, - Address = taxInfo == null ? null : new Stripe.AddressOptions - { - Country = taxInfo.BillingAddressCountry, - PostalCode = taxInfo.BillingAddressPostalCode, - Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, - Line2 = taxInfo.BillingAddressLine2, - City = taxInfo.BillingAddressCity, - State = taxInfo.BillingAddressState, - }, - Expand = new List { "sources" }, - }); - - subscriber.Gateway = GatewayType.Stripe; - subscriber.GatewayCustomerId = customer.Id; - createdCustomer = true; + // SCA required, get intent client secret + var invoiceGetOptions = new Stripe.InvoiceGetOptions(); + invoiceGetOptions.AddExpand("payment_intent"); + invoice = await _stripeAdapter.InvoiceGetAsync(invoice.Id, invoiceGetOptions); + paymentIntentClientSecret = invoice?.PaymentIntent?.ClientSecret; } - - if (!createdCustomer) + else { - string defaultSourceId = null; - string defaultPaymentMethodId = null; - if (stripePaymentMethod) - { - if (!string.IsNullOrWhiteSpace(stipeCustomerSourceToken) && paymentToken.StartsWith("btok_")) - { - var bankAccount = await _stripeAdapter.BankAccountCreateAsync(customer.Id, new Stripe.BankAccountCreateOptions - { - Source = paymentToken - }); - defaultSourceId = bankAccount.Id; - } - else if (!string.IsNullOrWhiteSpace(stipeCustomerPaymentMethodId)) - { - await _stripeAdapter.PaymentMethodAttachAsync(stipeCustomerPaymentMethodId, - new Stripe.PaymentMethodAttachOptions { Customer = customer.Id }); - defaultPaymentMethodId = stipeCustomerPaymentMethodId; - } - } - - if (customer.Sources != null) - { - foreach (var source in customer.Sources.Where(s => s.Id != defaultSourceId)) - { - if (source is Stripe.BankAccount) - { - await _stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id); - } - else if (source is Stripe.Card) - { - await _stripeAdapter.CardDeleteAsync(customer.Id, source.Id); - } - } - } - - var cardPaymentMethods = _stripeAdapter.PaymentMethodListAutoPaging(new Stripe.PaymentMethodListOptions - { - Customer = customer.Id, - Type = "card" - }); - foreach (var cardMethod in cardPaymentMethods.Where(m => m.Id != defaultPaymentMethodId)) - { - await _stripeAdapter.PaymentMethodDetachAsync(cardMethod.Id, new Stripe.PaymentMethodDetachOptions()); - } - - customer = await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions - { - Metadata = stripeCustomerMetadata, - DefaultSource = defaultSourceId, - InvoiceSettings = new Stripe.CustomerInvoiceSettingsOptions - { - DefaultPaymentMethod = defaultPaymentMethodId - }, - Address = taxInfo == null ? null : new Stripe.AddressOptions - { - Country = taxInfo.BillingAddressCountry, - PostalCode = taxInfo.BillingAddressPostalCode, - Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, - Line2 = taxInfo.BillingAddressLine2, - City = taxInfo.BillingAddressCity, - State = taxInfo.BillingAddressState, - }, - }); + throw new GatewayException("Unable to pay invoice."); } } - catch + } + catch (Exception e) + { + if (braintreeTransaction != null) { - if (braintreeCustomer != null && !hadBtCustomer) + await _btGateway.Transaction.RefundAsync(braintreeTransaction.Id); + } + if (invoice != null) + { + if (invoice.Status == "paid") { - await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id); - } - throw; - } - - return createdCustomer; - } - - public async Task CreditAccountAsync(ISubscriber subscriber, decimal creditAmount) - { - Stripe.Customer customer = null; - var customerExists = subscriber.Gateway == GatewayType.Stripe && - !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId); - if (customerExists) - { - customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); - } - else - { - customer = await _stripeAdapter.CustomerCreateAsync(new Stripe.CustomerCreateOptions - { - Email = subscriber.BillingEmailAddress(), - Description = subscriber.BillingName(), - }); - subscriber.Gateway = GatewayType.Stripe; - subscriber.GatewayCustomerId = customer.Id; - } - await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions - { - Balance = customer.Balance - (long)(creditAmount * 100) - }); - return !customerExists; - } - - public async Task GetBillingAsync(ISubscriber subscriber) - { - var customer = await GetCustomerAsync(subscriber.GatewayCustomerId, GetCustomerPaymentOptions()); - var billingInfo = new BillingInfo - { - Balance = GetBillingBalance(customer), - PaymentSource = await GetBillingPaymentSourceAsync(customer), - Invoices = await GetBillingInvoicesAsync(customer), - Transactions = await GetBillingTransactionsAsync(subscriber) - }; - - return billingInfo; - } - - public async Task GetBillingBalanceAndSourceAsync(ISubscriber subscriber) - { - var customer = await GetCustomerAsync(subscriber.GatewayCustomerId, GetCustomerPaymentOptions()); - var billingInfo = new BillingInfo - { - Balance = GetBillingBalance(customer), - PaymentSource = await GetBillingPaymentSourceAsync(customer) - }; - - return billingInfo; - } - - public async Task GetBillingHistoryAsync(ISubscriber subscriber) - { - var customer = await GetCustomerAsync(subscriber.GatewayCustomerId); - var billingInfo = new BillingInfo - { - Transactions = await GetBillingTransactionsAsync(subscriber), - Invoices = await GetBillingInvoicesAsync(customer) - }; - - return billingInfo; - } - - public async Task GetSubscriptionAsync(ISubscriber subscriber) - { - var subscriptionInfo = new SubscriptionInfo(); - - if (subscriber.IsUser() && !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) - { - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); - subscriptionInfo.UsingInAppPurchase = customer.Metadata.ContainsKey("appleReceipt"); - } - - if (!string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) - { - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); - if (sub != null) - { - subscriptionInfo.Subscription = new SubscriptionInfo.BillingSubscription(sub); + // It's apparently paid, so we need to return w/o throwing an exception + return paymentIntentClientSecret; } - if (!sub.CanceledAt.HasValue && !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) + invoice = await _stripeAdapter.InvoiceVoidInvoiceAsync(invoice.Id, new Stripe.InvoiceVoidOptions()); + + // HACK: Workaround for customer balance credit + if (invoice.StartingBalance < 0) { - try + // Customer had a balance applied to this invoice. Since we can't fully trust Stripe to + // credit it back to the customer (even though their docs claim they will), we need to + // check that balance against the current customer balance and determine if it needs to be re-applied + customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerOptions); + + // Assumption: Customer balance should now be $0, otherwise payment would not have failed. + if (customer.Balance == 0) { - var upcomingInvoice = await _stripeAdapter.InvoiceUpcomingAsync( - new Stripe.UpcomingInvoiceOptions { Customer = subscriber.GatewayCustomerId }); - if (upcomingInvoice != null) + await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions { - subscriptionInfo.UpcomingInvoice = - new SubscriptionInfo.BillingUpcomingInvoice(upcomingInvoice); - } - } - catch (Stripe.StripeException) { } - } - } - - return subscriptionInfo; - } - - public async Task GetTaxInfoAsync(ISubscriber subscriber) - { - if (subscriber == null || string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) - { - return null; - } - - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, - new Stripe.CustomerGetOptions { Expand = new List { "tax_ids" } }); - - if (customer == null) - { - return null; - } - - var address = customer.Address; - var taxId = customer.TaxIds?.FirstOrDefault(); - - // Line1 is required, so if missing we're using the subscriber name - // see: https://stripe.com/docs/api/customers/create#create_customer-address-line1 - if (address != null && string.IsNullOrWhiteSpace(address.Line1)) - { - address.Line1 = null; - } - - return new TaxInfo - { - TaxIdNumber = taxId?.Value, - BillingAddressLine1 = address?.Line1, - BillingAddressLine2 = address?.Line2, - BillingAddressCity = address?.City, - BillingAddressState = address?.State, - BillingAddressPostalCode = address?.PostalCode, - BillingAddressCountry = address?.Country, - }; - } - - public async Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo) - { - if (subscriber != null && !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) - { - var customer = await _stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, new Stripe.CustomerUpdateOptions - { - Address = new Stripe.AddressOptions - { - Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, - Line2 = taxInfo.BillingAddressLine2, - City = taxInfo.BillingAddressCity, - State = taxInfo.BillingAddressState, - PostalCode = taxInfo.BillingAddressPostalCode, - Country = taxInfo.BillingAddressCountry, - }, - Expand = new List { "tax_ids" } - }); - - if (!subscriber.IsUser() && customer != null) - { - var taxId = customer.TaxIds?.FirstOrDefault(); - - if (taxId != null) - { - await _stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id); - } - if (!string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber) && - !string.IsNullOrWhiteSpace(taxInfo.TaxIdType)) - { - await _stripeAdapter.TaxIdCreateAsync(customer.Id, new Stripe.TaxIdCreateOptions - { - Type = taxInfo.TaxIdType, - Value = taxInfo.TaxIdNumber, + Balance = invoice.StartingBalance }); } } } - } - public async Task CreateTaxRateAsync(TaxRate taxRate) - { - var stripeTaxRateOptions = new Stripe.TaxRateCreateOptions() + if (e is Stripe.StripeException strEx && + (strEx.StripeError?.Message?.Contains("cannot be used because it is not verified") ?? false)) { - DisplayName = $"{taxRate.Country} - {taxRate.PostalCode}", - Inclusive = false, - Percentage = taxRate.Rate, - Active = true - }; - var stripeTaxRate = await _stripeAdapter.TaxRateCreateAsync(stripeTaxRateOptions); - taxRate.Id = stripeTaxRate.Id; - await _taxRateRepository.CreateAsync(taxRate); - return taxRate; - } - - public async Task UpdateTaxRateAsync(TaxRate taxRate) - { - if (string.IsNullOrWhiteSpace(taxRate.Id)) - { - return; + throw new GatewayException("Bank account is not yet verified."); } - await ArchiveTaxRateAsync(taxRate); - await CreateTaxRateAsync(taxRate); + // Let the caller perform any subscription change cleanup + throw; + } + return paymentIntentClientSecret; + } + + public async Task CancelSubscriptionAsync(ISubscriber subscriber, bool endOfPeriod = false, + bool skipInAppPurchaseCheck = false) + { + if (subscriber == null) + { + throw new ArgumentNullException(nameof(subscriber)); } - public async Task ArchiveTaxRateAsync(TaxRate taxRate) + if (string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) { - if (string.IsNullOrWhiteSpace(taxRate.Id)) - { - return; - } + throw new GatewayException("No subscription."); + } - var updatedStripeTaxRate = await _stripeAdapter.TaxRateUpdateAsync( - taxRate.Id, - new Stripe.TaxRateUpdateOptions() { Active = false } - ); - if (!updatedStripeTaxRate.Active) + if (!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId) && !skipInAppPurchaseCheck) + { + var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); + if (customer.Metadata.ContainsKey("appleReceipt")) { - taxRate.Active = false; - await _taxRateRepository.ArchiveAsync(taxRate); + throw new BadRequestException("You are required to manage your subscription from the app store."); } } - private Stripe.PaymentMethod GetLatestCardPaymentMethod(string customerId) + var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + if (sub == null) { - var cardPaymentMethods = _stripeAdapter.PaymentMethodListAutoPaging( - new Stripe.PaymentMethodListOptions { Customer = customerId, Type = "card" }); - return cardPaymentMethods.OrderByDescending(m => m.Created).FirstOrDefault(); + throw new GatewayException("Subscription was not found."); } - private async Task VerifyAppleReceiptNotInUseAsync(string receiptOriginalTransactionId, ISubscriber subscriber) + if (sub.CanceledAt.HasValue || sub.Status == "canceled" || sub.Status == "unpaid" || + sub.Status == "incomplete_expired") { - var existingReceipt = await _appleIapService.GetReceiptAsync(receiptOriginalTransactionId); - if (existingReceipt != null && existingReceipt.Item2.HasValue && existingReceipt.Item2 != subscriber.Id) + // Already canceled + return; + } + + try + { + var canceledSub = endOfPeriod ? + await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, + new Stripe.SubscriptionUpdateOptions { CancelAtPeriodEnd = true }) : + await _stripeAdapter.SubscriptionCancelAsync(sub.Id, new Stripe.SubscriptionCancelOptions()); + if (!canceledSub.CanceledAt.HasValue) { - var existingUser = await _userRepository.GetByIdAsync(existingReceipt.Item2.Value); - if (existingUser != null) + throw new GatewayException("Unable to cancel subscription."); + } + } + catch (Stripe.StripeException e) + { + if (e.Message != $"No such subscription: {subscriber.GatewaySubscriptionId}") + { + throw; + } + } + } + + public async Task ReinstateSubscriptionAsync(ISubscriber subscriber) + { + if (subscriber == null) + { + throw new ArgumentNullException(nameof(subscriber)); + } + + if (string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) + { + throw new GatewayException("No subscription."); + } + + var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + if (sub == null) + { + throw new GatewayException("Subscription was not found."); + } + + if ((sub.Status != "active" && sub.Status != "trialing" && !sub.Status.StartsWith("incomplete")) || + !sub.CanceledAt.HasValue) + { + throw new GatewayException("Subscription is not marked for cancellation."); + } + + var updatedSub = await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, + new Stripe.SubscriptionUpdateOptions { CancelAtPeriodEnd = false }); + if (updatedSub.CanceledAt.HasValue) + { + throw new GatewayException("Unable to reinstate subscription."); + } + } + + public async Task UpdatePaymentMethodAsync(ISubscriber subscriber, PaymentMethodType paymentMethodType, + string paymentToken, bool allowInAppPurchases = false, TaxInfo taxInfo = null) + { + if (subscriber == null) + { + throw new ArgumentNullException(nameof(subscriber)); + } + + if (subscriber.Gateway.HasValue && subscriber.Gateway.Value != GatewayType.Stripe) + { + throw new GatewayException("Switching from one payment type to another is not supported. " + + "Contact us for assistance."); + } + + var createdCustomer = false; + AppleReceiptStatus appleReceiptStatus = null; + Braintree.Customer braintreeCustomer = null; + string stipeCustomerSourceToken = null; + string stipeCustomerPaymentMethodId = null; + var stripeCustomerMetadata = new Dictionary(); + var stripePaymentMethod = paymentMethodType == PaymentMethodType.Card || + paymentMethodType == PaymentMethodType.BankAccount; + var inAppPurchase = paymentMethodType == PaymentMethodType.AppleInApp || + paymentMethodType == PaymentMethodType.GoogleInApp; + + Stripe.Customer customer = null; + + if (!allowInAppPurchases && inAppPurchase) + { + throw new GatewayException("In-app purchase payment method is not allowed."); + } + + if (!subscriber.IsUser() && inAppPurchase) + { + throw new GatewayException("In-app purchase payment method is only allowed for users."); + } + + if (!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) + { + var options = new Stripe.CustomerGetOptions(); + options.AddExpand("sources"); + customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, options); + if (customer.Metadata?.Any() ?? false) + { + stripeCustomerMetadata = customer.Metadata; + } + } + + if (inAppPurchase && customer != null && customer.Balance != 0) + { + throw new GatewayException("Customer balance cannot exist when using in-app purchases."); + } + + if (!inAppPurchase && customer != null && stripeCustomerMetadata.ContainsKey("appleReceipt")) + { + throw new GatewayException("Cannot change from in-app payment method. Contact support."); + } + + var hadBtCustomer = stripeCustomerMetadata.ContainsKey("btCustomerId"); + if (stripePaymentMethod) + { + if (paymentToken.StartsWith("pm_")) + { + stipeCustomerPaymentMethodId = paymentToken; + } + else + { + stipeCustomerSourceToken = paymentToken; + } + } + else if (paymentMethodType == PaymentMethodType.PayPal) + { + if (hadBtCustomer) + { + var pmResult = await _btGateway.PaymentMethod.CreateAsync(new Braintree.PaymentMethodRequest { - throw new GatewayException("Apple receipt already in use by another user."); + CustomerId = stripeCustomerMetadata["btCustomerId"], + PaymentMethodNonce = paymentToken + }); + + if (pmResult.IsSuccess()) + { + var customerResult = await _btGateway.Customer.UpdateAsync( + stripeCustomerMetadata["btCustomerId"], new Braintree.CustomerRequest + { + DefaultPaymentMethodToken = pmResult.Target.Token + }); + + if (customerResult.IsSuccess() && customerResult.Target.PaymentMethods.Length > 0) + { + braintreeCustomer = customerResult.Target; + } + else + { + await _btGateway.PaymentMethod.DeleteAsync(pmResult.Target.Token); + hadBtCustomer = false; + } + } + else + { + hadBtCustomer = false; } } - } - private decimal GetBillingBalance(Stripe.Customer customer) + if (!hadBtCustomer) + { + var customerResult = await _btGateway.Customer.CreateAsync(new Braintree.CustomerRequest + { + PaymentMethodNonce = paymentToken, + Email = subscriber.BillingEmailAddress(), + Id = subscriber.BraintreeCustomerIdPrefix() + subscriber.Id.ToString("N").ToLower() + + Utilities.CoreHelpers.RandomString(3, upper: false, numeric: false), + CustomFields = new Dictionary + { + [subscriber.BraintreeIdField()] = subscriber.Id.ToString() + } + }); + + if (!customerResult.IsSuccess() || customerResult.Target.PaymentMethods.Length == 0) + { + throw new GatewayException("Failed to create PayPal customer record."); + } + + braintreeCustomer = customerResult.Target; + } + } + else if (paymentMethodType == PaymentMethodType.AppleInApp) { - return customer != null ? customer.Balance / 100M : default; + appleReceiptStatus = await _appleIapService.GetVerifiedReceiptStatusAsync(paymentToken); + if (appleReceiptStatus == null) + { + throw new GatewayException("Cannot verify Apple in-app purchase."); + } + await VerifyAppleReceiptNotInUseAsync(appleReceiptStatus.GetOriginalTransactionId(), subscriber); + } + else + { + throw new GatewayException("Payment method is not supported at this time."); } - private async Task GetBillingPaymentSourceAsync(Stripe.Customer customer) + if (stripeCustomerMetadata.ContainsKey("btCustomerId")) + { + if (braintreeCustomer?.Id != stripeCustomerMetadata["btCustomerId"]) + { + var nowSec = Utilities.CoreHelpers.ToEpocSeconds(DateTime.UtcNow); + stripeCustomerMetadata.Add($"btCustomerId_{nowSec}", stripeCustomerMetadata["btCustomerId"]); + } + stripeCustomerMetadata["btCustomerId"] = braintreeCustomer?.Id; + } + else if (!string.IsNullOrWhiteSpace(braintreeCustomer?.Id)) + { + stripeCustomerMetadata.Add("btCustomerId", braintreeCustomer.Id); + } + + if (appleReceiptStatus != null) + { + var originalTransactionId = appleReceiptStatus.GetOriginalTransactionId(); + if (stripeCustomerMetadata.ContainsKey("appleReceipt")) + { + if (originalTransactionId != stripeCustomerMetadata["appleReceipt"]) + { + var nowSec = Utilities.CoreHelpers.ToEpocSeconds(DateTime.UtcNow); + stripeCustomerMetadata.Add($"appleReceipt_{nowSec}", stripeCustomerMetadata["appleReceipt"]); + } + stripeCustomerMetadata["appleReceipt"] = originalTransactionId; + } + else + { + stripeCustomerMetadata.Add("appleReceipt", originalTransactionId); + } + await _appleIapService.SaveReceiptAsync(appleReceiptStatus, subscriber.Id); + } + + try { if (customer == null) { - return null; - } - - if (customer.Metadata?.ContainsKey("appleReceipt") ?? false) - { - return new BillingInfo.BillingSource + customer = await _stripeAdapter.CustomerCreateAsync(new Stripe.CustomerCreateOptions { - Type = PaymentMethodType.AppleInApp - }; + Description = subscriber.BillingName(), + Email = subscriber.BillingEmailAddress(), + Metadata = stripeCustomerMetadata, + Source = stipeCustomerSourceToken, + PaymentMethod = stipeCustomerPaymentMethodId, + InvoiceSettings = new Stripe.CustomerInvoiceSettingsOptions + { + DefaultPaymentMethod = stipeCustomerPaymentMethodId + }, + Address = taxInfo == null ? null : new Stripe.AddressOptions + { + Country = taxInfo.BillingAddressCountry, + PostalCode = taxInfo.BillingAddressPostalCode, + Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, + Line2 = taxInfo.BillingAddressLine2, + City = taxInfo.BillingAddressCity, + State = taxInfo.BillingAddressState, + }, + Expand = new List { "sources" }, + }); + + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = customer.Id; + createdCustomer = true; } - if (customer.Metadata?.ContainsKey("btCustomerId") ?? false) + if (!createdCustomer) + { + string defaultSourceId = null; + string defaultPaymentMethodId = null; + if (stripePaymentMethod) + { + if (!string.IsNullOrWhiteSpace(stipeCustomerSourceToken) && paymentToken.StartsWith("btok_")) + { + var bankAccount = await _stripeAdapter.BankAccountCreateAsync(customer.Id, new Stripe.BankAccountCreateOptions + { + Source = paymentToken + }); + defaultSourceId = bankAccount.Id; + } + else if (!string.IsNullOrWhiteSpace(stipeCustomerPaymentMethodId)) + { + await _stripeAdapter.PaymentMethodAttachAsync(stipeCustomerPaymentMethodId, + new Stripe.PaymentMethodAttachOptions { Customer = customer.Id }); + defaultPaymentMethodId = stipeCustomerPaymentMethodId; + } + } + + if (customer.Sources != null) + { + foreach (var source in customer.Sources.Where(s => s.Id != defaultSourceId)) + { + if (source is Stripe.BankAccount) + { + await _stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id); + } + else if (source is Stripe.Card) + { + await _stripeAdapter.CardDeleteAsync(customer.Id, source.Id); + } + } + } + + var cardPaymentMethods = _stripeAdapter.PaymentMethodListAutoPaging(new Stripe.PaymentMethodListOptions + { + Customer = customer.Id, + Type = "card" + }); + foreach (var cardMethod in cardPaymentMethods.Where(m => m.Id != defaultPaymentMethodId)) + { + await _stripeAdapter.PaymentMethodDetachAsync(cardMethod.Id, new Stripe.PaymentMethodDetachOptions()); + } + + customer = await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions + { + Metadata = stripeCustomerMetadata, + DefaultSource = defaultSourceId, + InvoiceSettings = new Stripe.CustomerInvoiceSettingsOptions + { + DefaultPaymentMethod = defaultPaymentMethodId + }, + Address = taxInfo == null ? null : new Stripe.AddressOptions + { + Country = taxInfo.BillingAddressCountry, + PostalCode = taxInfo.BillingAddressPostalCode, + Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, + Line2 = taxInfo.BillingAddressLine2, + City = taxInfo.BillingAddressCity, + State = taxInfo.BillingAddressState, + }, + }); + } + } + catch + { + if (braintreeCustomer != null && !hadBtCustomer) + { + await _btGateway.Customer.DeleteAsync(braintreeCustomer.Id); + } + throw; + } + + return createdCustomer; + } + + public async Task CreditAccountAsync(ISubscriber subscriber, decimal creditAmount) + { + Stripe.Customer customer = null; + var customerExists = subscriber.Gateway == GatewayType.Stripe && + !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId); + if (customerExists) + { + customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); + } + else + { + customer = await _stripeAdapter.CustomerCreateAsync(new Stripe.CustomerCreateOptions + { + Email = subscriber.BillingEmailAddress(), + Description = subscriber.BillingName(), + }); + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = customer.Id; + } + await _stripeAdapter.CustomerUpdateAsync(customer.Id, new Stripe.CustomerUpdateOptions + { + Balance = customer.Balance - (long)(creditAmount * 100) + }); + return !customerExists; + } + + public async Task GetBillingAsync(ISubscriber subscriber) + { + var customer = await GetCustomerAsync(subscriber.GatewayCustomerId, GetCustomerPaymentOptions()); + var billingInfo = new BillingInfo + { + Balance = GetBillingBalance(customer), + PaymentSource = await GetBillingPaymentSourceAsync(customer), + Invoices = await GetBillingInvoicesAsync(customer), + Transactions = await GetBillingTransactionsAsync(subscriber) + }; + + return billingInfo; + } + + public async Task GetBillingBalanceAndSourceAsync(ISubscriber subscriber) + { + var customer = await GetCustomerAsync(subscriber.GatewayCustomerId, GetCustomerPaymentOptions()); + var billingInfo = new BillingInfo + { + Balance = GetBillingBalance(customer), + PaymentSource = await GetBillingPaymentSourceAsync(customer) + }; + + return billingInfo; + } + + public async Task GetBillingHistoryAsync(ISubscriber subscriber) + { + var customer = await GetCustomerAsync(subscriber.GatewayCustomerId); + var billingInfo = new BillingInfo + { + Transactions = await GetBillingTransactionsAsync(subscriber), + Invoices = await GetBillingInvoicesAsync(customer) + }; + + return billingInfo; + } + + public async Task GetSubscriptionAsync(ISubscriber subscriber) + { + var subscriptionInfo = new SubscriptionInfo(); + + if (subscriber.IsUser() && !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) + { + var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); + subscriptionInfo.UsingInAppPurchase = customer.Metadata.ContainsKey("appleReceipt"); + } + + if (!string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) + { + var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + if (sub != null) + { + subscriptionInfo.Subscription = new SubscriptionInfo.BillingSubscription(sub); + } + + if (!sub.CanceledAt.HasValue && !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) { try { - var braintreeCustomer = await _btGateway.Customer.FindAsync( - customer.Metadata["btCustomerId"]); - if (braintreeCustomer?.DefaultPaymentMethod != null) + var upcomingInvoice = await _stripeAdapter.InvoiceUpcomingAsync( + new Stripe.UpcomingInvoiceOptions { Customer = subscriber.GatewayCustomerId }); + if (upcomingInvoice != null) { - return new BillingInfo.BillingSource( - braintreeCustomer.DefaultPaymentMethod); + subscriptionInfo.UpcomingInvoice = + new SubscriptionInfo.BillingUpcomingInvoice(upcomingInvoice); } } - catch (Braintree.Exceptions.NotFoundException) { } + catch (Stripe.StripeException) { } } - - if (customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card") - { - return new BillingInfo.BillingSource( - customer.InvoiceSettings.DefaultPaymentMethod); - } - - if (customer.DefaultSource != null && - (customer.DefaultSource is Stripe.Card || customer.DefaultSource is Stripe.BankAccount)) - { - return new BillingInfo.BillingSource(customer.DefaultSource); - } - - var paymentMethod = GetLatestCardPaymentMethod(customer.Id); - return paymentMethod != null ? new BillingInfo.BillingSource(paymentMethod) : null; } - private Stripe.CustomerGetOptions GetCustomerPaymentOptions() + return subscriptionInfo; + } + + public async Task GetTaxInfoAsync(ISubscriber subscriber) + { + if (subscriber == null || string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) { - var customerOptions = new Stripe.CustomerGetOptions(); - customerOptions.AddExpand("default_source"); - customerOptions.AddExpand("invoice_settings.default_payment_method"); - return customerOptions; + return null; } - private async Task GetCustomerAsync(string gatewayCustomerId, Stripe.CustomerGetOptions options = null) + var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, + new Stripe.CustomerGetOptions { Expand = new List { "tax_ids" } }); + + if (customer == null) { - if (string.IsNullOrWhiteSpace(gatewayCustomerId)) - { - return null; - } - - Stripe.Customer customer = null; - try - { - customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId, options); - } - catch (Stripe.StripeException) { } - - return customer; + return null; } - private async Task> GetBillingTransactionsAsync(ISubscriber subscriber) + var address = customer.Address; + var taxId = customer.TaxIds?.FirstOrDefault(); + + // Line1 is required, so if missing we're using the subscriber name + // see: https://stripe.com/docs/api/customers/create#create_customer-address-line1 + if (address != null && string.IsNullOrWhiteSpace(address.Line1)) { - ICollection transactions = null; - if (subscriber is User) - { - transactions = await _transactionRepository.GetManyByUserIdAsync(subscriber.Id); - } - else if (subscriber is Organization) - { - transactions = await _transactionRepository.GetManyByOrganizationIdAsync(subscriber.Id); - } - - return transactions?.OrderByDescending(i => i.CreationDate) - .Select(t => new BillingInfo.BillingTransaction(t)); - + address.Line1 = null; } - private async Task> GetBillingInvoicesAsync(Stripe.Customer customer) + return new TaxInfo { - if (customer == null) - { - return null; - } + TaxIdNumber = taxId?.Value, + BillingAddressLine1 = address?.Line1, + BillingAddressLine2 = address?.Line2, + BillingAddressCity = address?.City, + BillingAddressState = address?.State, + BillingAddressPostalCode = address?.PostalCode, + BillingAddressCountry = address?.Country, + }; + } - var invoices = await _stripeAdapter.InvoiceListAsync(new Stripe.InvoiceListOptions + public async Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo) + { + if (subscriber != null && !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) + { + var customer = await _stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, new Stripe.CustomerUpdateOptions { - Customer = customer.Id, - Limit = 50 + Address = new Stripe.AddressOptions + { + Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, + Line2 = taxInfo.BillingAddressLine2, + City = taxInfo.BillingAddressCity, + State = taxInfo.BillingAddressState, + PostalCode = taxInfo.BillingAddressPostalCode, + Country = taxInfo.BillingAddressCountry, + }, + Expand = new List { "tax_ids" } }); - return invoices.Data.Where(i => i.Status != "void" && i.Status != "draft") - .OrderByDescending(i => i.Created).Select(i => new BillingInfo.BillingInvoice(i)); + if (!subscriber.IsUser() && customer != null) + { + var taxId = customer.TaxIds?.FirstOrDefault(); + if (taxId != null) + { + await _stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id); + } + if (!string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber) && + !string.IsNullOrWhiteSpace(taxInfo.TaxIdType)) + { + await _stripeAdapter.TaxIdCreateAsync(customer.Id, new Stripe.TaxIdCreateOptions + { + Type = taxInfo.TaxIdType, + Value = taxInfo.TaxIdNumber, + }); + } + } } } + + public async Task CreateTaxRateAsync(TaxRate taxRate) + { + var stripeTaxRateOptions = new Stripe.TaxRateCreateOptions() + { + DisplayName = $"{taxRate.Country} - {taxRate.PostalCode}", + Inclusive = false, + Percentage = taxRate.Rate, + Active = true + }; + var stripeTaxRate = await _stripeAdapter.TaxRateCreateAsync(stripeTaxRateOptions); + taxRate.Id = stripeTaxRate.Id; + await _taxRateRepository.CreateAsync(taxRate); + return taxRate; + } + + public async Task UpdateTaxRateAsync(TaxRate taxRate) + { + if (string.IsNullOrWhiteSpace(taxRate.Id)) + { + return; + } + + await ArchiveTaxRateAsync(taxRate); + await CreateTaxRateAsync(taxRate); + } + + public async Task ArchiveTaxRateAsync(TaxRate taxRate) + { + if (string.IsNullOrWhiteSpace(taxRate.Id)) + { + return; + } + + var updatedStripeTaxRate = await _stripeAdapter.TaxRateUpdateAsync( + taxRate.Id, + new Stripe.TaxRateUpdateOptions() { Active = false } + ); + if (!updatedStripeTaxRate.Active) + { + taxRate.Active = false; + await _taxRateRepository.ArchiveAsync(taxRate); + } + } + + private Stripe.PaymentMethod GetLatestCardPaymentMethod(string customerId) + { + var cardPaymentMethods = _stripeAdapter.PaymentMethodListAutoPaging( + new Stripe.PaymentMethodListOptions { Customer = customerId, Type = "card" }); + return cardPaymentMethods.OrderByDescending(m => m.Created).FirstOrDefault(); + } + + private async Task VerifyAppleReceiptNotInUseAsync(string receiptOriginalTransactionId, ISubscriber subscriber) + { + var existingReceipt = await _appleIapService.GetReceiptAsync(receiptOriginalTransactionId); + if (existingReceipt != null && existingReceipt.Item2.HasValue && existingReceipt.Item2 != subscriber.Id) + { + var existingUser = await _userRepository.GetByIdAsync(existingReceipt.Item2.Value); + if (existingUser != null) + { + throw new GatewayException("Apple receipt already in use by another user."); + } + } + } + + private decimal GetBillingBalance(Stripe.Customer customer) + { + return customer != null ? customer.Balance / 100M : default; + } + + private async Task GetBillingPaymentSourceAsync(Stripe.Customer customer) + { + if (customer == null) + { + return null; + } + + if (customer.Metadata?.ContainsKey("appleReceipt") ?? false) + { + return new BillingInfo.BillingSource + { + Type = PaymentMethodType.AppleInApp + }; + } + + if (customer.Metadata?.ContainsKey("btCustomerId") ?? false) + { + try + { + var braintreeCustomer = await _btGateway.Customer.FindAsync( + customer.Metadata["btCustomerId"]); + if (braintreeCustomer?.DefaultPaymentMethod != null) + { + return new BillingInfo.BillingSource( + braintreeCustomer.DefaultPaymentMethod); + } + } + catch (Braintree.Exceptions.NotFoundException) { } + } + + if (customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card") + { + return new BillingInfo.BillingSource( + customer.InvoiceSettings.DefaultPaymentMethod); + } + + if (customer.DefaultSource != null && + (customer.DefaultSource is Stripe.Card || customer.DefaultSource is Stripe.BankAccount)) + { + return new BillingInfo.BillingSource(customer.DefaultSource); + } + + var paymentMethod = GetLatestCardPaymentMethod(customer.Id); + return paymentMethod != null ? new BillingInfo.BillingSource(paymentMethod) : null; + } + + private Stripe.CustomerGetOptions GetCustomerPaymentOptions() + { + var customerOptions = new Stripe.CustomerGetOptions(); + customerOptions.AddExpand("default_source"); + customerOptions.AddExpand("invoice_settings.default_payment_method"); + return customerOptions; + } + + private async Task GetCustomerAsync(string gatewayCustomerId, Stripe.CustomerGetOptions options = null) + { + if (string.IsNullOrWhiteSpace(gatewayCustomerId)) + { + return null; + } + + Stripe.Customer customer = null; + try + { + customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId, options); + } + catch (Stripe.StripeException) { } + + return customer; + } + + private async Task> GetBillingTransactionsAsync(ISubscriber subscriber) + { + ICollection transactions = null; + if (subscriber is User) + { + transactions = await _transactionRepository.GetManyByUserIdAsync(subscriber.Id); + } + else if (subscriber is Organization) + { + transactions = await _transactionRepository.GetManyByOrganizationIdAsync(subscriber.Id); + } + + return transactions?.OrderByDescending(i => i.CreationDate) + .Select(t => new BillingInfo.BillingTransaction(t)); + + } + + private async Task> GetBillingInvoicesAsync(Stripe.Customer customer) + { + if (customer == null) + { + return null; + } + + var invoices = await _stripeAdapter.InvoiceListAsync(new Stripe.InvoiceListOptions + { + Customer = customer.Id, + Limit = 50 + }); + + return invoices.Data.Where(i => i.Status != "void" && i.Status != "draft") + .OrderByDescending(i => i.Created).Select(i => new BillingInfo.BillingInvoice(i)); + + } } diff --git a/src/Core/Services/Implementations/StripeSyncService.cs b/src/Core/Services/Implementations/StripeSyncService.cs index f042eac5c..b2700e65d 100644 --- a/src/Core/Services/Implementations/StripeSyncService.cs +++ b/src/Core/Services/Implementations/StripeSyncService.cs @@ -1,32 +1,31 @@ using Bit.Core.Exceptions; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class StripeSyncService : IStripeSyncService { - public class StripeSyncService : IStripeSyncService + private readonly IStripeAdapter _stripeAdapter; + + public StripeSyncService(IStripeAdapter stripeAdapter) { - private readonly IStripeAdapter _stripeAdapter; + _stripeAdapter = stripeAdapter; + } - public StripeSyncService(IStripeAdapter stripeAdapter) + public async Task UpdateCustomerEmailAddress(string gatewayCustomerId, string emailAddress) + { + if (string.IsNullOrWhiteSpace(gatewayCustomerId)) { - _stripeAdapter = stripeAdapter; + throw new InvalidGatewayCustomerIdException(); } - public async Task UpdateCustomerEmailAddress(string gatewayCustomerId, string emailAddress) + if (string.IsNullOrWhiteSpace(emailAddress)) { - if (string.IsNullOrWhiteSpace(gatewayCustomerId)) - { - throw new InvalidGatewayCustomerIdException(); - } - - if (string.IsNullOrWhiteSpace(emailAddress)) - { - throw new InvalidEmailException(); - } - - var customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); - - await _stripeAdapter.CustomerUpdateAsync(customer.Id, - new Stripe.CustomerUpdateOptions { Email = emailAddress }); + throw new InvalidEmailException(); } + + var customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); + + await _stripeAdapter.CustomerUpdateAsync(customer.Id, + new Stripe.CustomerUpdateOptions { Email = emailAddress }); } } diff --git a/src/Core/Services/Implementations/UserService.cs b/src/Core/Services/Implementations/UserService.cs index d54ea7bb4..46509ceda 100644 --- a/src/Core/Services/Implementations/UserService.cs +++ b/src/Core/Services/Implementations/UserService.cs @@ -17,1072 +17,980 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using File = System.IO.File; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class UserService : UserManager, IUserService, IDisposable { - public class UserService : UserManager, IUserService, IDisposable + private const string PremiumPlanId = "premium-annually"; + private const string StoragePlanId = "storage-gb-annually"; + + private readonly IUserRepository _userRepository; + private readonly ICipherRepository _cipherRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly IMailService _mailService; + private readonly IPushNotificationService _pushService; + private readonly IdentityErrorDescriber _identityErrorDescriber; + private readonly IdentityOptions _identityOptions; + private readonly IPasswordHasher _passwordHasher; + private readonly IEnumerable> _passwordValidators; + private readonly ILicensingService _licenseService; + private readonly IEventService _eventService; + private readonly IApplicationCacheService _applicationCacheService; + private readonly IPaymentService _paymentService; + private readonly IPolicyRepository _policyRepository; + private readonly IDataProtector _organizationServiceDataProtector; + private readonly IReferenceEventService _referenceEventService; + private readonly IFido2 _fido2; + private readonly ICurrentContext _currentContext; + private readonly IGlobalSettings _globalSettings; + private readonly IOrganizationService _organizationService; + private readonly IProviderUserRepository _providerUserRepository; + private readonly IDeviceRepository _deviceRepository; + private readonly IStripeSyncService _stripeSyncService; + + public UserService( + IUserRepository userRepository, + ICipherRepository cipherRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationRepository organizationRepository, + IMailService mailService, + IPushNotificationService pushService, + IUserStore store, + IOptions optionsAccessor, + IPasswordHasher passwordHasher, + IEnumerable> userValidators, + IEnumerable> passwordValidators, + ILookupNormalizer keyNormalizer, + IdentityErrorDescriber errors, + IServiceProvider services, + ILogger> logger, + ILicensingService licenseService, + IEventService eventService, + IApplicationCacheService applicationCacheService, + IDataProtectionProvider dataProtectionProvider, + IPaymentService paymentService, + IPolicyRepository policyRepository, + IReferenceEventService referenceEventService, + IFido2 fido2, + ICurrentContext currentContext, + IGlobalSettings globalSettings, + IOrganizationService organizationService, + IProviderUserRepository providerUserRepository, + IDeviceRepository deviceRepository, + IStripeSyncService stripeSyncService) + : base( + store, + optionsAccessor, + passwordHasher, + userValidators, + passwordValidators, + keyNormalizer, + errors, + services, + logger) { - private const string PremiumPlanId = "premium-annually"; - private const string StoragePlanId = "storage-gb-annually"; + _userRepository = userRepository; + _cipherRepository = cipherRepository; + _organizationUserRepository = organizationUserRepository; + _organizationRepository = organizationRepository; + _mailService = mailService; + _pushService = pushService; + _identityOptions = optionsAccessor?.Value ?? new IdentityOptions(); + _identityErrorDescriber = errors; + _passwordHasher = passwordHasher; + _passwordValidators = passwordValidators; + _licenseService = licenseService; + _eventService = eventService; + _applicationCacheService = applicationCacheService; + _paymentService = paymentService; + _policyRepository = policyRepository; + _organizationServiceDataProtector = dataProtectionProvider.CreateProtector( + "OrganizationServiceDataProtector"); + _referenceEventService = referenceEventService; + _fido2 = fido2; + _currentContext = currentContext; + _globalSettings = globalSettings; + _organizationService = organizationService; + _providerUserRepository = providerUserRepository; + _deviceRepository = deviceRepository; + _stripeSyncService = stripeSyncService; + } - private readonly IUserRepository _userRepository; - private readonly ICipherRepository _cipherRepository; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IMailService _mailService; - private readonly IPushNotificationService _pushService; - private readonly IdentityErrorDescriber _identityErrorDescriber; - private readonly IdentityOptions _identityOptions; - private readonly IPasswordHasher _passwordHasher; - private readonly IEnumerable> _passwordValidators; - private readonly ILicensingService _licenseService; - private readonly IEventService _eventService; - private readonly IApplicationCacheService _applicationCacheService; - private readonly IPaymentService _paymentService; - private readonly IPolicyRepository _policyRepository; - private readonly IDataProtector _organizationServiceDataProtector; - private readonly IReferenceEventService _referenceEventService; - private readonly IFido2 _fido2; - private readonly ICurrentContext _currentContext; - private readonly IGlobalSettings _globalSettings; - private readonly IOrganizationService _organizationService; - private readonly IProviderUserRepository _providerUserRepository; - private readonly IDeviceRepository _deviceRepository; - private readonly IStripeSyncService _stripeSyncService; - - public UserService( - IUserRepository userRepository, - ICipherRepository cipherRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationRepository organizationRepository, - IMailService mailService, - IPushNotificationService pushService, - IUserStore store, - IOptions optionsAccessor, - IPasswordHasher passwordHasher, - IEnumerable> userValidators, - IEnumerable> passwordValidators, - ILookupNormalizer keyNormalizer, - IdentityErrorDescriber errors, - IServiceProvider services, - ILogger> logger, - ILicensingService licenseService, - IEventService eventService, - IApplicationCacheService applicationCacheService, - IDataProtectionProvider dataProtectionProvider, - IPaymentService paymentService, - IPolicyRepository policyRepository, - IReferenceEventService referenceEventService, - IFido2 fido2, - ICurrentContext currentContext, - IGlobalSettings globalSettings, - IOrganizationService organizationService, - IProviderUserRepository providerUserRepository, - IDeviceRepository deviceRepository, - IStripeSyncService stripeSyncService) - : base( - store, - optionsAccessor, - passwordHasher, - userValidators, - passwordValidators, - keyNormalizer, - errors, - services, - logger) + public Guid? GetProperUserId(ClaimsPrincipal principal) + { + if (!Guid.TryParse(GetUserId(principal), out var userIdGuid)) { - _userRepository = userRepository; - _cipherRepository = cipherRepository; - _organizationUserRepository = organizationUserRepository; - _organizationRepository = organizationRepository; - _mailService = mailService; - _pushService = pushService; - _identityOptions = optionsAccessor?.Value ?? new IdentityOptions(); - _identityErrorDescriber = errors; - _passwordHasher = passwordHasher; - _passwordValidators = passwordValidators; - _licenseService = licenseService; - _eventService = eventService; - _applicationCacheService = applicationCacheService; - _paymentService = paymentService; - _policyRepository = policyRepository; - _organizationServiceDataProtector = dataProtectionProvider.CreateProtector( - "OrganizationServiceDataProtector"); - _referenceEventService = referenceEventService; - _fido2 = fido2; - _currentContext = currentContext; - _globalSettings = globalSettings; - _organizationService = organizationService; - _providerUserRepository = providerUserRepository; - _deviceRepository = deviceRepository; - _stripeSyncService = stripeSyncService; - } - - public Guid? GetProperUserId(ClaimsPrincipal principal) - { - if (!Guid.TryParse(GetUserId(principal), out var userIdGuid)) - { - return null; - } - - return userIdGuid; - } - - public async Task GetUserByIdAsync(string userId) - { - if (_currentContext?.User != null && - string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase)) - { - return _currentContext.User; - } - - if (!Guid.TryParse(userId, out var userIdGuid)) - { - return null; - } - - _currentContext.User = await _userRepository.GetByIdAsync(userIdGuid); - return _currentContext.User; - } - - public async Task GetUserByIdAsync(Guid userId) - { - if (_currentContext?.User != null && _currentContext.User.Id == userId) - { - return _currentContext.User; - } - - _currentContext.User = await _userRepository.GetByIdAsync(userId); - return _currentContext.User; - } - - public async Task GetUserByPrincipalAsync(ClaimsPrincipal principal) - { - var userId = GetProperUserId(principal); - if (!userId.HasValue) - { - return null; - } - - return await GetUserByIdAsync(userId.Value); - } - - public async Task GetAccountRevisionDateByIdAsync(Guid userId) - { - return await _userRepository.GetAccountRevisionDateAsync(userId); - } - - public async Task SaveUserAsync(User user, bool push = false) - { - if (user.Id == default(Guid)) - { - throw new ApplicationException("Use register method to create a new user."); - } - - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); - - if (push) - { - // push - await _pushService.PushSyncSettingsAsync(user.Id); - } - } - - public override async Task DeleteAsync(User user) - { - // Check if user is the only owner of any organizations. - var onlyOwnerCount = await _organizationUserRepository.GetCountByOnlyOwnerAsync(user.Id); - if (onlyOwnerCount > 0) - { - var deletedOrg = false; - var orgs = await _organizationUserRepository.GetManyDetailsByUserAsync(user.Id, - OrganizationUserStatusType.Confirmed); - if (orgs.Count == 1) - { - var org = await _organizationRepository.GetByIdAsync(orgs.First().OrganizationId); - if (org != null && (!org.Enabled || string.IsNullOrWhiteSpace(org.GatewaySubscriptionId))) - { - var orgCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(org.Id); - if (orgCount <= 1) - { - await _organizationRepository.DeleteAsync(org); - deletedOrg = true; - } - } - } - - if (!deletedOrg) - { - return IdentityResult.Failed(new IdentityError - { - Description = "Cannot delete this user because it is the sole owner of at least one organization. Please delete these organizations or upgrade another user.", - }); - } - } - - var onlyOwnerProviderCount = await _providerUserRepository.GetCountByOnlyOwnerAsync(user.Id); - if (onlyOwnerProviderCount > 0) - { - return IdentityResult.Failed(new IdentityError - { - Description = "Cannot delete this user because it is the sole owner of at least one provider. Please delete these providers or upgrade another user.", - }); - } - - if (!string.IsNullOrWhiteSpace(user.GatewaySubscriptionId)) - { - try - { - await CancelPremiumAsync(user, null, true); - } - catch (GatewayException) { } - } - - await _userRepository.DeleteAsync(user); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.DeleteAccount, user)); - await _pushService.PushLogOutAsync(user.Id); - return IdentityResult.Success; - } - - public async Task DeleteAsync(User user, string token) - { - if (!(await VerifyUserTokenAsync(user, TokenOptions.DefaultProvider, "DeleteAccount", token))) - { - return IdentityResult.Failed(ErrorDescriber.InvalidToken()); - } - - return await DeleteAsync(user); - } - - public async Task SendDeleteConfirmationAsync(string email) - { - var user = await _userRepository.GetByEmailAsync(email); - if (user == null) - { - // No user exists. - return; - } - - var token = await base.GenerateUserTokenAsync(user, TokenOptions.DefaultProvider, "DeleteAccount"); - await _mailService.SendVerifyDeleteEmailAsync(user.Email, user.Id, token); - } - - public async Task RegisterUserAsync(User user, string masterPassword, - string token, Guid? orgUserId) - { - var tokenValid = false; - if (_globalSettings.DisableUserRegistration && !string.IsNullOrWhiteSpace(token) && orgUserId.HasValue) - { - tokenValid = CoreHelpers.UserInviteTokenIsValid(_organizationServiceDataProtector, token, - user.Email, orgUserId.Value, _globalSettings); - } - - if (_globalSettings.DisableUserRegistration && !tokenValid) - { - throw new BadRequestException("Open registration has been disabled by the system administrator."); - } - - if (orgUserId.HasValue) - { - var orgUser = await _organizationUserRepository.GetByIdAsync(orgUserId.Value); - if (orgUser != null) - { - var twoFactorPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(orgUser.OrganizationId, - PolicyType.TwoFactorAuthentication); - if (twoFactorPolicy != null && twoFactorPolicy.Enabled) - { - user.SetTwoFactorProviders(new Dictionary - { - - [TwoFactorProviderType.Email] = new TwoFactorProvider - { - MetaData = new Dictionary { ["Email"] = user.Email.ToLowerInvariant() }, - Enabled = true - } - }); - SetTwoFactorProvider(user, TwoFactorProviderType.Email); - } - } - } - - user.ApiKey = CoreHelpers.SecureRandomString(30); - var result = await base.CreateAsync(user, masterPassword); - if (result == IdentityResult.Success) - { - await _mailService.SendWelcomeEmailAsync(user); - await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.Signup, user)); - } - - return result; - } - - public async Task RegisterUserAsync(User user) - { - var result = await base.CreateAsync(user); - if (result == IdentityResult.Success) - { - await _mailService.SendWelcomeEmailAsync(user); - await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.Signup, user)); - } - - return result; - } - - public async Task SendMasterPasswordHintAsync(string email) - { - var user = await _userRepository.GetByEmailAsync(email); - if (user == null) - { - // No user exists. Do we want to send an email telling them this in the future? - return; - } - - if (string.IsNullOrWhiteSpace(user.MasterPasswordHint)) - { - await _mailService.SendNoMasterPasswordHintEmailAsync(email); - return; - } - - await _mailService.SendMasterPasswordHintEmailAsync(email, user.MasterPasswordHint); - } - - public async Task SendTwoFactorEmailAsync(User user, bool isBecauseNewDeviceLogin = false) - { - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); - if (provider == null || provider.MetaData == null || !provider.MetaData.ContainsKey("Email")) - { - throw new ArgumentNullException("No email."); - } - - var email = ((string)provider.MetaData["Email"]).ToLowerInvariant(); - var token = await base.GenerateUserTokenAsync(user, TokenOptions.DefaultEmailProvider, - "2faEmail:" + email); - - if (isBecauseNewDeviceLogin) - { - await _mailService.SendNewDeviceLoginTwoFactorEmailAsync(email, token); - } - else - { - await _mailService.SendTwoFactorEmailAsync(email, token); - } - } - - public async Task VerifyTwoFactorEmailAsync(User user, string token) - { - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); - if (provider == null || provider.MetaData == null || !provider.MetaData.ContainsKey("Email")) - { - throw new ArgumentNullException("No email."); - } - - var email = ((string)provider.MetaData["Email"]).ToLowerInvariant(); - return await base.VerifyUserTokenAsync(user, TokenOptions.DefaultEmailProvider, - "2faEmail:" + email, token); - } - - public async Task StartWebAuthnRegistrationAsync(User user) - { - var providers = user.GetTwoFactorProviders(); - if (providers == null) - { - providers = new Dictionary(); - } - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); - if (provider == null) - { - provider = new TwoFactorProvider - { - Enabled = false - }; - } - if (provider.MetaData == null) - { - provider.MetaData = new Dictionary(); - } - - var fidoUser = new Fido2User - { - DisplayName = user.Name, - Name = user.Email, - Id = user.Id.ToByteArray(), - }; - - var excludeCredentials = provider.MetaData - .Where(k => k.Key.StartsWith("Key")) - .Select(k => new TwoFactorProvider.WebAuthnData((dynamic)k.Value).Descriptor) - .ToList(); - - var authenticatorSelection = new AuthenticatorSelection - { - AuthenticatorAttachment = null, - RequireResidentKey = false, - UserVerification = UserVerificationRequirement.Discouraged - }; - var options = _fido2.RequestNewCredential(fidoUser, excludeCredentials, authenticatorSelection, AttestationConveyancePreference.None); - - provider.MetaData["pending"] = options.ToJson(); - providers[TwoFactorProviderType.WebAuthn] = provider; - user.SetTwoFactorProviders(providers); - await UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn, false); - - return options; - } - - public async Task CompleteWebAuthRegistrationAsync(User user, int id, string name, AuthenticatorAttestationRawResponse attestationResponse) - { - var keyId = $"Key{id}"; - - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); - if (!provider?.MetaData?.ContainsKey("pending") ?? true) - { - return false; - } - - var options = CredentialCreateOptions.FromJson((string)provider.MetaData["pending"]); - - // Callback to ensure credential id is unique. Always return true since we don't care if another - // account uses the same 2fa key. - IsCredentialIdUniqueToUserAsyncDelegate callback = args => Task.FromResult(true); - - var success = await _fido2.MakeNewCredentialAsync(attestationResponse, options, callback); - - provider.MetaData.Remove("pending"); - provider.MetaData[keyId] = new TwoFactorProvider.WebAuthnData - { - Name = name, - Descriptor = new PublicKeyCredentialDescriptor(success.Result.CredentialId), - PublicKey = success.Result.PublicKey, - UserHandle = success.Result.User.Id, - SignatureCounter = success.Result.Counter, - CredType = success.Result.CredType, - RegDate = DateTime.Now, - AaGuid = success.Result.Aaguid - }; - - var providers = user.GetTwoFactorProviders(); - providers[TwoFactorProviderType.WebAuthn] = provider; - user.SetTwoFactorProviders(providers); - await UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn); - - return true; - } - - public async Task DeleteWebAuthnKeyAsync(User user, int id) - { - var providers = user.GetTwoFactorProviders(); - if (providers == null) - { - return false; - } - - var keyName = $"Key{id}"; - var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); - if (!provider?.MetaData?.ContainsKey(keyName) ?? true) - { - return false; - } - - if (provider.MetaData.Count < 2) - { - return false; - } - - provider.MetaData.Remove(keyName); - providers[TwoFactorProviderType.WebAuthn] = provider; - user.SetTwoFactorProviders(providers); - await UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn); - return true; - } - - public async Task SendEmailVerificationAsync(User user) - { - if (user.EmailVerified) - { - throw new BadRequestException("Email already verified."); - } - - var token = await base.GenerateEmailConfirmationTokenAsync(user); - await _mailService.SendVerifyEmailEmailAsync(user.Email, user.Id, token); - } - - public async Task InitiateEmailChangeAsync(User user, string newEmail) - { - var existingUser = await _userRepository.GetByEmailAsync(newEmail); - if (existingUser != null) - { - await _mailService.SendChangeEmailAlreadyExistsEmailAsync(user.Email, newEmail); - return; - } - - var token = await base.GenerateChangeEmailTokenAsync(user, newEmail); - await _mailService.SendChangeEmailEmailAsync(newEmail, token); - } - - public async Task ChangeEmailAsync(User user, string masterPassword, string newEmail, - string newMasterPassword, string token, string key) - { - var verifyPasswordResult = _passwordHasher.VerifyHashedPassword(user, user.MasterPassword, masterPassword); - if (verifyPasswordResult == PasswordVerificationResult.Failed) - { - return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); - } - - if (!await base.VerifyUserTokenAsync(user, _identityOptions.Tokens.ChangeEmailTokenProvider, - GetChangeEmailTokenPurpose(newEmail), token)) - { - return IdentityResult.Failed(_identityErrorDescriber.InvalidToken()); - } - - var existingUser = await _userRepository.GetByEmailAsync(newEmail); - if (existingUser != null && existingUser.Id != user.Id) - { - return IdentityResult.Failed(_identityErrorDescriber.DuplicateEmail(newEmail)); - } - - var previousState = new - { - Key = user.Key, - MasterPassword = user.MasterPassword, - SecurityStamp = user.SecurityStamp, - Email = user.Email - }; - - var result = await UpdatePasswordHash(user, newMasterPassword); - if (!result.Succeeded) - { - return result; - } - - user.Key = key; - user.Email = newEmail; - user.EmailVerified = true; - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); - - if (user.Gateway == GatewayType.Stripe) - { - - try - { - await _stripeSyncService.UpdateCustomerEmailAddress(user.GatewayCustomerId, - user.BillingEmailAddress()); - } - catch (Exception ex) - { - //if sync to strip fails, update email and securityStamp to previous - user.Key = previousState.Key; - user.Email = previousState.Email; - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - user.MasterPassword = previousState.MasterPassword; - user.SecurityStamp = previousState.SecurityStamp; - - await _userRepository.ReplaceAsync(user); - return IdentityResult.Failed(new IdentityError - { - Description = ex.Message - }); - } - } - - await _pushService.PushLogOutAsync(user.Id); - - return IdentityResult.Success; - } - - public override Task ChangePasswordAsync(User user, string masterPassword, string newMasterPassword) - { - throw new NotImplementedException(); - } - - public async Task ChangePasswordAsync(User user, string masterPassword, string newMasterPassword, string passwordHint, - string key) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - if (await CheckPasswordAsync(user, masterPassword)) - { - var result = await UpdatePasswordHash(user, newMasterPassword); - if (!result.Succeeded) - { - return result; - } - - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - user.Key = key; - user.MasterPasswordHint = passwordHint; - - await _userRepository.ReplaceAsync(user); - await _eventService.LogUserEventAsync(user.Id, EventType.User_ChangedPassword); - await _pushService.PushLogOutAsync(user.Id); - - return IdentityResult.Success; - } - - Logger.LogWarning("Change password failed for user {userId}.", user.Id); - return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); - } - - public async Task SetPasswordAsync(User user, string masterPassword, string key, - string orgIdentifier = null) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - if (!string.IsNullOrWhiteSpace(user.MasterPassword)) - { - Logger.LogWarning("Change password failed for user {userId} - already has password.", user.Id); - return IdentityResult.Failed(_identityErrorDescriber.UserAlreadyHasPassword()); - } - - var result = await UpdatePasswordHash(user, masterPassword, true, false); - if (!result.Succeeded) - { - return result; - } - - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - user.Key = key; - - await _userRepository.ReplaceAsync(user); - await _eventService.LogUserEventAsync(user.Id, EventType.User_ChangedPassword); - - if (!string.IsNullOrWhiteSpace(orgIdentifier)) - { - await _organizationService.AcceptUserAsync(orgIdentifier, user, this); - } - - return IdentityResult.Success; - } - - public async Task SetKeyConnectorKeyAsync(User user, string key, string orgIdentifier) - { - var identityResult = CheckCanUseKeyConnector(user); - if (identityResult != null) - { - return identityResult; - } - - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - user.Key = key; - user.UsesKeyConnector = true; - - await _userRepository.ReplaceAsync(user); - await _eventService.LogUserEventAsync(user.Id, EventType.User_MigratedKeyToKeyConnector); - - await _organizationService.AcceptUserAsync(orgIdentifier, user, this); - - return IdentityResult.Success; - } - - public async Task ConvertToKeyConnectorAsync(User user) - { - var identityResult = CheckCanUseKeyConnector(user); - if (identityResult != null) - { - return identityResult; - } - - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - user.MasterPassword = null; - user.UsesKeyConnector = true; - - await _userRepository.ReplaceAsync(user); - await _eventService.LogUserEventAsync(user.Id, EventType.User_MigratedKeyToKeyConnector); - - return IdentityResult.Success; - } - - private IdentityResult CheckCanUseKeyConnector(User user) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - if (user.UsesKeyConnector) - { - Logger.LogWarning("Already uses Key Connector."); - return IdentityResult.Failed(_identityErrorDescriber.UserAlreadyHasPassword()); - } - - if (_currentContext.Organizations.Any(u => - u.Type is OrganizationUserType.Owner or OrganizationUserType.Admin)) - { - throw new BadRequestException("Cannot use Key Connector when admin or owner of an organization."); - } - return null; } - public async Task AdminResetPasswordAsync(OrganizationUserType callingUserType, Guid orgId, Guid id, string newMasterPassword, string key) + return userIdGuid; + } + + public async Task GetUserByIdAsync(string userId) + { + if (_currentContext?.User != null && + string.Equals(_currentContext.User.Id.ToString(), userId, StringComparison.InvariantCultureIgnoreCase)) { - // Org must be able to use reset password - var org = await _organizationRepository.GetByIdAsync(orgId); - if (org == null || !org.UseResetPassword) - { - throw new BadRequestException("Organization does not allow password reset."); - } - - // Enterprise policy must be enabled - var resetPasswordPolicy = - await _policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword); - if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled) - { - throw new BadRequestException("Organization does not have the password reset policy enabled."); - } - - // Org User must be confirmed and have a ResetPasswordKey - var orgUser = await _organizationUserRepository.GetByIdAsync(id); - if (orgUser == null || orgUser.Status != OrganizationUserStatusType.Confirmed || - orgUser.OrganizationId != orgId || string.IsNullOrEmpty(orgUser.ResetPasswordKey) || - !orgUser.UserId.HasValue) - { - throw new BadRequestException("Organization User not valid"); - } - - // Calling User must be of higher/equal user type to reset user's password - var canAdjustPassword = false; - switch (callingUserType) - { - case OrganizationUserType.Owner: - canAdjustPassword = true; - break; - case OrganizationUserType.Admin: - canAdjustPassword = orgUser.Type != OrganizationUserType.Owner; - break; - case OrganizationUserType.Custom: - canAdjustPassword = orgUser.Type != OrganizationUserType.Owner && - orgUser.Type != OrganizationUserType.Admin; - break; - } - - if (!canAdjustPassword) - { - throw new BadRequestException("Calling user does not have permission to reset this user's master password"); - } - - var user = await GetUserByIdAsync(orgUser.UserId.Value); - if (user == null) - { - throw new NotFoundException(); - } - - if (user.UsesKeyConnector) - { - throw new BadRequestException("Cannot reset password of a user with Key Connector."); - } - - var result = await UpdatePasswordHash(user, newMasterPassword); - if (!result.Succeeded) - { - return result; - } - - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - user.Key = key; - user.ForcePasswordReset = true; - - await _userRepository.ReplaceAsync(user); - await _mailService.SendAdminResetPasswordEmailAsync(user.Email, user.Name, org.Name); - await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_AdminResetPassword); - await _pushService.PushLogOutAsync(user.Id); - - return IdentityResult.Success; + return _currentContext.User; } - public async Task UpdateTempPasswordAsync(User user, string newMasterPassword, string key, string hint) + if (!Guid.TryParse(userId, out var userIdGuid)) { - if (!user.ForcePasswordReset) - { - throw new BadRequestException("User does not have a temporary password to update."); - } - - var result = await UpdatePasswordHash(user, newMasterPassword); - if (!result.Succeeded) - { - return result; - } - - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - user.ForcePasswordReset = false; - user.Key = key; - user.MasterPasswordHint = hint; - - await _userRepository.ReplaceAsync(user); - await _mailService.SendUpdatedTempPasswordEmailAsync(user.Email, user.Name); - await _eventService.LogUserEventAsync(user.Id, EventType.User_UpdatedTempPassword); - await _pushService.PushLogOutAsync(user.Id); - - return IdentityResult.Success; + return null; } - public async Task ChangeKdfAsync(User user, string masterPassword, string newMasterPassword, - string key, KdfType kdf, int kdfIterations) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } + _currentContext.User = await _userRepository.GetByIdAsync(userIdGuid); + return _currentContext.User; + } - if (await CheckPasswordAsync(user, masterPassword)) + public async Task GetUserByIdAsync(Guid userId) + { + if (_currentContext?.User != null && _currentContext.User.Id == userId) + { + return _currentContext.User; + } + + _currentContext.User = await _userRepository.GetByIdAsync(userId); + return _currentContext.User; + } + + public async Task GetUserByPrincipalAsync(ClaimsPrincipal principal) + { + var userId = GetProperUserId(principal); + if (!userId.HasValue) + { + return null; + } + + return await GetUserByIdAsync(userId.Value); + } + + public async Task GetAccountRevisionDateByIdAsync(Guid userId) + { + return await _userRepository.GetAccountRevisionDateAsync(userId); + } + + public async Task SaveUserAsync(User user, bool push = false) + { + if (user.Id == default(Guid)) + { + throw new ApplicationException("Use register method to create a new user."); + } + + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + + if (push) + { + // push + await _pushService.PushSyncSettingsAsync(user.Id); + } + } + + public override async Task DeleteAsync(User user) + { + // Check if user is the only owner of any organizations. + var onlyOwnerCount = await _organizationUserRepository.GetCountByOnlyOwnerAsync(user.Id); + if (onlyOwnerCount > 0) + { + var deletedOrg = false; + var orgs = await _organizationUserRepository.GetManyDetailsByUserAsync(user.Id, + OrganizationUserStatusType.Confirmed); + if (orgs.Count == 1) { - var result = await UpdatePasswordHash(user, newMasterPassword); - if (!result.Succeeded) + var org = await _organizationRepository.GetByIdAsync(orgs.First().OrganizationId); + if (org != null && (!org.Enabled || string.IsNullOrWhiteSpace(org.GatewaySubscriptionId))) { - return result; + var orgCount = await _organizationUserRepository.GetCountByOrganizationIdAsync(org.Id); + if (orgCount <= 1) + { + await _organizationRepository.DeleteAsync(org); + deletedOrg = true; + } } - - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - user.Key = key; - user.Kdf = kdf; - user.KdfIterations = kdfIterations; - await _userRepository.ReplaceAsync(user); - await _pushService.PushLogOutAsync(user.Id); - return IdentityResult.Success; } - Logger.LogWarning("Change KDF failed for user {userId}.", user.Id); + if (!deletedOrg) + { + return IdentityResult.Failed(new IdentityError + { + Description = "Cannot delete this user because it is the sole owner of at least one organization. Please delete these organizations or upgrade another user.", + }); + } + } + + var onlyOwnerProviderCount = await _providerUserRepository.GetCountByOnlyOwnerAsync(user.Id); + if (onlyOwnerProviderCount > 0) + { + return IdentityResult.Failed(new IdentityError + { + Description = "Cannot delete this user because it is the sole owner of at least one provider. Please delete these providers or upgrade another user.", + }); + } + + if (!string.IsNullOrWhiteSpace(user.GatewaySubscriptionId)) + { + try + { + await CancelPremiumAsync(user, null, true); + } + catch (GatewayException) { } + } + + await _userRepository.DeleteAsync(user); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.DeleteAccount, user)); + await _pushService.PushLogOutAsync(user.Id); + return IdentityResult.Success; + } + + public async Task DeleteAsync(User user, string token) + { + if (!(await VerifyUserTokenAsync(user, TokenOptions.DefaultProvider, "DeleteAccount", token))) + { + return IdentityResult.Failed(ErrorDescriber.InvalidToken()); + } + + return await DeleteAsync(user); + } + + public async Task SendDeleteConfirmationAsync(string email) + { + var user = await _userRepository.GetByEmailAsync(email); + if (user == null) + { + // No user exists. + return; + } + + var token = await base.GenerateUserTokenAsync(user, TokenOptions.DefaultProvider, "DeleteAccount"); + await _mailService.SendVerifyDeleteEmailAsync(user.Email, user.Id, token); + } + + public async Task RegisterUserAsync(User user, string masterPassword, + string token, Guid? orgUserId) + { + var tokenValid = false; + if (_globalSettings.DisableUserRegistration && !string.IsNullOrWhiteSpace(token) && orgUserId.HasValue) + { + tokenValid = CoreHelpers.UserInviteTokenIsValid(_organizationServiceDataProtector, token, + user.Email, orgUserId.Value, _globalSettings); + } + + if (_globalSettings.DisableUserRegistration && !tokenValid) + { + throw new BadRequestException("Open registration has been disabled by the system administrator."); + } + + if (orgUserId.HasValue) + { + var orgUser = await _organizationUserRepository.GetByIdAsync(orgUserId.Value); + if (orgUser != null) + { + var twoFactorPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(orgUser.OrganizationId, + PolicyType.TwoFactorAuthentication); + if (twoFactorPolicy != null && twoFactorPolicy.Enabled) + { + user.SetTwoFactorProviders(new Dictionary + { + + [TwoFactorProviderType.Email] = new TwoFactorProvider + { + MetaData = new Dictionary { ["Email"] = user.Email.ToLowerInvariant() }, + Enabled = true + } + }); + SetTwoFactorProvider(user, TwoFactorProviderType.Email); + } + } + } + + user.ApiKey = CoreHelpers.SecureRandomString(30); + var result = await base.CreateAsync(user, masterPassword); + if (result == IdentityResult.Success) + { + await _mailService.SendWelcomeEmailAsync(user); + await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.Signup, user)); + } + + return result; + } + + public async Task RegisterUserAsync(User user) + { + var result = await base.CreateAsync(user); + if (result == IdentityResult.Success) + { + await _mailService.SendWelcomeEmailAsync(user); + await _referenceEventService.RaiseEventAsync(new ReferenceEvent(ReferenceEventType.Signup, user)); + } + + return result; + } + + public async Task SendMasterPasswordHintAsync(string email) + { + var user = await _userRepository.GetByEmailAsync(email); + if (user == null) + { + // No user exists. Do we want to send an email telling them this in the future? + return; + } + + if (string.IsNullOrWhiteSpace(user.MasterPasswordHint)) + { + await _mailService.SendNoMasterPasswordHintEmailAsync(email); + return; + } + + await _mailService.SendMasterPasswordHintEmailAsync(email, user.MasterPasswordHint); + } + + public async Task SendTwoFactorEmailAsync(User user, bool isBecauseNewDeviceLogin = false) + { + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); + if (provider == null || provider.MetaData == null || !provider.MetaData.ContainsKey("Email")) + { + throw new ArgumentNullException("No email."); + } + + var email = ((string)provider.MetaData["Email"]).ToLowerInvariant(); + var token = await base.GenerateUserTokenAsync(user, TokenOptions.DefaultEmailProvider, + "2faEmail:" + email); + + if (isBecauseNewDeviceLogin) + { + await _mailService.SendNewDeviceLoginTwoFactorEmailAsync(email, token); + } + else + { + await _mailService.SendTwoFactorEmailAsync(email, token); + } + } + + public async Task VerifyTwoFactorEmailAsync(User user, string token) + { + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.Email); + if (provider == null || provider.MetaData == null || !provider.MetaData.ContainsKey("Email")) + { + throw new ArgumentNullException("No email."); + } + + var email = ((string)provider.MetaData["Email"]).ToLowerInvariant(); + return await base.VerifyUserTokenAsync(user, TokenOptions.DefaultEmailProvider, + "2faEmail:" + email, token); + } + + public async Task StartWebAuthnRegistrationAsync(User user) + { + var providers = user.GetTwoFactorProviders(); + if (providers == null) + { + providers = new Dictionary(); + } + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); + if (provider == null) + { + provider = new TwoFactorProvider + { + Enabled = false + }; + } + if (provider.MetaData == null) + { + provider.MetaData = new Dictionary(); + } + + var fidoUser = new Fido2User + { + DisplayName = user.Name, + Name = user.Email, + Id = user.Id.ToByteArray(), + }; + + var excludeCredentials = provider.MetaData + .Where(k => k.Key.StartsWith("Key")) + .Select(k => new TwoFactorProvider.WebAuthnData((dynamic)k.Value).Descriptor) + .ToList(); + + var authenticatorSelection = new AuthenticatorSelection + { + AuthenticatorAttachment = null, + RequireResidentKey = false, + UserVerification = UserVerificationRequirement.Discouraged + }; + var options = _fido2.RequestNewCredential(fidoUser, excludeCredentials, authenticatorSelection, AttestationConveyancePreference.None); + + provider.MetaData["pending"] = options.ToJson(); + providers[TwoFactorProviderType.WebAuthn] = provider; + user.SetTwoFactorProviders(providers); + await UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn, false); + + return options; + } + + public async Task CompleteWebAuthRegistrationAsync(User user, int id, string name, AuthenticatorAttestationRawResponse attestationResponse) + { + var keyId = $"Key{id}"; + + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); + if (!provider?.MetaData?.ContainsKey("pending") ?? true) + { + return false; + } + + var options = CredentialCreateOptions.FromJson((string)provider.MetaData["pending"]); + + // Callback to ensure credential id is unique. Always return true since we don't care if another + // account uses the same 2fa key. + IsCredentialIdUniqueToUserAsyncDelegate callback = args => Task.FromResult(true); + + var success = await _fido2.MakeNewCredentialAsync(attestationResponse, options, callback); + + provider.MetaData.Remove("pending"); + provider.MetaData[keyId] = new TwoFactorProvider.WebAuthnData + { + Name = name, + Descriptor = new PublicKeyCredentialDescriptor(success.Result.CredentialId), + PublicKey = success.Result.PublicKey, + UserHandle = success.Result.User.Id, + SignatureCounter = success.Result.Counter, + CredType = success.Result.CredType, + RegDate = DateTime.Now, + AaGuid = success.Result.Aaguid + }; + + var providers = user.GetTwoFactorProviders(); + providers[TwoFactorProviderType.WebAuthn] = provider; + user.SetTwoFactorProviders(providers); + await UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn); + + return true; + } + + public async Task DeleteWebAuthnKeyAsync(User user, int id) + { + var providers = user.GetTwoFactorProviders(); + if (providers == null) + { + return false; + } + + var keyName = $"Key{id}"; + var provider = user.GetTwoFactorProvider(TwoFactorProviderType.WebAuthn); + if (!provider?.MetaData?.ContainsKey(keyName) ?? true) + { + return false; + } + + if (provider.MetaData.Count < 2) + { + return false; + } + + provider.MetaData.Remove(keyName); + providers[TwoFactorProviderType.WebAuthn] = provider; + user.SetTwoFactorProviders(providers); + await UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.WebAuthn); + return true; + } + + public async Task SendEmailVerificationAsync(User user) + { + if (user.EmailVerified) + { + throw new BadRequestException("Email already verified."); + } + + var token = await base.GenerateEmailConfirmationTokenAsync(user); + await _mailService.SendVerifyEmailEmailAsync(user.Email, user.Id, token); + } + + public async Task InitiateEmailChangeAsync(User user, string newEmail) + { + var existingUser = await _userRepository.GetByEmailAsync(newEmail); + if (existingUser != null) + { + await _mailService.SendChangeEmailAlreadyExistsEmailAsync(user.Email, newEmail); + return; + } + + var token = await base.GenerateChangeEmailTokenAsync(user, newEmail); + await _mailService.SendChangeEmailEmailAsync(newEmail, token); + } + + public async Task ChangeEmailAsync(User user, string masterPassword, string newEmail, + string newMasterPassword, string token, string key) + { + var verifyPasswordResult = _passwordHasher.VerifyHashedPassword(user, user.MasterPassword, masterPassword); + if (verifyPasswordResult == PasswordVerificationResult.Failed) + { return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); } - public async Task UpdateKeyAsync(User user, string masterPassword, string key, string privateKey, - IEnumerable ciphers, IEnumerable folders, IEnumerable sends) + if (!await base.VerifyUserTokenAsync(user, _identityOptions.Tokens.ChangeEmailTokenProvider, + GetChangeEmailTokenPurpose(newEmail), token)) { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - if (await CheckPasswordAsync(user, masterPassword)) - { - user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; - user.SecurityStamp = Guid.NewGuid().ToString(); - user.Key = key; - user.PrivateKey = privateKey; - if (ciphers.Any() || folders.Any() || sends.Any()) - { - await _cipherRepository.UpdateUserKeysAndCiphersAsync(user, ciphers, folders, sends); - } - else - { - await _userRepository.ReplaceAsync(user); - } - - await _pushService.PushLogOutAsync(user.Id); - return IdentityResult.Success; - } - - Logger.LogWarning("Update key failed for user {userId}.", user.Id); - return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); + return IdentityResult.Failed(_identityErrorDescriber.InvalidToken()); } - public async Task RefreshSecurityStampAsync(User user, string secret) + var existingUser = await _userRepository.GetByEmailAsync(newEmail); + if (existingUser != null && existingUser.Id != user.Id) { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - if (await VerifySecretAsync(user, secret)) - { - var result = await base.UpdateSecurityStampAsync(user); - if (!result.Succeeded) - { - return result; - } - - await SaveUserAsync(user); - await _pushService.PushLogOutAsync(user.Id); - return IdentityResult.Success; - } - - Logger.LogWarning("Refresh security stamp failed for user {userId}.", user.Id); - return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); + return IdentityResult.Failed(_identityErrorDescriber.DuplicateEmail(newEmail)); } - public async Task UpdateTwoFactorProviderAsync(User user, TwoFactorProviderType type, bool setEnabled = true, bool logEvent = true) + var previousState = new { - SetTwoFactorProvider(user, type, setEnabled); - await SaveUserAsync(user); - if (logEvent) - { - await _eventService.LogUserEventAsync(user.Id, EventType.User_Updated2fa); - } + Key = user.Key, + MasterPassword = user.MasterPassword, + SecurityStamp = user.SecurityStamp, + Email = user.Email + }; + + var result = await UpdatePasswordHash(user, newMasterPassword); + if (!result.Succeeded) + { + return result; } - public async Task DisableTwoFactorProviderAsync(User user, TwoFactorProviderType type, - IOrganizationService organizationService) + user.Key = key; + user.Email = newEmail; + user.EmailVerified = true; + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + + if (user.Gateway == GatewayType.Stripe) { - var providers = user.GetTwoFactorProviders(); - if (!providers?.ContainsKey(type) ?? true) - { - return; - } - - providers.Remove(type); - user.SetTwoFactorProviders(providers); - await SaveUserAsync(user); - await _eventService.LogUserEventAsync(user.Id, EventType.User_Disabled2fa); - - if (!await TwoFactorIsEnabledAsync(user)) - { - await CheckPoliciesOnTwoFactorRemovalAsync(user, organizationService); - } - } - - public async Task RecoverTwoFactorAsync(string email, string secret, string recoveryCode, - IOrganizationService organizationService) - { - var user = await _userRepository.GetByEmailAsync(email); - if (user == null) - { - // No user exists. Do we want to send an email telling them this in the future? - return false; - } - - if (!await VerifySecretAsync(user, secret)) - { - return false; - } - - if (!CoreHelpers.FixedTimeEquals(user.TwoFactorRecoveryCode, recoveryCode)) - { - return false; - } - - user.TwoFactorProviders = null; - user.TwoFactorRecoveryCode = CoreHelpers.SecureRandomString(32, upper: false, special: false); - await SaveUserAsync(user); - await _mailService.SendRecoverTwoFactorEmail(user.Email, DateTime.UtcNow, _currentContext.IpAddress); - await _eventService.LogUserEventAsync(user.Id, EventType.User_Recovered2fa); - await CheckPoliciesOnTwoFactorRemovalAsync(user, organizationService); - - return true; - } - - public async Task> SignUpPremiumAsync(User user, string paymentToken, - PaymentMethodType paymentMethodType, short additionalStorageGb, UserLicense license, - TaxInfo taxInfo) - { - if (user.Premium) - { - throw new BadRequestException("Already a premium user."); - } - - if (additionalStorageGb < 0) - { - throw new BadRequestException("You can't subtract storage!"); - } - - if ((paymentMethodType == PaymentMethodType.GoogleInApp || - paymentMethodType == PaymentMethodType.AppleInApp) && additionalStorageGb > 0) - { - throw new BadRequestException("You cannot add storage with this payment method."); - } - - string paymentIntentClientSecret = null; - IPaymentService paymentService = null; - if (_globalSettings.SelfHosted) - { - if (license == null || !_licenseService.VerifyLicense(license)) - { - throw new BadRequestException("Invalid license."); - } - - if (!license.CanUse(user)) - { - throw new BadRequestException("This license is not valid for this user."); - } - - var dir = $"{_globalSettings.LicenseDirectory}/user"; - Directory.CreateDirectory(dir); - using var fs = File.OpenWrite(Path.Combine(dir, $"{user.Id}.json")); - await JsonSerializer.SerializeAsync(fs, license, JsonHelpers.Indented); - } - else - { - paymentIntentClientSecret = await _paymentService.PurchasePremiumAsync(user, paymentMethodType, - paymentToken, additionalStorageGb, taxInfo); - } - - user.Premium = true; - user.RevisionDate = DateTime.UtcNow; - - if (_globalSettings.SelfHosted) - { - user.MaxStorageGb = 10240; // 10 TB - user.LicenseKey = license.LicenseKey; - user.PremiumExpirationDate = license.Expires; - } - else - { - user.MaxStorageGb = (short)(1 + additionalStorageGb); - user.LicenseKey = CoreHelpers.SecureRandomString(20); - } try { - await SaveUserAsync(user); - await _pushService.PushSyncVaultAsync(user.Id); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.UpgradePlan, user) - { - Storage = user.MaxStorageGb, - PlanName = PremiumPlanId, - }); + await _stripeSyncService.UpdateCustomerEmailAddress(user.GatewayCustomerId, + user.BillingEmailAddress()); } - catch when (!_globalSettings.SelfHosted) + catch (Exception ex) { - await paymentService.CancelAndRecoverChargesAsync(user); - throw; - } - return new Tuple(string.IsNullOrWhiteSpace(paymentIntentClientSecret), - paymentIntentClientSecret); - } + //if sync to strip fails, update email and securityStamp to previous + user.Key = previousState.Key; + user.Email = previousState.Email; + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + user.MasterPassword = previousState.MasterPassword; + user.SecurityStamp = previousState.SecurityStamp; - public async Task IapCheckAsync(User user, PaymentMethodType paymentMethodType) - { - if (paymentMethodType != PaymentMethodType.AppleInApp) - { - throw new BadRequestException("Payment method not supported for in-app purchases."); - } - - if (user.Premium) - { - throw new BadRequestException("Already a premium user."); - } - - if (!string.IsNullOrWhiteSpace(user.GatewayCustomerId)) - { - var customerService = new Stripe.CustomerService(); - var customer = await customerService.GetAsync(user.GatewayCustomerId); - if (customer != null && customer.Balance != 0) + await _userRepository.ReplaceAsync(user); + return IdentityResult.Failed(new IdentityError { - throw new BadRequestException("Customer balance cannot exist when using in-app purchases."); - } + Description = ex.Message + }); } } - public async Task UpdateLicenseAsync(User user, UserLicense license) + await _pushService.PushLogOutAsync(user.Id); + + return IdentityResult.Success; + } + + public override Task ChangePasswordAsync(User user, string masterPassword, string newMasterPassword) + { + throw new NotImplementedException(); + } + + public async Task ChangePasswordAsync(User user, string masterPassword, string newMasterPassword, string passwordHint, + string key) + { + if (user == null) { - if (!_globalSettings.SelfHosted) + throw new ArgumentNullException(nameof(user)); + } + + if (await CheckPasswordAsync(user, masterPassword)) + { + var result = await UpdatePasswordHash(user, newMasterPassword); + if (!result.Succeeded) { - throw new InvalidOperationException("Licenses require self hosting."); + return result; } - if (license?.LicenseType != null && license.LicenseType != LicenseType.User) + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + user.Key = key; + user.MasterPasswordHint = passwordHint; + + await _userRepository.ReplaceAsync(user); + await _eventService.LogUserEventAsync(user.Id, EventType.User_ChangedPassword); + await _pushService.PushLogOutAsync(user.Id); + + return IdentityResult.Success; + } + + Logger.LogWarning("Change password failed for user {userId}.", user.Id); + return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); + } + + public async Task SetPasswordAsync(User user, string masterPassword, string key, + string orgIdentifier = null) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + if (!string.IsNullOrWhiteSpace(user.MasterPassword)) + { + Logger.LogWarning("Change password failed for user {userId} - already has password.", user.Id); + return IdentityResult.Failed(_identityErrorDescriber.UserAlreadyHasPassword()); + } + + var result = await UpdatePasswordHash(user, masterPassword, true, false); + if (!result.Succeeded) + { + return result; + } + + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + user.Key = key; + + await _userRepository.ReplaceAsync(user); + await _eventService.LogUserEventAsync(user.Id, EventType.User_ChangedPassword); + + if (!string.IsNullOrWhiteSpace(orgIdentifier)) + { + await _organizationService.AcceptUserAsync(orgIdentifier, user, this); + } + + return IdentityResult.Success; + } + + public async Task SetKeyConnectorKeyAsync(User user, string key, string orgIdentifier) + { + var identityResult = CheckCanUseKeyConnector(user); + if (identityResult != null) + { + return identityResult; + } + + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + user.Key = key; + user.UsesKeyConnector = true; + + await _userRepository.ReplaceAsync(user); + await _eventService.LogUserEventAsync(user.Id, EventType.User_MigratedKeyToKeyConnector); + + await _organizationService.AcceptUserAsync(orgIdentifier, user, this); + + return IdentityResult.Success; + } + + public async Task ConvertToKeyConnectorAsync(User user) + { + var identityResult = CheckCanUseKeyConnector(user); + if (identityResult != null) + { + return identityResult; + } + + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + user.MasterPassword = null; + user.UsesKeyConnector = true; + + await _userRepository.ReplaceAsync(user); + await _eventService.LogUserEventAsync(user.Id, EventType.User_MigratedKeyToKeyConnector); + + return IdentityResult.Success; + } + + private IdentityResult CheckCanUseKeyConnector(User user) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + if (user.UsesKeyConnector) + { + Logger.LogWarning("Already uses Key Connector."); + return IdentityResult.Failed(_identityErrorDescriber.UserAlreadyHasPassword()); + } + + if (_currentContext.Organizations.Any(u => + u.Type is OrganizationUserType.Owner or OrganizationUserType.Admin)) + { + throw new BadRequestException("Cannot use Key Connector when admin or owner of an organization."); + } + + return null; + } + + public async Task AdminResetPasswordAsync(OrganizationUserType callingUserType, Guid orgId, Guid id, string newMasterPassword, string key) + { + // Org must be able to use reset password + var org = await _organizationRepository.GetByIdAsync(orgId); + if (org == null || !org.UseResetPassword) + { + throw new BadRequestException("Organization does not allow password reset."); + } + + // Enterprise policy must be enabled + var resetPasswordPolicy = + await _policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword); + if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled) + { + throw new BadRequestException("Organization does not have the password reset policy enabled."); + } + + // Org User must be confirmed and have a ResetPasswordKey + var orgUser = await _organizationUserRepository.GetByIdAsync(id); + if (orgUser == null || orgUser.Status != OrganizationUserStatusType.Confirmed || + orgUser.OrganizationId != orgId || string.IsNullOrEmpty(orgUser.ResetPasswordKey) || + !orgUser.UserId.HasValue) + { + throw new BadRequestException("Organization User not valid"); + } + + // Calling User must be of higher/equal user type to reset user's password + var canAdjustPassword = false; + switch (callingUserType) + { + case OrganizationUserType.Owner: + canAdjustPassword = true; + break; + case OrganizationUserType.Admin: + canAdjustPassword = orgUser.Type != OrganizationUserType.Owner; + break; + case OrganizationUserType.Custom: + canAdjustPassword = orgUser.Type != OrganizationUserType.Owner && + orgUser.Type != OrganizationUserType.Admin; + break; + } + + if (!canAdjustPassword) + { + throw new BadRequestException("Calling user does not have permission to reset this user's master password"); + } + + var user = await GetUserByIdAsync(orgUser.UserId.Value); + if (user == null) + { + throw new NotFoundException(); + } + + if (user.UsesKeyConnector) + { + throw new BadRequestException("Cannot reset password of a user with Key Connector."); + } + + var result = await UpdatePasswordHash(user, newMasterPassword); + if (!result.Succeeded) + { + return result; + } + + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + user.Key = key; + user.ForcePasswordReset = true; + + await _userRepository.ReplaceAsync(user); + await _mailService.SendAdminResetPasswordEmailAsync(user.Email, user.Name, org.Name); + await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_AdminResetPassword); + await _pushService.PushLogOutAsync(user.Id); + + return IdentityResult.Success; + } + + public async Task UpdateTempPasswordAsync(User user, string newMasterPassword, string key, string hint) + { + if (!user.ForcePasswordReset) + { + throw new BadRequestException("User does not have a temporary password to update."); + } + + var result = await UpdatePasswordHash(user, newMasterPassword); + if (!result.Succeeded) + { + return result; + } + + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + user.ForcePasswordReset = false; + user.Key = key; + user.MasterPasswordHint = hint; + + await _userRepository.ReplaceAsync(user); + await _mailService.SendUpdatedTempPasswordEmailAsync(user.Email, user.Name); + await _eventService.LogUserEventAsync(user.Id, EventType.User_UpdatedTempPassword); + await _pushService.PushLogOutAsync(user.Id); + + return IdentityResult.Success; + } + + public async Task ChangeKdfAsync(User user, string masterPassword, string newMasterPassword, + string key, KdfType kdf, int kdfIterations) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + if (await CheckPasswordAsync(user, masterPassword)) + { + var result = await UpdatePasswordHash(user, newMasterPassword); + if (!result.Succeeded) { - throw new BadRequestException("Organization licenses cannot be applied to a user. " - + "Upload this license from the Organization settings page."); + return result; } + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + user.Key = key; + user.Kdf = kdf; + user.KdfIterations = kdfIterations; + await _userRepository.ReplaceAsync(user); + await _pushService.PushLogOutAsync(user.Id); + return IdentityResult.Success; + } + + Logger.LogWarning("Change KDF failed for user {userId}.", user.Id); + return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); + } + + public async Task UpdateKeyAsync(User user, string masterPassword, string key, string privateKey, + IEnumerable ciphers, IEnumerable folders, IEnumerable sends) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + if (await CheckPasswordAsync(user, masterPassword)) + { + user.RevisionDate = user.AccountRevisionDate = DateTime.UtcNow; + user.SecurityStamp = Guid.NewGuid().ToString(); + user.Key = key; + user.PrivateKey = privateKey; + if (ciphers.Any() || folders.Any() || sends.Any()) + { + await _cipherRepository.UpdateUserKeysAndCiphersAsync(user, ciphers, folders, sends); + } + else + { + await _userRepository.ReplaceAsync(user); + } + + await _pushService.PushLogOutAsync(user.Id); + return IdentityResult.Success; + } + + Logger.LogWarning("Update key failed for user {userId}.", user.Id); + return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); + } + + public async Task RefreshSecurityStampAsync(User user, string secret) + { + if (user == null) + { + throw new ArgumentNullException(nameof(user)); + } + + if (await VerifySecretAsync(user, secret)) + { + var result = await base.UpdateSecurityStampAsync(user); + if (!result.Succeeded) + { + return result; + } + + await SaveUserAsync(user); + await _pushService.PushLogOutAsync(user.Id); + return IdentityResult.Success; + } + + Logger.LogWarning("Refresh security stamp failed for user {userId}.", user.Id); + return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); + } + + public async Task UpdateTwoFactorProviderAsync(User user, TwoFactorProviderType type, bool setEnabled = true, bool logEvent = true) + { + SetTwoFactorProvider(user, type, setEnabled); + await SaveUserAsync(user); + if (logEvent) + { + await _eventService.LogUserEventAsync(user.Id, EventType.User_Updated2fa); + } + } + + public async Task DisableTwoFactorProviderAsync(User user, TwoFactorProviderType type, + IOrganizationService organizationService) + { + var providers = user.GetTwoFactorProviders(); + if (!providers?.ContainsKey(type) ?? true) + { + return; + } + + providers.Remove(type); + user.SetTwoFactorProviders(providers); + await SaveUserAsync(user); + await _eventService.LogUserEventAsync(user.Id, EventType.User_Disabled2fa); + + if (!await TwoFactorIsEnabledAsync(user)) + { + await CheckPoliciesOnTwoFactorRemovalAsync(user, organizationService); + } + } + + public async Task RecoverTwoFactorAsync(string email, string secret, string recoveryCode, + IOrganizationService organizationService) + { + var user = await _userRepository.GetByEmailAsync(email); + if (user == null) + { + // No user exists. Do we want to send an email telling them this in the future? + return false; + } + + if (!await VerifySecretAsync(user, secret)) + { + return false; + } + + if (!CoreHelpers.FixedTimeEquals(user.TwoFactorRecoveryCode, recoveryCode)) + { + return false; + } + + user.TwoFactorProviders = null; + user.TwoFactorRecoveryCode = CoreHelpers.SecureRandomString(32, upper: false, special: false); + await SaveUserAsync(user); + await _mailService.SendRecoverTwoFactorEmail(user.Email, DateTime.UtcNow, _currentContext.IpAddress); + await _eventService.LogUserEventAsync(user.Id, EventType.User_Recovered2fa); + await CheckPoliciesOnTwoFactorRemovalAsync(user, organizationService); + + return true; + } + + public async Task> SignUpPremiumAsync(User user, string paymentToken, + PaymentMethodType paymentMethodType, short additionalStorageGb, UserLicense license, + TaxInfo taxInfo) + { + if (user.Premium) + { + throw new BadRequestException("Already a premium user."); + } + + if (additionalStorageGb < 0) + { + throw new BadRequestException("You can't subtract storage!"); + } + + if ((paymentMethodType == PaymentMethodType.GoogleInApp || + paymentMethodType == PaymentMethodType.AppleInApp) && additionalStorageGb > 0) + { + throw new BadRequestException("You cannot add storage with this payment method."); + } + + string paymentIntentClientSecret = null; + IPaymentService paymentService = null; + if (_globalSettings.SelfHosted) + { if (license == null || !_licenseService.VerifyLicense(license)) { throw new BadRequestException("Invalid license."); @@ -1097,401 +1005,492 @@ namespace Bit.Core.Services Directory.CreateDirectory(dir); using var fs = File.OpenWrite(Path.Combine(dir, $"{user.Id}.json")); await JsonSerializer.SerializeAsync(fs, license, JsonHelpers.Indented); + } + else + { + paymentIntentClientSecret = await _paymentService.PurchasePremiumAsync(user, paymentMethodType, + paymentToken, additionalStorageGb, taxInfo); + } - user.Premium = license.Premium; - user.RevisionDate = DateTime.UtcNow; - user.MaxStorageGb = _globalSettings.SelfHosted ? 10240 : license.MaxStorageGb; // 10 TB + user.Premium = true; + user.RevisionDate = DateTime.UtcNow; + + if (_globalSettings.SelfHosted) + { + user.MaxStorageGb = 10240; // 10 TB user.LicenseKey = license.LicenseKey; user.PremiumExpirationDate = license.Expires; + } + else + { + user.MaxStorageGb = (short)(1 + additionalStorageGb); + user.LicenseKey = CoreHelpers.SecureRandomString(20); + } + + try + { await SaveUserAsync(user); - } - - public async Task AdjustStorageAsync(User user, short storageAdjustmentGb) - { - if (user == null) - { - throw new ArgumentNullException(nameof(user)); - } - - if (!user.Premium) - { - throw new BadRequestException("Not a premium user."); - } - - var secret = await BillingHelpers.AdjustStorageAsync(_paymentService, user, storageAdjustmentGb, - StoragePlanId); + await _pushService.PushSyncVaultAsync(user.Id); await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.AdjustStorage, user) + new ReferenceEvent(ReferenceEventType.UpgradePlan, user) { - Storage = storageAdjustmentGb, - PlanName = StoragePlanId, - }); - await SaveUserAsync(user); - return secret; - } - - public async Task ReplacePaymentMethodAsync(User user, string paymentToken, PaymentMethodType paymentMethodType, TaxInfo taxInfo) - { - if (paymentToken.StartsWith("btok_")) - { - throw new BadRequestException("Invalid token."); - } - - var updated = await _paymentService.UpdatePaymentMethodAsync(user, paymentMethodType, paymentToken, taxInfo: taxInfo); - if (updated) - { - await SaveUserAsync(user); - } - } - - public async Task CancelPremiumAsync(User user, bool? endOfPeriod = null, bool accountDelete = false) - { - var eop = endOfPeriod.GetValueOrDefault(true); - if (!endOfPeriod.HasValue && user.PremiumExpirationDate.HasValue && - user.PremiumExpirationDate.Value < DateTime.UtcNow) - { - eop = false; - } - await _paymentService.CancelSubscriptionAsync(user, eop, accountDelete); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.CancelSubscription, user) - { - EndOfPeriod = eop, + Storage = user.MaxStorageGb, + PlanName = PremiumPlanId, }); } - - public async Task ReinstatePremiumAsync(User user) + catch when (!_globalSettings.SelfHosted) { - await _paymentService.ReinstateSubscriptionAsync(user); - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.ReinstateSubscription, user)); + await paymentService.CancelAndRecoverChargesAsync(user); + throw; + } + return new Tuple(string.IsNullOrWhiteSpace(paymentIntentClientSecret), + paymentIntentClientSecret); + } + + public async Task IapCheckAsync(User user, PaymentMethodType paymentMethodType) + { + if (paymentMethodType != PaymentMethodType.AppleInApp) + { + throw new BadRequestException("Payment method not supported for in-app purchases."); } - public async Task EnablePremiumAsync(Guid userId, DateTime? expirationDate) + if (user.Premium) { - var user = await _userRepository.GetByIdAsync(userId); - await EnablePremiumAsync(user, expirationDate); + throw new BadRequestException("Already a premium user."); } - public async Task EnablePremiumAsync(User user, DateTime? expirationDate) + if (!string.IsNullOrWhiteSpace(user.GatewayCustomerId)) { - if (user != null && !user.Premium && user.Gateway.HasValue) + var customerService = new Stripe.CustomerService(); + var customer = await customerService.GetAsync(user.GatewayCustomerId); + if (customer != null && customer.Balance != 0) { - user.Premium = true; - user.PremiumExpirationDate = expirationDate; - user.RevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); + throw new BadRequestException("Customer balance cannot exist when using in-app purchases."); } } + } - public async Task DisablePremiumAsync(Guid userId, DateTime? expirationDate) + public async Task UpdateLicenseAsync(User user, UserLicense license) + { + if (!_globalSettings.SelfHosted) { - var user = await _userRepository.GetByIdAsync(userId); - await DisablePremiumAsync(user, expirationDate); + throw new InvalidOperationException("Licenses require self hosting."); } - public async Task DisablePremiumAsync(User user, DateTime? expirationDate) + if (license?.LicenseType != null && license.LicenseType != LicenseType.User) { - if (user != null && user.Premium) - { - user.Premium = false; - user.PremiumExpirationDate = expirationDate; - user.RevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); - } + throw new BadRequestException("Organization licenses cannot be applied to a user. " + + "Upload this license from the Organization settings page."); } - public async Task UpdatePremiumExpirationAsync(Guid userId, DateTime? expirationDate) + if (license == null || !_licenseService.VerifyLicense(license)) { - var user = await _userRepository.GetByIdAsync(userId); - if (user != null) - { - user.PremiumExpirationDate = expirationDate; - user.RevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); - } + throw new BadRequestException("Invalid license."); } - public async Task GenerateLicenseAsync(User user, SubscriptionInfo subscriptionInfo = null, - int? version = null) + if (!license.CanUse(user)) { - if (user == null) - { - throw new NotFoundException(); - } - - if (subscriptionInfo == null && user.Gateway != null) - { - subscriptionInfo = await _paymentService.GetSubscriptionAsync(user); - } - - return subscriptionInfo == null ? new UserLicense(user, _licenseService) : - new UserLicense(user, subscriptionInfo, _licenseService); + throw new BadRequestException("This license is not valid for this user."); } - public override async Task CheckPasswordAsync(User user, string password) + var dir = $"{_globalSettings.LicenseDirectory}/user"; + Directory.CreateDirectory(dir); + using var fs = File.OpenWrite(Path.Combine(dir, $"{user.Id}.json")); + await JsonSerializer.SerializeAsync(fs, license, JsonHelpers.Indented); + + user.Premium = license.Premium; + user.RevisionDate = DateTime.UtcNow; + user.MaxStorageGb = _globalSettings.SelfHosted ? 10240 : license.MaxStorageGb; // 10 TB + user.LicenseKey = license.LicenseKey; + user.PremiumExpirationDate = license.Expires; + await SaveUserAsync(user); + } + + public async Task AdjustStorageAsync(User user, short storageAdjustmentGb) + { + if (user == null) { - if (user == null) - { - return false; - } - - var result = await base.VerifyPasswordAsync(Store as IUserPasswordStore, user, password); - if (result == PasswordVerificationResult.SuccessRehashNeeded) - { - await UpdatePasswordHash(user, password, false, false); - user.RevisionDate = DateTime.UtcNow; - await _userRepository.ReplaceAsync(user); - } - - var success = result != PasswordVerificationResult.Failed; - if (!success) - { - Logger.LogWarning(0, "Invalid password for user {userId}.", user.Id); - } - return success; + throw new ArgumentNullException(nameof(user)); } - public async Task CanAccessPremium(ITwoFactorProvidersUser user) + if (!user.Premium) { - var userId = user.GetUserId(); - if (!userId.HasValue) - { - return false; - } - - return user.GetPremium() || await this.HasPremiumFromOrganization(user); + throw new BadRequestException("Not a premium user."); } - public async Task HasPremiumFromOrganization(ITwoFactorProvidersUser user) + var secret = await BillingHelpers.AdjustStorageAsync(_paymentService, user, storageAdjustmentGb, + StoragePlanId); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.AdjustStorage, user) + { + Storage = storageAdjustmentGb, + PlanName = StoragePlanId, + }); + await SaveUserAsync(user); + return secret; + } + + public async Task ReplacePaymentMethodAsync(User user, string paymentToken, PaymentMethodType paymentMethodType, TaxInfo taxInfo) + { + if (paymentToken.StartsWith("btok_")) { - var userId = user.GetUserId(); - if (!userId.HasValue) - { - return false; - } - - // orgUsers in the Invited status are not associated with a userId yet, so this will get - // orgUsers in Accepted and Confirmed states only - var orgUsers = await _organizationUserRepository.GetManyByUserAsync(userId.Value); - - if (!orgUsers.Any()) - { - return false; - } - - var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - return orgUsers.Any(ou => - orgAbilities.TryGetValue(ou.OrganizationId, out var orgAbility) && - orgAbility.UsersGetPremium && - orgAbility.Enabled); + throw new BadRequestException("Invalid token."); } - public async Task TwoFactorIsEnabledAsync(ITwoFactorProvidersUser user) + var updated = await _paymentService.UpdatePaymentMethodAsync(user, paymentMethodType, paymentToken, taxInfo: taxInfo); + if (updated) { - var providers = user.GetTwoFactorProviders(); - if (providers == null) - { - return false; - } + await SaveUserAsync(user); + } + } - foreach (var p in providers) + public async Task CancelPremiumAsync(User user, bool? endOfPeriod = null, bool accountDelete = false) + { + var eop = endOfPeriod.GetValueOrDefault(true); + if (!endOfPeriod.HasValue && user.PremiumExpirationDate.HasValue && + user.PremiumExpirationDate.Value < DateTime.UtcNow) + { + eop = false; + } + await _paymentService.CancelSubscriptionAsync(user, eop, accountDelete); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.CancelSubscription, user) { - if (p.Value?.Enabled ?? false) - { - if (!TwoFactorProvider.RequiresPremium(p.Key)) - { - return true; - } - if (await CanAccessPremium(user)) - { - return true; - } - } - } + EndOfPeriod = eop, + }); + } + + public async Task ReinstatePremiumAsync(User user) + { + await _paymentService.ReinstateSubscriptionAsync(user); + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.ReinstateSubscription, user)); + } + + public async Task EnablePremiumAsync(Guid userId, DateTime? expirationDate) + { + var user = await _userRepository.GetByIdAsync(userId); + await EnablePremiumAsync(user, expirationDate); + } + + public async Task EnablePremiumAsync(User user, DateTime? expirationDate) + { + if (user != null && !user.Premium && user.Gateway.HasValue) + { + user.Premium = true; + user.PremiumExpirationDate = expirationDate; + user.RevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + } + } + + public async Task DisablePremiumAsync(Guid userId, DateTime? expirationDate) + { + var user = await _userRepository.GetByIdAsync(userId); + await DisablePremiumAsync(user, expirationDate); + } + + public async Task DisablePremiumAsync(User user, DateTime? expirationDate) + { + if (user != null && user.Premium) + { + user.Premium = false; + user.PremiumExpirationDate = expirationDate; + user.RevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + } + } + + public async Task UpdatePremiumExpirationAsync(Guid userId, DateTime? expirationDate) + { + var user = await _userRepository.GetByIdAsync(userId); + if (user != null) + { + user.PremiumExpirationDate = expirationDate; + user.RevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + } + } + + public async Task GenerateLicenseAsync(User user, SubscriptionInfo subscriptionInfo = null, + int? version = null) + { + if (user == null) + { + throw new NotFoundException(); + } + + if (subscriptionInfo == null && user.Gateway != null) + { + subscriptionInfo = await _paymentService.GetSubscriptionAsync(user); + } + + return subscriptionInfo == null ? new UserLicense(user, _licenseService) : + new UserLicense(user, subscriptionInfo, _licenseService); + } + + public override async Task CheckPasswordAsync(User user, string password) + { + if (user == null) + { return false; } - public async Task TwoFactorProviderIsEnabledAsync(TwoFactorProviderType provider, ITwoFactorProvidersUser user) + var result = await base.VerifyPasswordAsync(Store as IUserPasswordStore, user, password); + if (result == PasswordVerificationResult.SuccessRehashNeeded) { - var providers = user.GetTwoFactorProviders(); - if (providers == null || !providers.ContainsKey(provider) || !providers[provider].Enabled) - { - return false; - } - - if (!TwoFactorProvider.RequiresPremium(provider)) - { - return true; - } - - return await CanAccessPremium(user); - } - - public async Task GenerateSignInTokenAsync(User user, string purpose) - { - var token = await GenerateUserTokenAsync(user, Options.Tokens.PasswordResetTokenProvider, - purpose); - return token; - } - - private async Task UpdatePasswordHash(User user, string newPassword, - bool validatePassword = true, bool refreshStamp = true) - { - if (validatePassword) - { - var validate = await ValidatePasswordInternal(user, newPassword); - if (!validate.Succeeded) - { - return validate; - } - } - - user.MasterPassword = _passwordHasher.HashPassword(user, newPassword); - if (refreshStamp) - { - user.SecurityStamp = Guid.NewGuid().ToString(); - } - - return IdentityResult.Success; - } - - private async Task ValidatePasswordInternal(User user, string password) - { - var errors = new List(); - foreach (var v in _passwordValidators) - { - var result = await v.ValidateAsync(this, user, password); - if (!result.Succeeded) - { - errors.AddRange(result.Errors); - } - } - - if (errors.Count > 0) - { - Logger.LogWarning("User {userId} password validation failed: {errors}.", await GetUserIdAsync(user), - string.Join(";", errors.Select(e => e.Code))); - return IdentityResult.Failed(errors.ToArray()); - } - - return IdentityResult.Success; - } - - public void SetTwoFactorProvider(User user, TwoFactorProviderType type, bool setEnabled = true) - { - var providers = user.GetTwoFactorProviders(); - if (!providers?.ContainsKey(type) ?? true) - { - return; - } - - if (setEnabled) - { - providers[type].Enabled = true; - } - user.SetTwoFactorProviders(providers); - - if (string.IsNullOrWhiteSpace(user.TwoFactorRecoveryCode)) - { - user.TwoFactorRecoveryCode = CoreHelpers.SecureRandomString(32, upper: false, special: false); - } - } - - private async Task CheckPoliciesOnTwoFactorRemovalAsync(User user, IOrganizationService organizationService) - { - var twoFactorPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(user.Id, - PolicyType.TwoFactorAuthentication); - - var removeOrgUserTasks = twoFactorPolicies.Select(async p => - { - await organizationService.DeleteUserAsync(p.OrganizationId, user.Id); - var organization = await _organizationRepository.GetByIdAsync(p.OrganizationId); - await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync( - organization.Name, user.Email); - }).ToArray(); - - await Task.WhenAll(removeOrgUserTasks); - } - - public override async Task ConfirmEmailAsync(User user, string token) - { - var result = await base.ConfirmEmailAsync(user, token); - if (result.Succeeded) - { - await _referenceEventService.RaiseEventAsync( - new ReferenceEvent(ReferenceEventType.ConfirmEmailAddress, user)); - } - return result; - } - - public async Task RotateApiKeyAsync(User user) - { - user.ApiKey = CoreHelpers.SecureRandomString(30); + await UpdatePasswordHash(user, password, false, false); user.RevisionDate = DateTime.UtcNow; await _userRepository.ReplaceAsync(user); } - public async Task SendOTPAsync(User user) + var success = result != PasswordVerificationResult.Failed; + if (!success) { - if (user.Email == null) + Logger.LogWarning(0, "Invalid password for user {userId}.", user.Id); + } + return success; + } + + public async Task CanAccessPremium(ITwoFactorProvidersUser user) + { + var userId = user.GetUserId(); + if (!userId.HasValue) + { + return false; + } + + return user.GetPremium() || await this.HasPremiumFromOrganization(user); + } + + public async Task HasPremiumFromOrganization(ITwoFactorProvidersUser user) + { + var userId = user.GetUserId(); + if (!userId.HasValue) + { + return false; + } + + // orgUsers in the Invited status are not associated with a userId yet, so this will get + // orgUsers in Accepted and Confirmed states only + var orgUsers = await _organizationUserRepository.GetManyByUserAsync(userId.Value); + + if (!orgUsers.Any()) + { + return false; + } + + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + return orgUsers.Any(ou => + orgAbilities.TryGetValue(ou.OrganizationId, out var orgAbility) && + orgAbility.UsersGetPremium && + orgAbility.Enabled); + } + + public async Task TwoFactorIsEnabledAsync(ITwoFactorProvidersUser user) + { + var providers = user.GetTwoFactorProviders(); + if (providers == null) + { + return false; + } + + foreach (var p in providers) + { + if (p.Value?.Enabled ?? false) { - throw new BadRequestException("No user email."); + if (!TwoFactorProvider.RequiresPremium(p.Key)) + { + return true; + } + if (await CanAccessPremium(user)) + { + return true; + } } + } + return false; + } - if (!user.UsesKeyConnector) + public async Task TwoFactorProviderIsEnabledAsync(TwoFactorProviderType provider, ITwoFactorProvidersUser user) + { + var providers = user.GetTwoFactorProviders(); + if (providers == null || !providers.ContainsKey(provider) || !providers[provider].Enabled) + { + return false; + } + + if (!TwoFactorProvider.RequiresPremium(provider)) + { + return true; + } + + return await CanAccessPremium(user); + } + + public async Task GenerateSignInTokenAsync(User user, string purpose) + { + var token = await GenerateUserTokenAsync(user, Options.Tokens.PasswordResetTokenProvider, + purpose); + return token; + } + + private async Task UpdatePasswordHash(User user, string newPassword, + bool validatePassword = true, bool refreshStamp = true) + { + if (validatePassword) + { + var validate = await ValidatePasswordInternal(user, newPassword); + if (!validate.Succeeded) { - throw new BadRequestException("Not using Key Connector."); + return validate; } - - var token = await base.GenerateUserTokenAsync(user, TokenOptions.DefaultEmailProvider, - "otp:" + user.Email); - await _mailService.SendOTPEmailAsync(user.Email, token); } - public Task VerifyOTPAsync(User user, string token) + user.MasterPassword = _passwordHasher.HashPassword(user, newPassword); + if (refreshStamp) { - return base.VerifyUserTokenAsync(user, TokenOptions.DefaultEmailProvider, - "otp:" + user.Email, token); + user.SecurityStamp = Guid.NewGuid().ToString(); } - public async Task VerifySecretAsync(User user, string secret) - { - return user.UsesKeyConnector - ? await VerifyOTPAsync(user, secret) - : await CheckPasswordAsync(user, secret); - } + return IdentityResult.Success; + } - public async Task Needs2FABecauseNewDeviceAsync(User user, string deviceIdentifier, string grantType) + private async Task ValidatePasswordInternal(User user, string password) + { + var errors = new List(); + foreach (var v in _passwordValidators) { - return CanEditDeviceVerificationSettings(user) - && user.UnknownDeviceVerificationEnabled - && grantType != "authorization_code" - && await IsNewDeviceAndNotTheFirstOneAsync(user, deviceIdentifier); - } - - public bool CanEditDeviceVerificationSettings(User user) - { - return _globalSettings.TwoFactorAuth.EmailOnNewDeviceLogin - && user.EmailVerified - && !user.UsesKeyConnector - && !(user.GetTwoFactorProviders()?.Any() ?? false); - } - - private async Task IsNewDeviceAndNotTheFirstOneAsync(User user, string deviceIdentifier) - { - if (user == null) + var result = await v.ValidateAsync(this, user, password); + if (!result.Succeeded) { - return default; + errors.AddRange(result.Errors); } + } - var devices = await _deviceRepository.GetManyByUserIdAsync(user.Id); - if (!devices.Any()) - { - return false; - } + if (errors.Count > 0) + { + Logger.LogWarning("User {userId} password validation failed: {errors}.", await GetUserIdAsync(user), + string.Join(";", errors.Select(e => e.Code))); + return IdentityResult.Failed(errors.ToArray()); + } - return !devices.Any(d => d.Identifier == deviceIdentifier); + return IdentityResult.Success; + } + + public void SetTwoFactorProvider(User user, TwoFactorProviderType type, bool setEnabled = true) + { + var providers = user.GetTwoFactorProviders(); + if (!providers?.ContainsKey(type) ?? true) + { + return; + } + + if (setEnabled) + { + providers[type].Enabled = true; + } + user.SetTwoFactorProviders(providers); + + if (string.IsNullOrWhiteSpace(user.TwoFactorRecoveryCode)) + { + user.TwoFactorRecoveryCode = CoreHelpers.SecureRandomString(32, upper: false, special: false); } } + + private async Task CheckPoliciesOnTwoFactorRemovalAsync(User user, IOrganizationService organizationService) + { + var twoFactorPolicies = await _policyRepository.GetManyByTypeApplicableToUserIdAsync(user.Id, + PolicyType.TwoFactorAuthentication); + + var removeOrgUserTasks = twoFactorPolicies.Select(async p => + { + await organizationService.DeleteUserAsync(p.OrganizationId, user.Id); + var organization = await _organizationRepository.GetByIdAsync(p.OrganizationId); + await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync( + organization.Name, user.Email); + }).ToArray(); + + await Task.WhenAll(removeOrgUserTasks); + } + + public override async Task ConfirmEmailAsync(User user, string token) + { + var result = await base.ConfirmEmailAsync(user, token); + if (result.Succeeded) + { + await _referenceEventService.RaiseEventAsync( + new ReferenceEvent(ReferenceEventType.ConfirmEmailAddress, user)); + } + return result; + } + + public async Task RotateApiKeyAsync(User user) + { + user.ApiKey = CoreHelpers.SecureRandomString(30); + user.RevisionDate = DateTime.UtcNow; + await _userRepository.ReplaceAsync(user); + } + + public async Task SendOTPAsync(User user) + { + if (user.Email == null) + { + throw new BadRequestException("No user email."); + } + + if (!user.UsesKeyConnector) + { + throw new BadRequestException("Not using Key Connector."); + } + + var token = await base.GenerateUserTokenAsync(user, TokenOptions.DefaultEmailProvider, + "otp:" + user.Email); + await _mailService.SendOTPEmailAsync(user.Email, token); + } + + public Task VerifyOTPAsync(User user, string token) + { + return base.VerifyUserTokenAsync(user, TokenOptions.DefaultEmailProvider, + "otp:" + user.Email, token); + } + + public async Task VerifySecretAsync(User user, string secret) + { + return user.UsesKeyConnector + ? await VerifyOTPAsync(user, secret) + : await CheckPasswordAsync(user, secret); + } + + public async Task Needs2FABecauseNewDeviceAsync(User user, string deviceIdentifier, string grantType) + { + return CanEditDeviceVerificationSettings(user) + && user.UnknownDeviceVerificationEnabled + && grantType != "authorization_code" + && await IsNewDeviceAndNotTheFirstOneAsync(user, deviceIdentifier); + } + + public bool CanEditDeviceVerificationSettings(User user) + { + return _globalSettings.TwoFactorAuth.EmailOnNewDeviceLogin + && user.EmailVerified + && !user.UsesKeyConnector + && !(user.GetTwoFactorProviders()?.Any() ?? false); + } + + private async Task IsNewDeviceAndNotTheFirstOneAsync(User user, string deviceIdentifier) + { + if (user == null) + { + return default; + } + + var devices = await _deviceRepository.GetManyByUserIdAsync(user.Id); + if (!devices.Any()) + { + return false; + } + + return !devices.Any(d => d.Identifier == deviceIdentifier); + } } diff --git a/src/Core/Services/NoopImplementations/NoopAttachmentStorageService.cs b/src/Core/Services/NoopImplementations/NoopAttachmentStorageService.cs index 7643fc43c..24f669c36 100644 --- a/src/Core/Services/NoopImplementations/NoopAttachmentStorageService.cs +++ b/src/Core/Services/NoopImplementations/NoopAttachmentStorageService.cs @@ -2,69 +2,68 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class NoopAttachmentStorageService : IAttachmentStorageService { - public class NoopAttachmentStorageService : IAttachmentStorageService + public FileUploadType FileUploadType => FileUploadType.Direct; + + public Task CleanupAsync(Guid cipherId) { - public FileUploadType FileUploadType => FileUploadType.Direct; + return Task.FromResult(0); + } - public Task CleanupAsync(Guid cipherId) - { - return Task.FromResult(0); - } + public Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData) + { + return Task.FromResult(0); + } - public Task DeleteAttachmentAsync(Guid cipherId, CipherAttachment.MetaData attachmentData) - { - return Task.FromResult(0); - } + public Task DeleteAttachmentsForCipherAsync(Guid cipherId) + { + return Task.FromResult(0); + } - public Task DeleteAttachmentsForCipherAsync(Guid cipherId) - { - return Task.FromResult(0); - } + public Task DeleteAttachmentsForOrganizationAsync(Guid organizationId) + { + return Task.FromResult(0); + } - public Task DeleteAttachmentsForOrganizationAsync(Guid organizationId) - { - return Task.FromResult(0); - } + public Task DeleteAttachmentsForUserAsync(Guid userId) + { + return Task.FromResult(0); + } - public Task DeleteAttachmentsForUserAsync(Guid userId) - { - return Task.FromResult(0); - } + public Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer) + { + return Task.FromResult(0); + } - public Task RollbackShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData, string originalContainer) - { - return Task.FromResult(0); - } + public Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) + { + return Task.FromResult(0); + } - public Task StartShareAttachmentAsync(Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) - { - return Task.FromResult(0); - } + public Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData) + { + return Task.FromResult(0); + } - public Task UploadNewAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentData) - { - return Task.FromResult(0); - } + public Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) + { + return Task.FromResult(0); + } - public Task UploadShareAttachmentAsync(Stream stream, Guid cipherId, Guid organizationId, CipherAttachment.MetaData attachmentData) - { - return Task.FromResult(0); - } + public Task GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) + { + return Task.FromResult((string)null); + } - public Task GetAttachmentDownloadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) - { - return Task.FromResult((string)null); - } - - public Task GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) - { - return Task.FromResult(default(string)); - } - public Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway) - { - return Task.FromResult((false, (long?)null)); - } + public Task GetAttachmentUploadUrlAsync(Cipher cipher, CipherAttachment.MetaData attachmentData) + { + return Task.FromResult(default(string)); + } + public Task<(bool, long?)> ValidateFileAsync(Cipher cipher, CipherAttachment.MetaData attachmentData, long leeway) + { + return Task.FromResult((false, (long?)null)); } } diff --git a/src/Core/Services/NoopImplementations/NoopBlockIpService.cs b/src/Core/Services/NoopImplementations/NoopBlockIpService.cs index 4ec59f09d..fd034325e 100644 --- a/src/Core/Services/NoopImplementations/NoopBlockIpService.cs +++ b/src/Core/Services/NoopImplementations/NoopBlockIpService.cs @@ -1,11 +1,10 @@ -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class NoopBlockIpService : IBlockIpService { - public class NoopBlockIpService : IBlockIpService + public Task BlockIpAsync(string ipAddress, bool permanentBlock) { - public Task BlockIpAsync(string ipAddress, bool permanentBlock) - { - // Do nothing - return Task.FromResult(0); - } + // Do nothing + return Task.FromResult(0); } } diff --git a/src/Core/Services/NoopImplementations/NoopCaptchaValidationService.cs b/src/Core/Services/NoopImplementations/NoopCaptchaValidationService.cs index 6e680227a..ef5e3366d 100644 --- a/src/Core/Services/NoopImplementations/NoopCaptchaValidationService.cs +++ b/src/Core/Services/NoopImplementations/NoopCaptchaValidationService.cs @@ -2,18 +2,17 @@ using Bit.Core.Entities; using Bit.Core.Models.Business; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class NoopCaptchaValidationService : ICaptchaValidationService { - public class NoopCaptchaValidationService : ICaptchaValidationService + public string SiteKeyResponseKeyName => null; + public string SiteKey => null; + public bool RequireCaptchaValidation(ICurrentContext currentContext, User user = null) => false; + public string GenerateCaptchaBypassToken(User user) => ""; + public Task ValidateCaptchaResponseAsync(string captchaResponse, string clientIpAddress, + User user = null) { - public string SiteKeyResponseKeyName => null; - public string SiteKey => null; - public bool RequireCaptchaValidation(ICurrentContext currentContext, User user = null) => false; - public string GenerateCaptchaBypassToken(User user) => ""; - public Task ValidateCaptchaResponseAsync(string captchaResponse, string clientIpAddress, - User user = null) - { - return Task.FromResult(new CaptchaResponse { Success = true }); - } + return Task.FromResult(new CaptchaResponse { Success = true }); } } diff --git a/src/Core/Services/NoopImplementations/NoopEventService.cs b/src/Core/Services/NoopImplementations/NoopEventService.cs index 976657bf3..7c596717e 100644 --- a/src/Core/Services/NoopImplementations/NoopEventService.cs +++ b/src/Core/Services/NoopImplementations/NoopEventService.cs @@ -2,70 +2,69 @@ using Bit.Core.Entities.Provider; using Bit.Core.Enums; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class NoopEventService : IEventService { - public class NoopEventService : IEventService + public Task LogCipherEventAsync(Cipher cipher, EventType type, DateTime? date = null) { - public Task LogCipherEventAsync(Cipher cipher, EventType type, DateTime? date = null) - { - return Task.FromResult(0); - } + return Task.FromResult(0); + } - public Task LogCipherEventsAsync(IEnumerable> events) - { - return Task.FromResult(0); - } + public Task LogCipherEventsAsync(IEnumerable> events) + { + return Task.FromResult(0); + } - public Task LogCollectionEventAsync(Collection collection, EventType type, DateTime? date = null) - { - return Task.FromResult(0); - } + public Task LogCollectionEventAsync(Collection collection, EventType type, DateTime? date = null) + { + return Task.FromResult(0); + } - public Task LogPolicyEventAsync(Policy policy, EventType type, DateTime? date = null) - { - return Task.FromResult(0); - } + public Task LogPolicyEventAsync(Policy policy, EventType type, DateTime? date = null) + { + return Task.FromResult(0); + } - public Task LogGroupEventAsync(Group group, EventType type, DateTime? date = null) - { - return Task.FromResult(0); - } + public Task LogGroupEventAsync(Group group, EventType type, DateTime? date = null) + { + return Task.FromResult(0); + } - public Task LogOrganizationEventAsync(Organization organization, EventType type, DateTime? date = null) - { - return Task.FromResult(0); - } + public Task LogOrganizationEventAsync(Organization organization, EventType type, DateTime? date = null) + { + return Task.FromResult(0); + } - public Task LogProviderUserEventAsync(ProviderUser providerUser, EventType type, DateTime? date = null) - { - return Task.FromResult(0); - } + public Task LogProviderUserEventAsync(ProviderUser providerUser, EventType type, DateTime? date = null) + { + return Task.FromResult(0); + } - public Task LogProviderUsersEventAsync(IEnumerable<(ProviderUser, EventType, DateTime?)> events) - { - return Task.FromResult(0); - } + public Task LogProviderUsersEventAsync(IEnumerable<(ProviderUser, EventType, DateTime?)> events) + { + return Task.FromResult(0); + } - public Task LogProviderOrganizationEventAsync(ProviderOrganization providerOrganization, EventType type, - DateTime? date = null) - { - return Task.FromResult(0); - } + public Task LogProviderOrganizationEventAsync(ProviderOrganization providerOrganization, EventType type, + DateTime? date = null) + { + return Task.FromResult(0); + } - public Task LogOrganizationUserEventAsync(OrganizationUser organizationUser, EventType type, - DateTime? date = null) - { - return Task.FromResult(0); - } + public Task LogOrganizationUserEventAsync(OrganizationUser organizationUser, EventType type, + DateTime? date = null) + { + return Task.FromResult(0); + } - public Task LogOrganizationUserEventsAsync(IEnumerable<(OrganizationUser, EventType, DateTime?)> events) - { - return Task.FromResult(0); - } + public Task LogOrganizationUserEventsAsync(IEnumerable<(OrganizationUser, EventType, DateTime?)> events) + { + return Task.FromResult(0); + } - public Task LogUserEventAsync(Guid userId, EventType type, DateTime? date = null) - { - return Task.FromResult(0); - } + public Task LogUserEventAsync(Guid userId, EventType type, DateTime? date = null) + { + return Task.FromResult(0); } } diff --git a/src/Core/Services/NoopImplementations/NoopEventWriteService.cs b/src/Core/Services/NoopImplementations/NoopEventWriteService.cs index 94be40b20..d7288389f 100644 --- a/src/Core/Services/NoopImplementations/NoopEventWriteService.cs +++ b/src/Core/Services/NoopImplementations/NoopEventWriteService.cs @@ -1,17 +1,16 @@ using Bit.Core.Models.Data; -namespace Bit.Core.Services -{ - public class NoopEventWriteService : IEventWriteService - { - public Task CreateAsync(IEvent e) - { - return Task.FromResult(0); - } +namespace Bit.Core.Services; - public Task CreateManyAsync(IEnumerable e) - { - return Task.FromResult(0); - } +public class NoopEventWriteService : IEventWriteService +{ + public Task CreateAsync(IEvent e) + { + return Task.FromResult(0); + } + + public Task CreateManyAsync(IEnumerable e) + { + return Task.FromResult(0); } } diff --git a/src/Core/Services/NoopImplementations/NoopLicensingService.cs b/src/Core/Services/NoopImplementations/NoopLicensingService.cs index ef5cb9b85..c79be8009 100644 --- a/src/Core/Services/NoopImplementations/NoopLicensingService.cs +++ b/src/Core/Services/NoopImplementations/NoopLicensingService.cs @@ -4,53 +4,52 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Hosting; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class NoopLicensingService : ILicensingService { - public class NoopLicensingService : ILicensingService + public NoopLicensingService( + IWebHostEnvironment environment, + GlobalSettings globalSettings) { - public NoopLicensingService( - IWebHostEnvironment environment, - GlobalSettings globalSettings) + if (!environment.IsDevelopment() && globalSettings.SelfHosted) { - if (!environment.IsDevelopment() && globalSettings.SelfHosted) - { - throw new Exception($"{nameof(NoopLicensingService)} cannot be used for self hosted instances."); - } - } - - public Task ValidateOrganizationsAsync() - { - return Task.FromResult(0); - } - - public Task ValidateUsersAsync() - { - return Task.FromResult(0); - } - - public Task ValidateUserPremiumAsync(User user) - { - return Task.FromResult(user.Premium); - } - - public bool VerifyLicense(ILicense license) - { - return true; - } - - public byte[] SignLicense(ILicense license) - { - return new byte[0]; - } - - public Task ReadOrganizationLicenseAsync(Organization organization) - { - return Task.FromResult(null); - } - - public Task ReadOrganizationLicenseAsync(Guid organizationId) - { - return Task.FromResult(null); + throw new Exception($"{nameof(NoopLicensingService)} cannot be used for self hosted instances."); } } + + public Task ValidateOrganizationsAsync() + { + return Task.FromResult(0); + } + + public Task ValidateUsersAsync() + { + return Task.FromResult(0); + } + + public Task ValidateUserPremiumAsync(User user) + { + return Task.FromResult(user.Premium); + } + + public bool VerifyLicense(ILicense license) + { + return true; + } + + public byte[] SignLicense(ILicense license) + { + return new byte[0]; + } + + public Task ReadOrganizationLicenseAsync(Organization organization) + { + return Task.FromResult(null); + } + + public Task ReadOrganizationLicenseAsync(Guid organizationId) + { + return Task.FromResult(null); + } } diff --git a/src/Core/Services/NoopImplementations/NoopMailDeliveryService.cs b/src/Core/Services/NoopImplementations/NoopMailDeliveryService.cs index dc8ef6b60..96b97b14f 100644 --- a/src/Core/Services/NoopImplementations/NoopMailDeliveryService.cs +++ b/src/Core/Services/NoopImplementations/NoopMailDeliveryService.cs @@ -1,12 +1,11 @@ using Bit.Core.Models.Mail; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class NoopMailDeliveryService : IMailDeliveryService { - public class NoopMailDeliveryService : IMailDeliveryService + public Task SendEmailAsync(MailMessage message) { - public Task SendEmailAsync(MailMessage message) - { - return Task.FromResult(0); - } + return Task.FromResult(0); } } diff --git a/src/Core/Services/NoopImplementations/NoopMailService.cs b/src/Core/Services/NoopImplementations/NoopMailService.cs index 910516ab5..cee8c91f4 100644 --- a/src/Core/Services/NoopImplementations/NoopMailService.cs +++ b/src/Core/Services/NoopImplementations/NoopMailService.cs @@ -3,239 +3,238 @@ using Bit.Core.Entities.Provider; using Bit.Core.Models.Business; using Bit.Core.Models.Mail; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class NoopMailService : IMailService { - public class NoopMailService : IMailService + public Task SendChangeEmailAlreadyExistsEmailAsync(string fromEmail, string toEmail) { - public Task SendChangeEmailAlreadyExistsEmailAsync(string fromEmail, string toEmail) - { - return Task.FromResult(0); - } + return Task.FromResult(0); + } - public Task SendVerifyEmailEmailAsync(string email, Guid userId, string hint) - { - return Task.FromResult(0); - } + public Task SendVerifyEmailEmailAsync(string email, Guid userId, string hint) + { + return Task.FromResult(0); + } - public Task SendChangeEmailEmailAsync(string newEmailAddress, string token) - { - return Task.FromResult(0); - } + public Task SendChangeEmailEmailAsync(string newEmailAddress, string token) + { + return Task.FromResult(0); + } - public Task SendMasterPasswordHintEmailAsync(string email, string hint) - { - return Task.FromResult(0); - } + public Task SendMasterPasswordHintEmailAsync(string email, string hint) + { + return Task.FromResult(0); + } - public Task SendNoMasterPasswordHintEmailAsync(string email) - { - return Task.FromResult(0); - } + public Task SendNoMasterPasswordHintEmailAsync(string email) + { + return Task.FromResult(0); + } - public Task SendOrganizationMaxSeatLimitReachedEmailAsync(Organization organization, int maxSeatCount, IEnumerable ownerEmails) - { - return Task.FromResult(0); - } + public Task SendOrganizationMaxSeatLimitReachedEmailAsync(Organization organization, int maxSeatCount, IEnumerable ownerEmails) + { + return Task.FromResult(0); + } - public Task SendOrganizationAutoscaledEmailAsync(Organization organization, int initialSeatCount, IEnumerable ownerEmails) - { - return Task.FromResult(0); - } + public Task SendOrganizationAutoscaledEmailAsync(Organization organization, int initialSeatCount, IEnumerable ownerEmails) + { + return Task.FromResult(0); + } - public Task SendOrganizationAcceptedEmailAsync(Organization organization, string userIdentifier, IEnumerable adminEmails) - { - return Task.FromResult(0); - } + public Task SendOrganizationAcceptedEmailAsync(Organization organization, string userIdentifier, IEnumerable adminEmails) + { + return Task.FromResult(0); + } - public Task SendOrganizationConfirmedEmailAsync(string organizationName, string email) - { - return Task.FromResult(0); - } + public Task SendOrganizationConfirmedEmailAsync(string organizationName, string email) + { + return Task.FromResult(0); + } - public Task SendOrganizationInviteEmailAsync(string organizationName, OrganizationUser orgUser, ExpiringToken token) - { - return Task.FromResult(0); - } + public Task SendOrganizationInviteEmailAsync(string organizationName, OrganizationUser orgUser, ExpiringToken token) + { + return Task.FromResult(0); + } - public Task BulkSendOrganizationInviteEmailAsync(string organizationName, IEnumerable<(OrganizationUser orgUser, ExpiringToken token)> invites) - { - return Task.FromResult(0); - } + public Task BulkSendOrganizationInviteEmailAsync(string organizationName, IEnumerable<(OrganizationUser orgUser, ExpiringToken token)> invites) + { + return Task.FromResult(0); + } - public Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string organizationName, string email) - { - return Task.FromResult(0); - } + public Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string organizationName, string email) + { + return Task.FromResult(0); + } - public Task SendTwoFactorEmailAsync(string email, string token) - { - return Task.FromResult(0); - } + public Task SendTwoFactorEmailAsync(string email, string token) + { + return Task.FromResult(0); + } - public Task SendNewDeviceLoginTwoFactorEmailAsync(string email, string token) - { - return Task.CompletedTask; - } + public Task SendNewDeviceLoginTwoFactorEmailAsync(string email, string token) + { + return Task.CompletedTask; + } - public Task SendWelcomeEmailAsync(User user) - { - return Task.FromResult(0); - } + public Task SendWelcomeEmailAsync(User user) + { + return Task.FromResult(0); + } - public Task SendVerifyDeleteEmailAsync(string email, Guid userId, string token) - { - return Task.FromResult(0); - } + public Task SendVerifyDeleteEmailAsync(string email, Guid userId, string token) + { + return Task.FromResult(0); + } - public Task SendPasswordlessSignInAsync(string returnUrl, string token, string email) - { - return Task.FromResult(0); - } + public Task SendPasswordlessSignInAsync(string returnUrl, string token, string email) + { + return Task.FromResult(0); + } - public Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, - List items, bool mentionInvoices) - { - return Task.FromResult(0); - } + public Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, + List items, bool mentionInvoices) + { + return Task.FromResult(0); + } - public Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices) - { - return Task.FromResult(0); - } + public Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices) + { + return Task.FromResult(0); + } - public Task SendAddedCreditAsync(string email, decimal amount) - { - return Task.FromResult(0); - } + public Task SendAddedCreditAsync(string email, decimal amount) + { + return Task.FromResult(0); + } - public Task SendLicenseExpiredAsync(IEnumerable emails, string organizationName = null) - { - return Task.FromResult(0); - } + public Task SendLicenseExpiredAsync(IEnumerable emails, string organizationName = null) + { + return Task.FromResult(0); + } - public Task SendNewDeviceLoggedInEmail(string email, string deviceType, DateTime timestamp, string ip) - { - return Task.FromResult(0); - } + public Task SendNewDeviceLoggedInEmail(string email, string deviceType, DateTime timestamp, string ip) + { + return Task.FromResult(0); + } - public Task SendRecoverTwoFactorEmail(string email, DateTime timestamp, string ip) - { - return Task.FromResult(0); - } + public Task SendRecoverTwoFactorEmail(string email, DateTime timestamp, string ip) + { + return Task.FromResult(0); + } - public Task SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(string organizationName, string email) - { - return Task.FromResult(0); - } + public Task SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(string organizationName, string email) + { + return Task.FromResult(0); + } - public Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token) - { - return Task.FromResult(0); - } + public Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token) + { + return Task.FromResult(0); + } - public Task SendEmergencyAccessAcceptedEmailAsync(string granteeEmail, string email) - { - return Task.FromResult(0); - } + public Task SendEmergencyAccessAcceptedEmailAsync(string granteeEmail, string email) + { + return Task.FromResult(0); + } - public Task SendEmergencyAccessConfirmedEmailAsync(string grantorName, string email) - { - return Task.FromResult(0); - } + public Task SendEmergencyAccessConfirmedEmailAsync(string grantorName, string email) + { + return Task.FromResult(0); + } - public Task SendEmergencyAccessRecoveryInitiated(EmergencyAccess emergencyAccess, string initiatingName, string email) - { - return Task.FromResult(0); - } + public Task SendEmergencyAccessRecoveryInitiated(EmergencyAccess emergencyAccess, string initiatingName, string email) + { + return Task.FromResult(0); + } - public Task SendEmergencyAccessRecoveryApproved(EmergencyAccess emergencyAccess, string approvingName, string email) - { - return Task.FromResult(0); - } + public Task SendEmergencyAccessRecoveryApproved(EmergencyAccess emergencyAccess, string approvingName, string email) + { + return Task.FromResult(0); + } - public Task SendEmergencyAccessRecoveryRejected(EmergencyAccess emergencyAccess, string rejectingName, string email) - { - return Task.FromResult(0); - } + public Task SendEmergencyAccessRecoveryRejected(EmergencyAccess emergencyAccess, string rejectingName, string email) + { + return Task.FromResult(0); + } - public Task SendEmergencyAccessRecoveryReminder(EmergencyAccess emergencyAccess, string initiatingName, string email) - { - return Task.FromResult(0); - } + public Task SendEmergencyAccessRecoveryReminder(EmergencyAccess emergencyAccess, string initiatingName, string email) + { + return Task.FromResult(0); + } - public Task SendEmergencyAccessRecoveryTimedOut(EmergencyAccess ea, string initiatingName, string email) - { - return Task.FromResult(0); - } + public Task SendEmergencyAccessRecoveryTimedOut(EmergencyAccess ea, string initiatingName, string email) + { + return Task.FromResult(0); + } - public Task SendEnqueuedMailMessageAsync(IMailQueueMessage queueMessage) - { - return Task.FromResult(0); - } + public Task SendEnqueuedMailMessageAsync(IMailQueueMessage queueMessage) + { + return Task.FromResult(0); + } - public Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName) - { - return Task.FromResult(0); - } + public Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName) + { + return Task.FromResult(0); + } - public Task SendProviderSetupInviteEmailAsync(Provider provider, string token, string email) - { - return Task.FromResult(0); - } + public Task SendProviderSetupInviteEmailAsync(Provider provider, string token, string email) + { + return Task.FromResult(0); + } - public Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email) - { - return Task.FromResult(0); - } + public Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email) + { + return Task.FromResult(0); + } - public Task SendProviderConfirmedEmailAsync(string providerName, string email) - { - return Task.FromResult(0); - } + public Task SendProviderConfirmedEmailAsync(string providerName, string email) + { + return Task.FromResult(0); + } - public Task SendProviderUserRemoved(string providerName, string email) - { - return Task.FromResult(0); - } + public Task SendProviderUserRemoved(string providerName, string email) + { + return Task.FromResult(0); + } - public Task SendUpdatedTempPasswordEmailAsync(string email, string userName) - { - return Task.FromResult(0); - } + public Task SendUpdatedTempPasswordEmailAsync(string email, string userName) + { + return Task.FromResult(0); + } - public Task SendFamiliesForEnterpriseOfferEmailAsync(string SponsorOrgName, string email, bool existingAccount, string token) - { - return Task.FromResult(0); - } + public Task SendFamiliesForEnterpriseOfferEmailAsync(string SponsorOrgName, string email, bool existingAccount, string token) + { + return Task.FromResult(0); + } - public Task BulkSendFamiliesForEnterpriseOfferEmailAsync(string SponsorOrgName, IEnumerable<(string Email, bool ExistingAccount, string Token)> invites) - { - return Task.FromResult(0); - } + public Task BulkSendFamiliesForEnterpriseOfferEmailAsync(string SponsorOrgName, IEnumerable<(string Email, bool ExistingAccount, string Token)> invites) + { + return Task.FromResult(0); + } - public Task SendFamiliesForEnterpriseRedeemedEmailsAsync(string familyUserEmail, string sponsorEmail) - { - return Task.FromResult(0); - } + public Task SendFamiliesForEnterpriseRedeemedEmailsAsync(string familyUserEmail, string sponsorEmail) + { + return Task.FromResult(0); + } - public Task SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(string email, DateTime expirationDate) - { - return Task.FromResult(0); - } + public Task SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(string email, DateTime expirationDate) + { + return Task.FromResult(0); + } - public Task SendOTPEmailAsync(string email, string token) - { - return Task.FromResult(0); - } + public Task SendOTPEmailAsync(string email, string token) + { + return Task.FromResult(0); + } - public Task SendFailedLoginAttemptsEmailAsync(string email, DateTime utcNow, string ip) - { - return Task.FromResult(0); - } + public Task SendFailedLoginAttemptsEmailAsync(string email, DateTime utcNow, string ip) + { + return Task.FromResult(0); + } - public Task SendFailedTwoFactorAttemptsEmailAsync(string email, DateTime utcNow, string ip) - { - return Task.FromResult(0); - } + public Task SendFailedTwoFactorAttemptsEmailAsync(string email, DateTime utcNow, string ip) + { + return Task.FromResult(0); } } diff --git a/src/Core/Services/NoopImplementations/NoopProviderService.cs b/src/Core/Services/NoopImplementations/NoopProviderService.cs index efa574144..478c5c6c1 100644 --- a/src/Core/Services/NoopImplementations/NoopProviderService.cs +++ b/src/Core/Services/NoopImplementations/NoopProviderService.cs @@ -3,36 +3,35 @@ using Bit.Core.Entities.Provider; using Bit.Core.Models.Business; using Bit.Core.Models.Business.Provider; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class NoopProviderService : IProviderService { - public class NoopProviderService : IProviderService - { - public Task CreateAsync(string ownerEmail) => throw new NotImplementedException(); + public Task CreateAsync(string ownerEmail) => throw new NotImplementedException(); - public Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key) => throw new NotImplementedException(); + public Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key) => throw new NotImplementedException(); - public Task UpdateAsync(Provider provider, bool updateBilling = false) => throw new NotImplementedException(); + public Task UpdateAsync(Provider provider, bool updateBilling = false) => throw new NotImplementedException(); - public Task> InviteUserAsync(ProviderUserInvite invite) => throw new NotImplementedException(); + public Task> InviteUserAsync(ProviderUserInvite invite) => throw new NotImplementedException(); - public Task>> ResendInvitesAsync(ProviderUserInvite invite) => throw new NotImplementedException(); + public Task>> ResendInvitesAsync(ProviderUserInvite invite) => throw new NotImplementedException(); - public Task AcceptUserAsync(Guid providerUserId, User user, string token) => throw new NotImplementedException(); + public Task AcceptUserAsync(Guid providerUserId, User user, string token) => throw new NotImplementedException(); - public Task>> ConfirmUsersAsync(Guid providerId, Dictionary keys, Guid confirmingUserId) => throw new NotImplementedException(); + public Task>> ConfirmUsersAsync(Guid providerId, Dictionary keys, Guid confirmingUserId) => throw new NotImplementedException(); - public Task SaveUserAsync(ProviderUser user, Guid savingUserId) => throw new NotImplementedException(); + public Task SaveUserAsync(ProviderUser user, Guid savingUserId) => throw new NotImplementedException(); - public Task>> DeleteUsersAsync(Guid providerId, IEnumerable providerUserIds, Guid deletingUserId) => throw new NotImplementedException(); + public Task>> DeleteUsersAsync(Guid providerId, IEnumerable providerUserIds, Guid deletingUserId) => throw new NotImplementedException(); - public Task AddOrganization(Guid providerId, Guid organizationId, Guid addingUserId, string key) => throw new NotImplementedException(); + public Task AddOrganization(Guid providerId, Guid organizationId, Guid addingUserId, string key) => throw new NotImplementedException(); - public Task CreateOrganizationAsync(Guid providerId, OrganizationSignup organizationSignup, string clientOwnerEmail, User user) => throw new NotImplementedException(); + public Task CreateOrganizationAsync(Guid providerId, OrganizationSignup organizationSignup, string clientOwnerEmail, User user) => throw new NotImplementedException(); - public Task RemoveOrganizationAsync(Guid providerId, Guid providerOrganizationId, Guid removingUserId) => throw new NotImplementedException(); + public Task RemoveOrganizationAsync(Guid providerId, Guid providerOrganizationId, Guid removingUserId) => throw new NotImplementedException(); - public Task LogProviderAccessToOrganizationAsync(Guid organizationId) => throw new NotImplementedException(); + public Task LogProviderAccessToOrganizationAsync(Guid organizationId) => throw new NotImplementedException(); - public Task ResendProviderSetupInviteEmailAsync(Guid providerId, Guid userId) => throw new NotImplementedException(); - } + public Task ResendProviderSetupInviteEmailAsync(Guid providerId, Guid userId) => throw new NotImplementedException(); } diff --git a/src/Core/Services/NoopImplementations/NoopPushNotificationService.cs b/src/Core/Services/NoopImplementations/NoopPushNotificationService.cs index 8d9f1117e..ee2c6a498 100644 --- a/src/Core/Services/NoopImplementations/NoopPushNotificationService.cs +++ b/src/Core/Services/NoopImplementations/NoopPushNotificationService.cs @@ -1,90 +1,89 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class NoopPushNotificationService : IPushNotificationService { - public class NoopPushNotificationService : IPushNotificationService + public Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) { - public Task PushSyncCipherCreateAsync(Cipher cipher, IEnumerable collectionIds) - { - return Task.FromResult(0); - } + return Task.FromResult(0); + } - public Task PushSyncCipherDeleteAsync(Cipher cipher) - { - return Task.FromResult(0); - } + public Task PushSyncCipherDeleteAsync(Cipher cipher) + { + return Task.FromResult(0); + } - public Task PushSyncCiphersAsync(Guid userId) - { - return Task.FromResult(0); - } + public Task PushSyncCiphersAsync(Guid userId) + { + return Task.FromResult(0); + } - public Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) - { - return Task.FromResult(0); - } + public Task PushSyncCipherUpdateAsync(Cipher cipher, IEnumerable collectionIds) + { + return Task.FromResult(0); + } - public Task PushSyncFolderCreateAsync(Folder folder) - { - return Task.FromResult(0); - } + public Task PushSyncFolderCreateAsync(Folder folder) + { + return Task.FromResult(0); + } - public Task PushSyncFolderDeleteAsync(Folder folder) - { - return Task.FromResult(0); - } + public Task PushSyncFolderDeleteAsync(Folder folder) + { + return Task.FromResult(0); + } - public Task PushSyncFolderUpdateAsync(Folder folder) - { - return Task.FromResult(0); - } + public Task PushSyncFolderUpdateAsync(Folder folder) + { + return Task.FromResult(0); + } - public Task PushSyncOrgKeysAsync(Guid userId) - { - return Task.FromResult(0); - } + public Task PushSyncOrgKeysAsync(Guid userId) + { + return Task.FromResult(0); + } - public Task PushSyncSettingsAsync(Guid userId) - { - return Task.FromResult(0); - } + public Task PushSyncSettingsAsync(Guid userId) + { + return Task.FromResult(0); + } - public Task PushSyncVaultAsync(Guid userId) - { - return Task.FromResult(0); - } + public Task PushSyncVaultAsync(Guid userId) + { + return Task.FromResult(0); + } - public Task PushLogOutAsync(Guid userId) - { - return Task.FromResult(0); - } + public Task PushLogOutAsync(Guid userId) + { + return Task.FromResult(0); + } - public Task PushSyncSendCreateAsync(Send send) - { - return Task.FromResult(0); - } + public Task PushSyncSendCreateAsync(Send send) + { + return Task.FromResult(0); + } - public Task PushSyncSendDeleteAsync(Send send) - { - return Task.FromResult(0); - } + public Task PushSyncSendDeleteAsync(Send send) + { + return Task.FromResult(0); + } - public Task PushSyncSendUpdateAsync(Send send) - { - return Task.FromResult(0); - } + public Task PushSyncSendUpdateAsync(Send send) + { + return Task.FromResult(0); + } - public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, - string deviceId = null) - { - return Task.FromResult(0); - } + public Task SendPayloadToOrganizationAsync(string orgId, PushType type, object payload, string identifier, + string deviceId = null) + { + return Task.FromResult(0); + } - public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, - string deviceId = null) - { - return Task.FromResult(0); - } + public Task SendPayloadToUserAsync(string userId, PushType type, object payload, string identifier, + string deviceId = null) + { + return Task.FromResult(0); } } diff --git a/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs b/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs index c574314e0..f6279c946 100644 --- a/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs +++ b/src/Core/Services/NoopImplementations/NoopPushRegistrationService.cs @@ -1,28 +1,27 @@ using Bit.Core.Enums; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class NoopPushRegistrationService : IPushRegistrationService { - public class NoopPushRegistrationService : IPushRegistrationService + public Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) { - public Task AddUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) - { - return Task.FromResult(0); - } + return Task.FromResult(0); + } - public Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, - string identifier, DeviceType type) - { - return Task.FromResult(0); - } + public Task CreateOrUpdateRegistrationAsync(string pushToken, string deviceId, string userId, + string identifier, DeviceType type) + { + return Task.FromResult(0); + } - public Task DeleteRegistrationAsync(string deviceId) - { - return Task.FromResult(0); - } + public Task DeleteRegistrationAsync(string deviceId) + { + return Task.FromResult(0); + } - public Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) - { - return Task.FromResult(0); - } + public Task DeleteUserRegistrationOrganizationAsync(IEnumerable deviceIds, string organizationId) + { + return Task.FromResult(0); } } diff --git a/src/Core/Services/NoopImplementations/NoopReferenceEventService.cs b/src/Core/Services/NoopImplementations/NoopReferenceEventService.cs index fa15ce727..a32001e85 100644 --- a/src/Core/Services/NoopImplementations/NoopReferenceEventService.cs +++ b/src/Core/Services/NoopImplementations/NoopReferenceEventService.cs @@ -1,12 +1,11 @@ using Bit.Core.Models.Business; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class NoopReferenceEventService : IReferenceEventService { - public class NoopReferenceEventService : IReferenceEventService + public Task RaiseEventAsync(ReferenceEvent referenceEvent) { - public Task RaiseEventAsync(ReferenceEvent referenceEvent) - { - return Task.CompletedTask; - } + return Task.CompletedTask; } } diff --git a/src/Core/Services/NoopImplementations/NoopSendFileStorageService.cs b/src/Core/Services/NoopImplementations/NoopSendFileStorageService.cs index 407e3976f..08602ef9f 100644 --- a/src/Core/Services/NoopImplementations/NoopSendFileStorageService.cs +++ b/src/Core/Services/NoopImplementations/NoopSendFileStorageService.cs @@ -1,45 +1,44 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Services +namespace Bit.Core.Services; + +public class NoopSendFileStorageService : ISendFileStorageService { - public class NoopSendFileStorageService : ISendFileStorageService + public FileUploadType FileUploadType => FileUploadType.Direct; + + public Task UploadNewFileAsync(Stream stream, Send send, string attachmentId) { - public FileUploadType FileUploadType => FileUploadType.Direct; + return Task.FromResult(0); + } - public Task UploadNewFileAsync(Stream stream, Send send, string attachmentId) - { - return Task.FromResult(0); - } + public Task DeleteFileAsync(Send send, string fileId) + { + return Task.FromResult(0); + } - public Task DeleteFileAsync(Send send, string fileId) - { - return Task.FromResult(0); - } + public Task DeleteFilesForOrganizationAsync(Guid organizationId) + { + return Task.FromResult(0); + } - public Task DeleteFilesForOrganizationAsync(Guid organizationId) - { - return Task.FromResult(0); - } + public Task DeleteFilesForUserAsync(Guid userId) + { + return Task.FromResult(0); + } - public Task DeleteFilesForUserAsync(Guid userId) - { - return Task.FromResult(0); - } + public Task GetSendFileDownloadUrlAsync(Send send, string fileId) + { + return Task.FromResult((string)null); + } - public Task GetSendFileDownloadUrlAsync(Send send, string fileId) - { - return Task.FromResult((string)null); - } + public Task GetSendFileUploadUrlAsync(Send send, string fileId) + { + return Task.FromResult((string)null); + } - public Task GetSendFileUploadUrlAsync(Send send, string fileId) - { - return Task.FromResult((string)null); - } - - public Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway) - { - return Task.FromResult((false, default(long?))); - } + public Task<(bool, long?)> ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway) + { + return Task.FromResult((false, default(long?))); } } diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs index f0bdca4ef..bd4087f3a 100644 --- a/src/Core/Settings/GlobalSettings.cs +++ b/src/Core/Settings/GlobalSettings.cs @@ -1,502 +1,501 @@ -namespace Bit.Core.Settings +namespace Bit.Core.Settings; + +public class GlobalSettings : IGlobalSettings { - public class GlobalSettings : IGlobalSettings + private string _logDirectory; + private string _licenseDirectory; + + public GlobalSettings() { - private string _logDirectory; - private string _licenseDirectory; + BaseServiceUri = new BaseServiceUriSettings(this); + Attachment = new FileStorageSettings(this, "attachments", "attachments"); + Send = new FileStorageSettings(this, "attachments/send", "attachments/send"); + DataProtection = new DataProtectionSettings(this); + } - public GlobalSettings() + public bool SelfHosted { get; set; } + public virtual string KnownProxies { get; set; } + public virtual string SiteName { get; set; } + public virtual string ProjectName { get; set; } + public virtual string LogDirectory + { + get => BuildDirectory(_logDirectory, "/logs"); + set => _logDirectory = value; + } + public virtual long? LogRollBySizeLimit { get; set; } + public virtual string LicenseDirectory + { + get => BuildDirectory(_licenseDirectory, "/core/licenses"); + set => _licenseDirectory = value; + } + public string LicenseCertificatePassword { get; set; } + public virtual string PushRelayBaseUri { get; set; } + public virtual string InternalIdentityKey { get; set; } + public virtual string OidcIdentityClientKey { get; set; } + public virtual string HibpApiKey { get; set; } + public virtual bool DisableUserRegistration { get; set; } + public virtual bool DisableEmailNewDevice { get; set; } + public virtual bool EnableCloudCommunication { get; set; } = false; + public virtual int OrganizationInviteExpirationHours { get; set; } = 120; // 5 days + public virtual string EventGridKey { get; set; } + public virtual CaptchaSettings Captcha { get; set; } = new CaptchaSettings(); + public virtual IInstallationSettings Installation { get; set; } = new InstallationSettings(); + public virtual IBaseServiceUriSettings BaseServiceUri { get; set; } + public virtual string DatabaseProvider { get; set; } + public virtual SqlSettings SqlServer { get; set; } = new SqlSettings(); + public virtual SqlSettings PostgreSql { get; set; } = new SqlSettings(); + public virtual SqlSettings MySql { get; set; } = new SqlSettings(); + public virtual SqlSettings Sqlite { get; set; } = new SqlSettings(); + public virtual MailSettings Mail { get; set; } = new MailSettings(); + public virtual IConnectionStringSettings Storage { get; set; } = new ConnectionStringSettings(); + public virtual ConnectionStringSettings Events { get; set; } = new ConnectionStringSettings(); + public virtual IConnectionStringSettings Redis { get; set; } = new ConnectionStringSettings(); + public virtual NotificationsSettings Notifications { get; set; } = new NotificationsSettings(); + public virtual IFileStorageSettings Attachment { get; set; } + public virtual FileStorageSettings Send { get; set; } + public virtual IdentityServerSettings IdentityServer { get; set; } = new IdentityServerSettings(); + public virtual DataProtectionSettings DataProtection { get; set; } + public virtual DocumentDbSettings DocumentDb { get; set; } = new DocumentDbSettings(); + public virtual SentrySettings Sentry { get; set; } = new SentrySettings(); + public virtual SyslogSettings Syslog { get; set; } = new SyslogSettings(); + public virtual NotificationHubSettings NotificationHub { get; set; } = new NotificationHubSettings(); + public virtual YubicoSettings Yubico { get; set; } = new YubicoSettings(); + public virtual DuoSettings Duo { get; set; } = new DuoSettings(); + public virtual BraintreeSettings Braintree { get; set; } = new BraintreeSettings(); + public virtual BitPaySettings BitPay { get; set; } = new BitPaySettings(); + public virtual AmazonSettings Amazon { get; set; } = new AmazonSettings(); + public virtual ServiceBusSettings ServiceBus { get; set; } = new ServiceBusSettings(); + public virtual AppleIapSettings AppleIap { get; set; } = new AppleIapSettings(); + public virtual ISsoSettings Sso { get; set; } = new SsoSettings(); + public virtual StripeSettings Stripe { get; set; } = new StripeSettings(); + public virtual ITwoFactorAuthSettings TwoFactorAuth { get; set; } = new TwoFactorAuthSettings(); + + public string BuildExternalUri(string explicitValue, string name) + { + if (!string.IsNullOrWhiteSpace(explicitValue)) { - BaseServiceUri = new BaseServiceUriSettings(this); - Attachment = new FileStorageSettings(this, "attachments", "attachments"); - Send = new FileStorageSettings(this, "attachments/send", "attachments/send"); - DataProtection = new DataProtectionSettings(this); + return explicitValue; + } + if (!SelfHosted) + { + return null; + } + return string.Format("{0}/{1}", BaseServiceUri.Vault, name); + } + + public string BuildInternalUri(string explicitValue, string name) + { + if (!string.IsNullOrWhiteSpace(explicitValue)) + { + return explicitValue; + } + if (!SelfHosted) + { + return null; + } + return string.Format("http://{0}:5000", name); + } + + public string BuildDirectory(string explicitValue, string appendedPath) + { + if (!string.IsNullOrWhiteSpace(explicitValue)) + { + return explicitValue; + } + if (!SelfHosted) + { + return null; + } + return string.Concat("/etc/bitwarden", appendedPath); + } + + public class BaseServiceUriSettings : IBaseServiceUriSettings + { + private readonly GlobalSettings _globalSettings; + + private string _api; + private string _identity; + private string _admin; + private string _notifications; + private string _sso; + private string _scim; + private string _internalApi; + private string _internalIdentity; + private string _internalAdmin; + private string _internalNotifications; + private string _internalSso; + private string _internalVault; + private string _internalScim; + + public BaseServiceUriSettings(GlobalSettings globalSettings) + { + _globalSettings = globalSettings; } - public bool SelfHosted { get; set; } - public virtual string KnownProxies { get; set; } - public virtual string SiteName { get; set; } - public virtual string ProjectName { get; set; } - public virtual string LogDirectory - { - get => BuildDirectory(_logDirectory, "/logs"); - set => _logDirectory = value; - } - public virtual long? LogRollBySizeLimit { get; set; } - public virtual string LicenseDirectory - { - get => BuildDirectory(_licenseDirectory, "/core/licenses"); - set => _licenseDirectory = value; - } - public string LicenseCertificatePassword { get; set; } - public virtual string PushRelayBaseUri { get; set; } - public virtual string InternalIdentityKey { get; set; } - public virtual string OidcIdentityClientKey { get; set; } - public virtual string HibpApiKey { get; set; } - public virtual bool DisableUserRegistration { get; set; } - public virtual bool DisableEmailNewDevice { get; set; } - public virtual bool EnableCloudCommunication { get; set; } = false; - public virtual int OrganizationInviteExpirationHours { get; set; } = 120; // 5 days - public virtual string EventGridKey { get; set; } - public virtual CaptchaSettings Captcha { get; set; } = new CaptchaSettings(); - public virtual IInstallationSettings Installation { get; set; } = new InstallationSettings(); - public virtual IBaseServiceUriSettings BaseServiceUri { get; set; } - public virtual string DatabaseProvider { get; set; } - public virtual SqlSettings SqlServer { get; set; } = new SqlSettings(); - public virtual SqlSettings PostgreSql { get; set; } = new SqlSettings(); - public virtual SqlSettings MySql { get; set; } = new SqlSettings(); - public virtual SqlSettings Sqlite { get; set; } = new SqlSettings(); - public virtual MailSettings Mail { get; set; } = new MailSettings(); - public virtual IConnectionStringSettings Storage { get; set; } = new ConnectionStringSettings(); - public virtual ConnectionStringSettings Events { get; set; } = new ConnectionStringSettings(); - public virtual IConnectionStringSettings Redis { get; set; } = new ConnectionStringSettings(); - public virtual NotificationsSettings Notifications { get; set; } = new NotificationsSettings(); - public virtual IFileStorageSettings Attachment { get; set; } - public virtual FileStorageSettings Send { get; set; } - public virtual IdentityServerSettings IdentityServer { get; set; } = new IdentityServerSettings(); - public virtual DataProtectionSettings DataProtection { get; set; } - public virtual DocumentDbSettings DocumentDb { get; set; } = new DocumentDbSettings(); - public virtual SentrySettings Sentry { get; set; } = new SentrySettings(); - public virtual SyslogSettings Syslog { get; set; } = new SyslogSettings(); - public virtual NotificationHubSettings NotificationHub { get; set; } = new NotificationHubSettings(); - public virtual YubicoSettings Yubico { get; set; } = new YubicoSettings(); - public virtual DuoSettings Duo { get; set; } = new DuoSettings(); - public virtual BraintreeSettings Braintree { get; set; } = new BraintreeSettings(); - public virtual BitPaySettings BitPay { get; set; } = new BitPaySettings(); - public virtual AmazonSettings Amazon { get; set; } = new AmazonSettings(); - public virtual ServiceBusSettings ServiceBus { get; set; } = new ServiceBusSettings(); - public virtual AppleIapSettings AppleIap { get; set; } = new AppleIapSettings(); - public virtual ISsoSettings Sso { get; set; } = new SsoSettings(); - public virtual StripeSettings Stripe { get; set; } = new StripeSettings(); - public virtual ITwoFactorAuthSettings TwoFactorAuth { get; set; } = new TwoFactorAuthSettings(); + public string Vault { get; set; } + public string VaultWithHash => $"{Vault}/#"; - public string BuildExternalUri(string explicitValue, string name) + public string Api { - if (!string.IsNullOrWhiteSpace(explicitValue)) - { - return explicitValue; - } - if (!SelfHosted) - { - return null; - } - return string.Format("{0}/{1}", BaseServiceUri.Vault, name); + get => _globalSettings.BuildExternalUri(_api, "api"); + set => _api = value; + } + public string Identity + { + get => _globalSettings.BuildExternalUri(_identity, "identity"); + set => _identity = value; + } + public string Admin + { + get => _globalSettings.BuildExternalUri(_admin, "admin"); + set => _admin = value; + } + public string Notifications + { + get => _globalSettings.BuildExternalUri(_notifications, "notifications"); + set => _notifications = value; + } + public string Sso + { + get => _globalSettings.BuildExternalUri(_sso, "sso"); + set => _sso = value; + } + public string Scim + { + get => _globalSettings.BuildExternalUri(_scim, "scim"); + set => _scim = value; } - public string BuildInternalUri(string explicitValue, string name) + public string InternalNotifications { - if (!string.IsNullOrWhiteSpace(explicitValue)) - { - return explicitValue; - } - if (!SelfHosted) - { - return null; - } - return string.Format("http://{0}:5000", name); + get => _globalSettings.BuildInternalUri(_internalNotifications, "notifications"); + set => _internalNotifications = value; } - - public string BuildDirectory(string explicitValue, string appendedPath) + public string InternalAdmin { - if (!string.IsNullOrWhiteSpace(explicitValue)) - { - return explicitValue; - } - if (!SelfHosted) - { - return null; - } - return string.Concat("/etc/bitwarden", appendedPath); + get => _globalSettings.BuildInternalUri(_internalAdmin, "admin"); + set => _internalAdmin = value; } - - public class BaseServiceUriSettings : IBaseServiceUriSettings + public string InternalIdentity { - private readonly GlobalSettings _globalSettings; - - private string _api; - private string _identity; - private string _admin; - private string _notifications; - private string _sso; - private string _scim; - private string _internalApi; - private string _internalIdentity; - private string _internalAdmin; - private string _internalNotifications; - private string _internalSso; - private string _internalVault; - private string _internalScim; - - public BaseServiceUriSettings(GlobalSettings globalSettings) - { - _globalSettings = globalSettings; - } - - public string Vault { get; set; } - public string VaultWithHash => $"{Vault}/#"; - - public string Api - { - get => _globalSettings.BuildExternalUri(_api, "api"); - set => _api = value; - } - public string Identity - { - get => _globalSettings.BuildExternalUri(_identity, "identity"); - set => _identity = value; - } - public string Admin - { - get => _globalSettings.BuildExternalUri(_admin, "admin"); - set => _admin = value; - } - public string Notifications - { - get => _globalSettings.BuildExternalUri(_notifications, "notifications"); - set => _notifications = value; - } - public string Sso - { - get => _globalSettings.BuildExternalUri(_sso, "sso"); - set => _sso = value; - } - public string Scim - { - get => _globalSettings.BuildExternalUri(_scim, "scim"); - set => _scim = value; - } - - public string InternalNotifications - { - get => _globalSettings.BuildInternalUri(_internalNotifications, "notifications"); - set => _internalNotifications = value; - } - public string InternalAdmin - { - get => _globalSettings.BuildInternalUri(_internalAdmin, "admin"); - set => _internalAdmin = value; - } - public string InternalIdentity - { - get => _globalSettings.BuildInternalUri(_internalIdentity, "identity"); - set => _internalIdentity = value; - } - public string InternalApi - { - get => _globalSettings.BuildInternalUri(_internalApi, "api"); - set => _internalApi = value; - } - public string InternalVault - { - get => _globalSettings.BuildInternalUri(_internalVault, "web"); - set => _internalVault = value; - } - public string InternalSso - { - get => _globalSettings.BuildInternalUri(_internalSso, "sso"); - set => _internalSso = value; - } - public string InternalScim - { - get => _globalSettings.BuildInternalUri(_scim, "scim"); - set => _internalScim = value; - } + get => _globalSettings.BuildInternalUri(_internalIdentity, "identity"); + set => _internalIdentity = value; } - - public class SqlSettings + public string InternalApi { - private string _connectionString; - private string _readOnlyConnectionString; - private string _jobSchedulerConnectionString; - - public string ConnectionString - { - get => _connectionString; - set => _connectionString = value.Trim('"'); - } - - public string ReadOnlyConnectionString - { - get => string.IsNullOrWhiteSpace(_readOnlyConnectionString) ? - _connectionString : _readOnlyConnectionString; - set => _readOnlyConnectionString = value.Trim('"'); - } - - public string JobSchedulerConnectionString - { - get => _jobSchedulerConnectionString; - set => _jobSchedulerConnectionString = value.Trim('"'); - } + get => _globalSettings.BuildInternalUri(_internalApi, "api"); + set => _internalApi = value; } - - public class ConnectionStringSettings : IConnectionStringSettings + public string InternalVault { - private string _connectionString; - - public string ConnectionString - { - get => _connectionString; - set => _connectionString = value.Trim('"'); - } + get => _globalSettings.BuildInternalUri(_internalVault, "web"); + set => _internalVault = value; } - - public class FileStorageSettings : IFileStorageSettings + public string InternalSso { - private readonly GlobalSettings _globalSettings; - private readonly string _urlName; - private readonly string _directoryName; - private string _connectionString; - private string _baseDirectory; - private string _baseUrl; - - public FileStorageSettings(GlobalSettings globalSettings, string urlName, string directoryName) - { - _globalSettings = globalSettings; - _urlName = urlName; - _directoryName = directoryName; - } - - public string ConnectionString - { - get => _connectionString; - set => _connectionString = value.Trim('"'); - } - - public string BaseDirectory - { - get => _globalSettings.BuildDirectory(_baseDirectory, string.Concat("/core/", _directoryName)); - set => _baseDirectory = value; - } - - public string BaseUrl - { - get => _globalSettings.BuildExternalUri(_baseUrl, _urlName); - set => _baseUrl = value; - } + get => _globalSettings.BuildInternalUri(_internalSso, "sso"); + set => _internalSso = value; } - - public class MailSettings + public string InternalScim { - private ConnectionStringSettings _connectionStringSettings; - public string ConnectionString - { - get => _connectionStringSettings?.ConnectionString; - set - { - if (_connectionStringSettings == null) - { - _connectionStringSettings = new ConnectionStringSettings(); - } - _connectionStringSettings.ConnectionString = value; - } - } - public string ReplyToEmail { get; set; } - public string AmazonConfigSetName { get; set; } - public SmtpSettings Smtp { get; set; } = new SmtpSettings(); - public string SendGridApiKey { get; set; } - public int? SendGridPercentage { get; set; } - - public class SmtpSettings - { - public string Host { get; set; } - public int Port { get; set; } = 25; - public bool StartTls { get; set; } = false; - public bool Ssl { get; set; } = false; - public bool SslOverride { get; set; } = false; - public string Username { get; set; } - public string Password { get; set; } - public bool TrustServer { get; set; } = false; - } - } - - public class IdentityServerSettings - { - public string CertificateThumbprint { get; set; } - public string CertificatePassword { get; set; } - public string RedisConnectionString { get; set; } - } - - public class DataProtectionSettings - { - private readonly GlobalSettings _globalSettings; - - private string _directory; - - public DataProtectionSettings(GlobalSettings globalSettings) - { - _globalSettings = globalSettings; - } - - public string CertificateThumbprint { get; set; } - public string CertificatePassword { get; set; } - public string Directory - { - get => _globalSettings.BuildDirectory(_directory, "/core/aspnet-dataprotection"); - set => _directory = value; - } - } - - public class DocumentDbSettings - { - public string Uri { get; set; } - public string Key { get; set; } - } - - public class SentrySettings - { - public string Dsn { get; set; } - } - - public class NotificationsSettings : ConnectionStringSettings - { - public string RedisConnectionString { get; set; } - } - - public class SyslogSettings - { - /// - /// The connection string used to connect to a remote syslog server over TCP or UDP, or to connect locally. - /// - /// - /// The connection string will be parsed using to extract the protocol, host name and port number. - /// - /// - /// Supported protocols are: - /// - /// UDP (use udp://) - /// TCP (use tcp://) - /// TLS over TCP (use tls://) - /// - /// - /// - /// - /// A remote server (logging.dev.example.com) is listening on UDP (port 514): - /// - /// udp://logging.dev.example.com:514. - /// - public string Destination { get; set; } - /// - /// The absolute path to a Certificate (DER or Base64 encoded with private key). - /// - /// - /// The certificate path and are passed into the . - /// The file format of the certificate may be binary encded (DER) or base64. If the private key is encrypted, provide the password in , - /// - public string CertificatePath { get; set; } - /// - /// The password for the encrypted private key in the certificate supplied in . - /// - /// - public string CertificatePassword { get; set; } - /// - /// The thumbprint of the certificate in the X.509 certificate store for personal certificates for the user account running Bitwarden. - /// - /// - public string CertificateThumbprint { get; set; } - } - - public class NotificationHubSettings - { - private string _connectionString; - - public string ConnectionString - { - get => _connectionString; - set => _connectionString = value.Trim('"'); - } - public string HubName { get; set; } - } - - public class YubicoSettings - { - public string ClientId { get; set; } - public string Key { get; set; } - public string[] ValidationUrls { get; set; } - } - - public class DuoSettings - { - public string AKey { get; set; } - } - - public class BraintreeSettings - { - public bool Production { get; set; } - public string MerchantId { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } - } - - public class BitPaySettings - { - public bool Production { get; set; } - public string Token { get; set; } - public string NotificationUrl { get; set; } - } - - public class InstallationSettings : IInstallationSettings - { - private string _identityUri; - private string _apiUri; - - public Guid Id { get; set; } - public string Key { get; set; } - public string IdentityUri - { - get => string.IsNullOrWhiteSpace(_identityUri) ? "https://identity.bitwarden.com" : _identityUri; - set => _identityUri = value; - } - public string ApiUri - { - get => string.IsNullOrWhiteSpace(_apiUri) ? "https://api.bitwarden.com" : _apiUri; - set => _apiUri = value; - } - } - - public class AmazonSettings - { - public string AccessKeyId { get; set; } - public string AccessKeySecret { get; set; } - public string Region { get; set; } - } - - public class ServiceBusSettings : ConnectionStringSettings - { - public string ApplicationCacheTopicName { get; set; } - public string ApplicationCacheSubscriptionName { get; set; } - } - - public class AppleIapSettings - { - public string Password { get; set; } - public bool AppInReview { get; set; } - } - - public class SsoSettings : ISsoSettings - { - public int CacheLifetimeInSeconds { get; set; } = 60; - public double SsoTokenLifetimeInSeconds { get; set; } = 5; - } - - public class CaptchaSettings - { - public bool ForceCaptchaRequired { get; set; } = false; - public string HCaptchaSecretKey { get; set; } - public string HCaptchaSiteKey { get; set; } - public int MaximumFailedLoginAttempts { get; set; } - public double MaybeBotScoreThreshold { get; set; } = double.MaxValue; - public double IsBotScoreThreshold { get; set; } = double.MaxValue; - } - - public class StripeSettings - { - public string ApiKey { get; set; } - public int MaxNetworkRetries { get; set; } = 2; - } - - public class TwoFactorAuthSettings : ITwoFactorAuthSettings - { - public bool EmailOnNewDeviceLogin { get; set; } = false; + get => _globalSettings.BuildInternalUri(_scim, "scim"); + set => _internalScim = value; } } + + public class SqlSettings + { + private string _connectionString; + private string _readOnlyConnectionString; + private string _jobSchedulerConnectionString; + + public string ConnectionString + { + get => _connectionString; + set => _connectionString = value.Trim('"'); + } + + public string ReadOnlyConnectionString + { + get => string.IsNullOrWhiteSpace(_readOnlyConnectionString) ? + _connectionString : _readOnlyConnectionString; + set => _readOnlyConnectionString = value.Trim('"'); + } + + public string JobSchedulerConnectionString + { + get => _jobSchedulerConnectionString; + set => _jobSchedulerConnectionString = value.Trim('"'); + } + } + + public class ConnectionStringSettings : IConnectionStringSettings + { + private string _connectionString; + + public string ConnectionString + { + get => _connectionString; + set => _connectionString = value.Trim('"'); + } + } + + public class FileStorageSettings : IFileStorageSettings + { + private readonly GlobalSettings _globalSettings; + private readonly string _urlName; + private readonly string _directoryName; + private string _connectionString; + private string _baseDirectory; + private string _baseUrl; + + public FileStorageSettings(GlobalSettings globalSettings, string urlName, string directoryName) + { + _globalSettings = globalSettings; + _urlName = urlName; + _directoryName = directoryName; + } + + public string ConnectionString + { + get => _connectionString; + set => _connectionString = value.Trim('"'); + } + + public string BaseDirectory + { + get => _globalSettings.BuildDirectory(_baseDirectory, string.Concat("/core/", _directoryName)); + set => _baseDirectory = value; + } + + public string BaseUrl + { + get => _globalSettings.BuildExternalUri(_baseUrl, _urlName); + set => _baseUrl = value; + } + } + + public class MailSettings + { + private ConnectionStringSettings _connectionStringSettings; + public string ConnectionString + { + get => _connectionStringSettings?.ConnectionString; + set + { + if (_connectionStringSettings == null) + { + _connectionStringSettings = new ConnectionStringSettings(); + } + _connectionStringSettings.ConnectionString = value; + } + } + public string ReplyToEmail { get; set; } + public string AmazonConfigSetName { get; set; } + public SmtpSettings Smtp { get; set; } = new SmtpSettings(); + public string SendGridApiKey { get; set; } + public int? SendGridPercentage { get; set; } + + public class SmtpSettings + { + public string Host { get; set; } + public int Port { get; set; } = 25; + public bool StartTls { get; set; } = false; + public bool Ssl { get; set; } = false; + public bool SslOverride { get; set; } = false; + public string Username { get; set; } + public string Password { get; set; } + public bool TrustServer { get; set; } = false; + } + } + + public class IdentityServerSettings + { + public string CertificateThumbprint { get; set; } + public string CertificatePassword { get; set; } + public string RedisConnectionString { get; set; } + } + + public class DataProtectionSettings + { + private readonly GlobalSettings _globalSettings; + + private string _directory; + + public DataProtectionSettings(GlobalSettings globalSettings) + { + _globalSettings = globalSettings; + } + + public string CertificateThumbprint { get; set; } + public string CertificatePassword { get; set; } + public string Directory + { + get => _globalSettings.BuildDirectory(_directory, "/core/aspnet-dataprotection"); + set => _directory = value; + } + } + + public class DocumentDbSettings + { + public string Uri { get; set; } + public string Key { get; set; } + } + + public class SentrySettings + { + public string Dsn { get; set; } + } + + public class NotificationsSettings : ConnectionStringSettings + { + public string RedisConnectionString { get; set; } + } + + public class SyslogSettings + { + /// + /// The connection string used to connect to a remote syslog server over TCP or UDP, or to connect locally. + /// + /// + /// The connection string will be parsed using to extract the protocol, host name and port number. + /// + /// + /// Supported protocols are: + /// + /// UDP (use udp://) + /// TCP (use tcp://) + /// TLS over TCP (use tls://) + /// + /// + /// + /// + /// A remote server (logging.dev.example.com) is listening on UDP (port 514): + /// + /// udp://logging.dev.example.com:514. + /// + public string Destination { get; set; } + /// + /// The absolute path to a Certificate (DER or Base64 encoded with private key). + /// + /// + /// The certificate path and are passed into the . + /// The file format of the certificate may be binary encded (DER) or base64. If the private key is encrypted, provide the password in , + /// + public string CertificatePath { get; set; } + /// + /// The password for the encrypted private key in the certificate supplied in . + /// + /// + public string CertificatePassword { get; set; } + /// + /// The thumbprint of the certificate in the X.509 certificate store for personal certificates for the user account running Bitwarden. + /// + /// + public string CertificateThumbprint { get; set; } + } + + public class NotificationHubSettings + { + private string _connectionString; + + public string ConnectionString + { + get => _connectionString; + set => _connectionString = value.Trim('"'); + } + public string HubName { get; set; } + } + + public class YubicoSettings + { + public string ClientId { get; set; } + public string Key { get; set; } + public string[] ValidationUrls { get; set; } + } + + public class DuoSettings + { + public string AKey { get; set; } + } + + public class BraintreeSettings + { + public bool Production { get; set; } + public string MerchantId { get; set; } + public string PublicKey { get; set; } + public string PrivateKey { get; set; } + } + + public class BitPaySettings + { + public bool Production { get; set; } + public string Token { get; set; } + public string NotificationUrl { get; set; } + } + + public class InstallationSettings : IInstallationSettings + { + private string _identityUri; + private string _apiUri; + + public Guid Id { get; set; } + public string Key { get; set; } + public string IdentityUri + { + get => string.IsNullOrWhiteSpace(_identityUri) ? "https://identity.bitwarden.com" : _identityUri; + set => _identityUri = value; + } + public string ApiUri + { + get => string.IsNullOrWhiteSpace(_apiUri) ? "https://api.bitwarden.com" : _apiUri; + set => _apiUri = value; + } + } + + public class AmazonSettings + { + public string AccessKeyId { get; set; } + public string AccessKeySecret { get; set; } + public string Region { get; set; } + } + + public class ServiceBusSettings : ConnectionStringSettings + { + public string ApplicationCacheTopicName { get; set; } + public string ApplicationCacheSubscriptionName { get; set; } + } + + public class AppleIapSettings + { + public string Password { get; set; } + public bool AppInReview { get; set; } + } + + public class SsoSettings : ISsoSettings + { + public int CacheLifetimeInSeconds { get; set; } = 60; + public double SsoTokenLifetimeInSeconds { get; set; } = 5; + } + + public class CaptchaSettings + { + public bool ForceCaptchaRequired { get; set; } = false; + public string HCaptchaSecretKey { get; set; } + public string HCaptchaSiteKey { get; set; } + public int MaximumFailedLoginAttempts { get; set; } + public double MaybeBotScoreThreshold { get; set; } = double.MaxValue; + public double IsBotScoreThreshold { get; set; } = double.MaxValue; + } + + public class StripeSettings + { + public string ApiKey { get; set; } + public int MaxNetworkRetries { get; set; } = 2; + } + + public class TwoFactorAuthSettings : ITwoFactorAuthSettings + { + public bool EmailOnNewDeviceLogin { get; set; } = false; + } } diff --git a/src/Core/Settings/IBaseServiceUriSettings.cs b/src/Core/Settings/IBaseServiceUriSettings.cs index 0dfdaf0b9..0550ae3e6 100644 --- a/src/Core/Settings/IBaseServiceUriSettings.cs +++ b/src/Core/Settings/IBaseServiceUriSettings.cs @@ -1,22 +1,21 @@  -namespace Bit.Core.Settings +namespace Bit.Core.Settings; + +public interface IBaseServiceUriSettings { - public interface IBaseServiceUriSettings - { - string Vault { get; set; } - string VaultWithHash { get; } - string Api { get; set; } - public string Identity { get; set; } - public string Admin { get; set; } - public string Notifications { get; set; } - public string Sso { get; set; } - public string Scim { get; set; } - public string InternalNotifications { get; set; } - public string InternalAdmin { get; set; } - public string InternalIdentity { get; set; } - public string InternalApi { get; set; } - public string InternalVault { get; set; } - public string InternalSso { get; set; } - public string InternalScim { get; set; } - } + string Vault { get; set; } + string VaultWithHash { get; } + string Api { get; set; } + public string Identity { get; set; } + public string Admin { get; set; } + public string Notifications { get; set; } + public string Sso { get; set; } + public string Scim { get; set; } + public string InternalNotifications { get; set; } + public string InternalAdmin { get; set; } + public string InternalIdentity { get; set; } + public string InternalApi { get; set; } + public string InternalVault { get; set; } + public string InternalSso { get; set; } + public string InternalScim { get; set; } } diff --git a/src/Core/Settings/IConnectionStringSettings.cs b/src/Core/Settings/IConnectionStringSettings.cs index aff2b0627..5b67dc9ca 100644 --- a/src/Core/Settings/IConnectionStringSettings.cs +++ b/src/Core/Settings/IConnectionStringSettings.cs @@ -1,8 +1,6 @@ -namespace Bit.Core.Settings -{ +namespace Bit.Core.Settings; - public interface IConnectionStringSettings - { - string ConnectionString { get; set; } - } +public interface IConnectionStringSettings +{ + string ConnectionString { get; set; } } diff --git a/src/Core/Settings/IFileStorageSettings.cs b/src/Core/Settings/IFileStorageSettings.cs index 45e44802d..44546042d 100644 --- a/src/Core/Settings/IFileStorageSettings.cs +++ b/src/Core/Settings/IFileStorageSettings.cs @@ -1,9 +1,8 @@ -namespace Bit.Core.Settings +namespace Bit.Core.Settings; + +public interface IFileStorageSettings { - public interface IFileStorageSettings - { - string ConnectionString { get; set; } - string BaseDirectory { get; set; } - string BaseUrl { get; set; } - } + string ConnectionString { get; set; } + string BaseDirectory { get; set; } + string BaseUrl { get; set; } } diff --git a/src/Core/Settings/IGlobalSettings.cs b/src/Core/Settings/IGlobalSettings.cs index ec648384e..1929da1f3 100644 --- a/src/Core/Settings/IGlobalSettings.cs +++ b/src/Core/Settings/IGlobalSettings.cs @@ -1,19 +1,18 @@ -namespace Bit.Core.Settings +namespace Bit.Core.Settings; + +public interface IGlobalSettings { - public interface IGlobalSettings - { - // This interface exists for testing. Add settings here as needed for testing - bool SelfHosted { get; set; } - bool EnableCloudCommunication { get; set; } - string LicenseDirectory { get; set; } - string LicenseCertificatePassword { get; set; } - int OrganizationInviteExpirationHours { get; set; } - bool DisableUserRegistration { get; set; } - IInstallationSettings Installation { get; set; } - IFileStorageSettings Attachment { get; set; } - IConnectionStringSettings Storage { get; set; } - IBaseServiceUriSettings BaseServiceUri { get; set; } - ITwoFactorAuthSettings TwoFactorAuth { get; set; } - ISsoSettings Sso { get; set; } - } + // This interface exists for testing. Add settings here as needed for testing + bool SelfHosted { get; set; } + bool EnableCloudCommunication { get; set; } + string LicenseDirectory { get; set; } + string LicenseCertificatePassword { get; set; } + int OrganizationInviteExpirationHours { get; set; } + bool DisableUserRegistration { get; set; } + IInstallationSettings Installation { get; set; } + IFileStorageSettings Attachment { get; set; } + IConnectionStringSettings Storage { get; set; } + IBaseServiceUriSettings BaseServiceUri { get; set; } + ITwoFactorAuthSettings TwoFactorAuth { get; set; } + ISsoSettings Sso { get; set; } } diff --git a/src/Core/Settings/IInstallationSettings.cs b/src/Core/Settings/IInstallationSettings.cs index dbc966d54..6f56a3fa0 100644 --- a/src/Core/Settings/IInstallationSettings.cs +++ b/src/Core/Settings/IInstallationSettings.cs @@ -1,10 +1,9 @@ -namespace Bit.Core.Settings +namespace Bit.Core.Settings; + +public interface IInstallationSettings { - public interface IInstallationSettings - { - public Guid Id { get; set; } - public string Key { get; set; } - public string IdentityUri { get; set; } - public string ApiUri { get; } - } + public Guid Id { get; set; } + public string Key { get; set; } + public string IdentityUri { get; set; } + public string ApiUri { get; } } diff --git a/src/Core/Settings/ISsoSettings.cs b/src/Core/Settings/ISsoSettings.cs index de5193cef..c7429baef 100644 --- a/src/Core/Settings/ISsoSettings.cs +++ b/src/Core/Settings/ISsoSettings.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Settings +namespace Bit.Core.Settings; + +public interface ISsoSettings { - public interface ISsoSettings - { - int CacheLifetimeInSeconds { get; set; } - double SsoTokenLifetimeInSeconds { get; set; } - } + int CacheLifetimeInSeconds { get; set; } + double SsoTokenLifetimeInSeconds { get; set; } } diff --git a/src/Core/Settings/ITwoFactorAuthSettings.cs b/src/Core/Settings/ITwoFactorAuthSettings.cs index 06dced0f8..2e11e6507 100644 --- a/src/Core/Settings/ITwoFactorAuthSettings.cs +++ b/src/Core/Settings/ITwoFactorAuthSettings.cs @@ -1,7 +1,6 @@ -namespace Bit.Core.Settings +namespace Bit.Core.Settings; + +public interface ITwoFactorAuthSettings { - public interface ITwoFactorAuthSettings - { - bool EmailOnNewDeviceLogin { get; set; } - } + bool EmailOnNewDeviceLogin { get; set; } } diff --git a/src/Core/Sso/SamlSigningAlgorithms.cs b/src/Core/Sso/SamlSigningAlgorithms.cs index fba67a4ab..68ad8e5fa 100644 --- a/src/Core/Sso/SamlSigningAlgorithms.cs +++ b/src/Core/Sso/SamlSigningAlgorithms.cs @@ -1,19 +1,18 @@ -namespace Bit.Core.Sso -{ - public static class SamlSigningAlgorithms - { - public const string Default = Sha256; - public const string Sha256 = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256"; - public const string Sha384 = "http://www.w3.org/2000/09/xmldsig#rsa-sha384"; - public const string Sha512 = "http://www.w3.org/2000/09/xmldsig#rsa-sha512"; - public const string Sha1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"; +namespace Bit.Core.Sso; - public static IEnumerable GetEnumerable() - { - yield return Sha256; - yield return Sha384; - yield return Sha512; - yield return Sha1; - } +public static class SamlSigningAlgorithms +{ + public const string Default = Sha256; + public const string Sha256 = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256"; + public const string Sha384 = "http://www.w3.org/2000/09/xmldsig#rsa-sha384"; + public const string Sha512 = "http://www.w3.org/2000/09/xmldsig#rsa-sha512"; + public const string Sha1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"; + + public static IEnumerable GetEnumerable() + { + yield return Sha256; + yield return Sha384; + yield return Sha512; + yield return Sha1; } } diff --git a/src/Core/Tokens/BadTokenException.cs b/src/Core/Tokens/BadTokenException.cs index ffd9cb520..ca2dcac49 100644 --- a/src/Core/Tokens/BadTokenException.cs +++ b/src/Core/Tokens/BadTokenException.cs @@ -1,13 +1,12 @@ -namespace Bit.Core.Tokens -{ - public class BadTokenException : Exception - { - public BadTokenException() - { - } +namespace Bit.Core.Tokens; - public BadTokenException(string message) : base(message) - { - } +public class BadTokenException : Exception +{ + public BadTokenException() + { + } + + public BadTokenException(string message) : base(message) + { } } diff --git a/src/Core/Tokens/DataProtectorTokenFactory.cs b/src/Core/Tokens/DataProtectorTokenFactory.cs index 8029b3554..e0ec9811f 100644 --- a/src/Core/Tokens/DataProtectorTokenFactory.cs +++ b/src/Core/Tokens/DataProtectorTokenFactory.cs @@ -1,55 +1,54 @@ using Microsoft.AspNetCore.DataProtection; -namespace Bit.Core.Tokens +namespace Bit.Core.Tokens; + +public class DataProtectorTokenFactory : IDataProtectorTokenFactory where T : Tokenable { - public class DataProtectorTokenFactory : IDataProtectorTokenFactory where T : Tokenable + private readonly IDataProtector _dataProtector; + private readonly string _clearTextPrefix; + + public DataProtectorTokenFactory(string clearTextPrefix, string purpose, IDataProtectionProvider dataProtectionProvider) { - private readonly IDataProtector _dataProtector; - private readonly string _clearTextPrefix; + _dataProtector = dataProtectionProvider.CreateProtector(purpose); + _clearTextPrefix = clearTextPrefix; + } - public DataProtectorTokenFactory(string clearTextPrefix, string purpose, IDataProtectionProvider dataProtectionProvider) + public string Protect(T data) => + data.ToToken().ProtectWith(_dataProtector).WithPrefix(_clearTextPrefix).ToString(); + + /// + /// Unprotect token + /// + /// The token to parse + /// The tokenable type to parse to + /// The parsed tokenable + /// Throws CryptographicException if fails to unprotect + public T Unprotect(string token) => + Tokenable.FromToken(new Token(token).RemovePrefix(_clearTextPrefix).UnprotectWith(_dataProtector).ToString()); + + public bool TokenValid(string token) + { + try { - _dataProtector = dataProtectionProvider.CreateProtector(purpose); - _clearTextPrefix = clearTextPrefix; + return Unprotect(token).Valid; } - - public string Protect(T data) => - data.ToToken().ProtectWith(_dataProtector).WithPrefix(_clearTextPrefix).ToString(); - - /// - /// Unprotect token - /// - /// The token to parse - /// The tokenable type to parse to - /// The parsed tokenable - /// Throws CryptographicException if fails to unprotect - public T Unprotect(string token) => - Tokenable.FromToken(new Token(token).RemovePrefix(_clearTextPrefix).UnprotectWith(_dataProtector).ToString()); - - public bool TokenValid(string token) + catch { - try - { - return Unprotect(token).Valid; - } - catch - { - return false; - } + return false; } + } - public bool TryUnprotect(string token, out T data) + public bool TryUnprotect(string token, out T data) + { + try { - try - { - data = Unprotect(token); - return true; - } - catch - { - data = default; - return false; - } + data = Unprotect(token); + return true; + } + catch + { + data = default; + return false; } } } diff --git a/src/Core/Tokens/ExpiringTokenable.cs b/src/Core/Tokens/ExpiringTokenable.cs index 37907bbe3..089405e53 100644 --- a/src/Core/Tokens/ExpiringTokenable.cs +++ b/src/Core/Tokens/ExpiringTokenable.cs @@ -1,14 +1,13 @@ using System.Text.Json.Serialization; using Bit.Core.Utilities; -namespace Bit.Core.Tokens -{ - public abstract class ExpiringTokenable : Tokenable - { - [JsonConverter(typeof(EpochDateTimeJsonConverter))] - public DateTime ExpirationDate { get; set; } - public override bool Valid => ExpirationDate > DateTime.UtcNow && TokenIsValid(); +namespace Bit.Core.Tokens; - protected abstract bool TokenIsValid(); - } +public abstract class ExpiringTokenable : Tokenable +{ + [JsonConverter(typeof(EpochDateTimeJsonConverter))] + public DateTime ExpirationDate { get; set; } + public override bool Valid => ExpirationDate > DateTime.UtcNow && TokenIsValid(); + + protected abstract bool TokenIsValid(); } diff --git a/src/Core/Tokens/IBillingSyncTokenable.cs b/src/Core/Tokens/IBillingSyncTokenable.cs index a9fdc06bd..d63df0cc7 100644 --- a/src/Core/Tokens/IBillingSyncTokenable.cs +++ b/src/Core/Tokens/IBillingSyncTokenable.cs @@ -1,8 +1,7 @@ -namespace Bit.Core.Tokens +namespace Bit.Core.Tokens; + +public interface IBillingSyncTokenable { - public interface IBillingSyncTokenable - { - public Guid OrganizationId { get; set; } - public string BillingSyncKey { get; set; } - } + public Guid OrganizationId { get; set; } + public string BillingSyncKey { get; set; } } diff --git a/src/Core/Tokens/IDataProtectorTokenFactory.cs b/src/Core/Tokens/IDataProtectorTokenFactory.cs index 038eff0f7..3809c40da 100644 --- a/src/Core/Tokens/IDataProtectorTokenFactory.cs +++ b/src/Core/Tokens/IDataProtectorTokenFactory.cs @@ -1,10 +1,9 @@ -namespace Bit.Core.Tokens +namespace Bit.Core.Tokens; + +public interface IDataProtectorTokenFactory where T : Tokenable { - public interface IDataProtectorTokenFactory where T : Tokenable - { - string Protect(T data); - T Unprotect(string token); - bool TryUnprotect(string token, out T data); - bool TokenValid(string token); - } + string Protect(T data); + T Unprotect(string token); + bool TryUnprotect(string token, out T data); + bool TokenValid(string token); } diff --git a/src/Core/Tokens/Token.cs b/src/Core/Tokens/Token.cs index 396b8747d..a50b81fbb 100644 --- a/src/Core/Tokens/Token.cs +++ b/src/Core/Tokens/Token.cs @@ -1,37 +1,36 @@ using Microsoft.AspNetCore.DataProtection; -namespace Bit.Core.Tokens +namespace Bit.Core.Tokens; + +public class Token { - public class Token + private readonly string _token; + + public Token(string token) { - private readonly string _token; - - public Token(string token) - { - _token = token; - } - - public Token WithPrefix(string prefix) - { - return new Token($"{prefix}{_token}"); - } - - public Token RemovePrefix(string expectedPrefix) - { - if (!_token.StartsWith(expectedPrefix)) - { - throw new BadTokenException($"Expected prefix, {expectedPrefix}, was not present."); - } - - return new Token(_token[expectedPrefix.Length..]); - } - - public Token ProtectWith(IDataProtector dataProtector) => - new(dataProtector.Protect(ToString())); - - public Token UnprotectWith(IDataProtector dataProtector) => - new(dataProtector.Unprotect(ToString())); - - public override string ToString() => _token; + _token = token; } + + public Token WithPrefix(string prefix) + { + return new Token($"{prefix}{_token}"); + } + + public Token RemovePrefix(string expectedPrefix) + { + if (!_token.StartsWith(expectedPrefix)) + { + throw new BadTokenException($"Expected prefix, {expectedPrefix}, was not present."); + } + + return new Token(_token[expectedPrefix.Length..]); + } + + public Token ProtectWith(IDataProtector dataProtector) => + new(dataProtector.Protect(ToString())); + + public Token UnprotectWith(IDataProtector dataProtector) => + new(dataProtector.Unprotect(ToString())); + + public override string ToString() => _token; } diff --git a/src/Core/Tokens/Tokenable.cs b/src/Core/Tokens/Tokenable.cs index c5c57c2f7..a145e64bb 100644 --- a/src/Core/Tokens/Tokenable.cs +++ b/src/Core/Tokens/Tokenable.cs @@ -1,20 +1,19 @@ using System.Text.Json; -namespace Bit.Core.Tokens +namespace Bit.Core.Tokens; + +public abstract class Tokenable { - public abstract class Tokenable + public abstract bool Valid { get; } + + public Token ToToken() { - public abstract bool Valid { get; } + return new Token(JsonSerializer.Serialize(this, this.GetType())); + } - public Token ToToken() - { - return new Token(JsonSerializer.Serialize(this, this.GetType())); - } - - public static T FromToken(string token) => FromToken(new Token(token)); - public static T FromToken(Token token) - { - return JsonSerializer.Deserialize(token.ToString()); - } + public static T FromToken(string token) => FromToken(new Token(token)); + public static T FromToken(Token token) + { + return JsonSerializer.Deserialize(token.ToString()); } } diff --git a/src/Core/Utilities/BillingHelpers.cs b/src/Core/Utilities/BillingHelpers.cs index 41202a2b4..e7ccfc354 100644 --- a/src/Core/Utilities/BillingHelpers.cs +++ b/src/Core/Utilities/BillingHelpers.cs @@ -2,57 +2,56 @@ using Bit.Core.Exceptions; using Bit.Core.Services; -namespace Bit.Core.Utilities +namespace Bit.Core.Utilities; + +public static class BillingHelpers { - public static class BillingHelpers + internal static async Task AdjustStorageAsync(IPaymentService paymentService, IStorableSubscriber storableSubscriber, + short storageAdjustmentGb, string storagePlanId) { - internal static async Task AdjustStorageAsync(IPaymentService paymentService, IStorableSubscriber storableSubscriber, - short storageAdjustmentGb, string storagePlanId) + if (storableSubscriber == null) { - if (storableSubscriber == null) - { - throw new ArgumentNullException(nameof(storableSubscriber)); - } - - if (string.IsNullOrWhiteSpace(storableSubscriber.GatewayCustomerId)) - { - throw new BadRequestException("No payment method found."); - } - - if (string.IsNullOrWhiteSpace(storableSubscriber.GatewaySubscriptionId)) - { - throw new BadRequestException("No subscription found."); - } - - if (!storableSubscriber.MaxStorageGb.HasValue) - { - throw new BadRequestException("No access to storage."); - } - - var newStorageGb = (short)(storableSubscriber.MaxStorageGb.Value + storageAdjustmentGb); - if (newStorageGb < 1) - { - newStorageGb = 1; - } - - if (newStorageGb > 100) - { - throw new BadRequestException("Maximum storage is 100 GB."); - } - - var remainingStorage = storableSubscriber.StorageBytesRemaining(newStorageGb); - if (remainingStorage < 0) - { - throw new BadRequestException("You are currently using " + - $"{CoreHelpers.ReadableBytesSize(storableSubscriber.Storage.GetValueOrDefault(0))} of storage. " + - "Delete some stored data first."); - } - - var additionalStorage = newStorageGb - 1; - var paymentIntentClientSecret = await paymentService.AdjustStorageAsync(storableSubscriber, - additionalStorage, storagePlanId); - storableSubscriber.MaxStorageGb = newStorageGb; - return paymentIntentClientSecret; + throw new ArgumentNullException(nameof(storableSubscriber)); } + + if (string.IsNullOrWhiteSpace(storableSubscriber.GatewayCustomerId)) + { + throw new BadRequestException("No payment method found."); + } + + if (string.IsNullOrWhiteSpace(storableSubscriber.GatewaySubscriptionId)) + { + throw new BadRequestException("No subscription found."); + } + + if (!storableSubscriber.MaxStorageGb.HasValue) + { + throw new BadRequestException("No access to storage."); + } + + var newStorageGb = (short)(storableSubscriber.MaxStorageGb.Value + storageAdjustmentGb); + if (newStorageGb < 1) + { + newStorageGb = 1; + } + + if (newStorageGb > 100) + { + throw new BadRequestException("Maximum storage is 100 GB."); + } + + var remainingStorage = storableSubscriber.StorageBytesRemaining(newStorageGb); + if (remainingStorage < 0) + { + throw new BadRequestException("You are currently using " + + $"{CoreHelpers.ReadableBytesSize(storableSubscriber.Storage.GetValueOrDefault(0))} of storage. " + + "Delete some stored data first."); + } + + var additionalStorage = newStorageGb - 1; + var paymentIntentClientSecret = await paymentService.AdjustStorageAsync(storableSubscriber, + additionalStorage, storagePlanId); + storableSubscriber.MaxStorageGb = newStorageGb; + return paymentIntentClientSecret; } } diff --git a/src/Core/Utilities/BitPayClient.cs b/src/Core/Utilities/BitPayClient.cs index 2532e8476..35a078998 100644 --- a/src/Core/Utilities/BitPayClient.cs +++ b/src/Core/Utilities/BitPayClient.cs @@ -1,28 +1,27 @@ using Bit.Core.Settings; -namespace Bit.Core.Utilities +namespace Bit.Core.Utilities; + +public class BitPayClient { - public class BitPayClient + private readonly BitPayLight.BitPay _bpClient; + + public BitPayClient(GlobalSettings globalSettings) { - private readonly BitPayLight.BitPay _bpClient; - - public BitPayClient(GlobalSettings globalSettings) + if (CoreHelpers.SettingHasValue(globalSettings.BitPay.Token)) { - if (CoreHelpers.SettingHasValue(globalSettings.BitPay.Token)) - { - _bpClient = new BitPayLight.BitPay(globalSettings.BitPay.Token, - globalSettings.BitPay.Production ? BitPayLight.Env.Prod : BitPayLight.Env.Test); - } - } - - public Task GetInvoiceAsync(string id) - { - return _bpClient.GetInvoice(id); - } - - public Task CreateInvoiceAsync(BitPayLight.Models.Invoice.Invoice invoice) - { - return _bpClient.CreateInvoice(invoice); + _bpClient = new BitPayLight.BitPay(globalSettings.BitPay.Token, + globalSettings.BitPay.Production ? BitPayLight.Env.Prod : BitPayLight.Env.Test); } } + + public Task GetInvoiceAsync(string id) + { + return _bpClient.GetInvoice(id); + } + + public Task CreateInvoiceAsync(BitPayLight.Models.Invoice.Invoice invoice) + { + return _bpClient.CreateInvoice(invoice); + } } diff --git a/src/Core/Utilities/CaptchaProtectedAttribute.cs b/src/Core/Utilities/CaptchaProtectedAttribute.cs index 102f1f175..6a5de6a9d 100644 --- a/src/Core/Utilities/CaptchaProtectedAttribute.cs +++ b/src/Core/Utilities/CaptchaProtectedAttribute.cs @@ -5,32 +5,31 @@ using Bit.Core.Services; using Microsoft.AspNetCore.Mvc.Filters; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.Utilities +namespace Bit.Core.Utilities; + +public class CaptchaProtectedAttribute : ActionFilterAttribute { - public class CaptchaProtectedAttribute : ActionFilterAttribute + public string ModelParameterName { get; set; } = "model"; + + public override void OnActionExecuting(ActionExecutingContext context) { - public string ModelParameterName { get; set; } = "model"; + var currentContext = context.HttpContext.RequestServices.GetRequiredService(); + var captchaValidationService = context.HttpContext.RequestServices.GetRequiredService(); - public override void OnActionExecuting(ActionExecutingContext context) + if (captchaValidationService.RequireCaptchaValidation(currentContext, null)) { - var currentContext = context.HttpContext.RequestServices.GetRequiredService(); - var captchaValidationService = context.HttpContext.RequestServices.GetRequiredService(); + var captchaResponse = (context.ActionArguments[ModelParameterName] as ICaptchaProtectedModel)?.CaptchaResponse; - if (captchaValidationService.RequireCaptchaValidation(currentContext, null)) + if (string.IsNullOrWhiteSpace(captchaResponse)) { - var captchaResponse = (context.ActionArguments[ModelParameterName] as ICaptchaProtectedModel)?.CaptchaResponse; + throw new BadRequestException(captchaValidationService.SiteKeyResponseKeyName, captchaValidationService.SiteKey); + } - if (string.IsNullOrWhiteSpace(captchaResponse)) - { - throw new BadRequestException(captchaValidationService.SiteKeyResponseKeyName, captchaValidationService.SiteKey); - } - - var captchaValidationResponse = captchaValidationService.ValidateCaptchaResponseAsync(captchaResponse, - currentContext.IpAddress, null).GetAwaiter().GetResult(); - if (!captchaValidationResponse.Success || captchaValidationResponse.IsBot) - { - throw new BadRequestException("Captcha is invalid. Please refresh and try again"); - } + var captchaValidationResponse = captchaValidationService.ValidateCaptchaResponseAsync(captchaResponse, + currentContext.IpAddress, null).GetAwaiter().GetResult(); + if (!captchaValidationResponse.Success || captchaValidationResponse.IsBot) + { + throw new BadRequestException("Captcha is invalid. Please refresh and try again"); } } } diff --git a/src/Core/Utilities/ClaimsExtensions.cs b/src/Core/Utilities/ClaimsExtensions.cs index ef25d1483..75478869e 100644 --- a/src/Core/Utilities/ClaimsExtensions.cs +++ b/src/Core/Utilities/ClaimsExtensions.cs @@ -1,12 +1,11 @@ using System.Security.Claims; -namespace Bit.Core.Utilities +namespace Bit.Core.Utilities; + +public static class ClaimsExtensions { - public static class ClaimsExtensions + public static bool HasSsoIdP(this IEnumerable claims) { - public static bool HasSsoIdP(this IEnumerable claims) - { - return claims.Any(c => c.Type == "idp" && c.Value == "sso"); - } + return claims.Any(c => c.Type == "idp" && c.Value == "sso"); } } diff --git a/src/Core/Utilities/CoreHelpers.cs b/src/Core/Utilities/CoreHelpers.cs index 7ad850a26..ef6848cf1 100644 --- a/src/Core/Utilities/CoreHelpers.cs +++ b/src/Core/Utilities/CoreHelpers.cs @@ -19,835 +19,834 @@ using IdentityModel; using Microsoft.AspNetCore.DataProtection; using MimeKit; -namespace Bit.Core.Utilities +namespace Bit.Core.Utilities; + +public static class CoreHelpers { - public static class CoreHelpers + private static readonly long _baseDateTicks = new DateTime(1900, 1, 1).Ticks; + private static readonly DateTime _epoc = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc); + private static readonly DateTime _max = new DateTime(9999, 1, 1, 0, 0, 0, DateTimeKind.Utc); + private static readonly Random _random = new Random(); + private static string _version; + private static readonly string CloudFlareConnectingIp = "CF-Connecting-IP"; + private static readonly string RealIp = "X-Real-IP"; + + /// + /// Generate sequential Guid for Sql Server. + /// ref: https://github.com/nhibernate/nhibernate-core/blob/master/src/NHibernate/Id/GuidCombGenerator.cs + /// + /// A comb Guid. + public static Guid GenerateComb() + => GenerateComb(Guid.NewGuid(), DateTime.UtcNow); + + /// + /// Implementation of with input parameters to remove randomness. + /// This should NOT be used outside of testing. + /// + /// + /// You probably don't want to use this method and instead want to use with no parameters + /// + internal static Guid GenerateComb(Guid startingGuid, DateTime time) { - private static readonly long _baseDateTicks = new DateTime(1900, 1, 1).Ticks; - private static readonly DateTime _epoc = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc); - private static readonly DateTime _max = new DateTime(9999, 1, 1, 0, 0, 0, DateTimeKind.Utc); - private static readonly Random _random = new Random(); - private static string _version; - private static readonly string CloudFlareConnectingIp = "CF-Connecting-IP"; - private static readonly string RealIp = "X-Real-IP"; + var guidArray = startingGuid.ToByteArray(); - /// - /// Generate sequential Guid for Sql Server. - /// ref: https://github.com/nhibernate/nhibernate-core/blob/master/src/NHibernate/Id/GuidCombGenerator.cs - /// - /// A comb Guid. - public static Guid GenerateComb() - => GenerateComb(Guid.NewGuid(), DateTime.UtcNow); + // Get the days and milliseconds which will be used to build the byte string + var days = new TimeSpan(time.Ticks - _baseDateTicks); + var msecs = time.TimeOfDay; - /// - /// Implementation of with input parameters to remove randomness. - /// This should NOT be used outside of testing. - /// - /// - /// You probably don't want to use this method and instead want to use with no parameters - /// - internal static Guid GenerateComb(Guid startingGuid, DateTime time) + // Convert to a byte array + // Note that SQL Server is accurate to 1/300th of a millisecond so we divide by 3.333333 + var daysArray = BitConverter.GetBytes(days.Days); + var msecsArray = BitConverter.GetBytes((long)(msecs.TotalMilliseconds / 3.333333)); + + // Reverse the bytes to match SQL Servers ordering + Array.Reverse(daysArray); + Array.Reverse(msecsArray); + + // Copy the bytes into the guid + Array.Copy(daysArray, daysArray.Length - 2, guidArray, guidArray.Length - 6, 2); + Array.Copy(msecsArray, msecsArray.Length - 4, guidArray, guidArray.Length - 4, 4); + + return new Guid(guidArray); + } + + public static IEnumerable> Batch(this IEnumerable source, int size) + { + T[] bucket = null; + var count = 0; + foreach (var item in source) { - var guidArray = startingGuid.ToByteArray(); + if (bucket == null) + { + bucket = new T[size]; + } + bucket[count++] = item; + if (count != size) + { + continue; + } + yield return bucket.Select(x => x); + bucket = null; + count = 0; + } + // Return the last bucket with all remaining elements + if (bucket != null && count > 0) + { + yield return bucket.Take(count); + } + } - // Get the days and milliseconds which will be used to build the byte string - var days = new TimeSpan(time.Ticks - _baseDateTicks); - var msecs = time.TimeOfDay; + public static string CleanCertificateThumbprint(string thumbprint) + { + // Clean possible garbage characters from thumbprint copy/paste + // ref http://stackoverflow.com/questions/8448147/problems-with-x509store-certificates-find-findbythumbprint + return Regex.Replace(thumbprint, @"[^\da-fA-F]", string.Empty).ToUpper(); + } - // Convert to a byte array - // Note that SQL Server is accurate to 1/300th of a millisecond so we divide by 3.333333 - var daysArray = BitConverter.GetBytes(days.Days); - var msecsArray = BitConverter.GetBytes((long)(msecs.TotalMilliseconds / 3.333333)); + public static X509Certificate2 GetCertificate(string thumbprint) + { + thumbprint = CleanCertificateThumbprint(thumbprint); - // Reverse the bytes to match SQL Servers ordering - Array.Reverse(daysArray); - Array.Reverse(msecsArray); - - // Copy the bytes into the guid - Array.Copy(daysArray, daysArray.Length - 2, guidArray, guidArray.Length - 6, 2); - Array.Copy(msecsArray, msecsArray.Length - 4, guidArray, guidArray.Length - 4, 4); - - return new Guid(guidArray); + X509Certificate2 cert = null; + var certStore = new X509Store(StoreName.My, StoreLocation.CurrentUser); + certStore.Open(OpenFlags.ReadOnly); + var certCollection = certStore.Certificates.Find(X509FindType.FindByThumbprint, thumbprint, false); + if (certCollection.Count > 0) + { + cert = certCollection[0]; } - public static IEnumerable> Batch(this IEnumerable source, int size) + certStore.Close(); + return cert; + } + + public static X509Certificate2 GetCertificate(string file, string password) + { + return new X509Certificate2(file, password); + } + + public async static Task GetEmbeddedCertificateAsync(string file, string password) + { + var assembly = typeof(CoreHelpers).GetTypeInfo().Assembly; + using (var s = assembly.GetManifestResourceStream($"Bit.Core.{file}")) + using (var ms = new MemoryStream()) { - T[] bucket = null; - var count = 0; - foreach (var item in source) + await s.CopyToAsync(ms); + return new X509Certificate2(ms.ToArray(), password); + } + } + + public static string GetEmbeddedResourceContentsAsync(string file) + { + var assembly = Assembly.GetCallingAssembly(); + var resourceName = assembly.GetManifestResourceNames().Single(n => n.EndsWith(file)); + using (var stream = assembly.GetManifestResourceStream(resourceName)) + using (var reader = new StreamReader(stream)) + { + return reader.ReadToEnd(); + } + } + + public async static Task GetBlobCertificateAsync(string connectionString, string container, string file, string password) + { + try + { + var blobServiceClient = new BlobServiceClient(connectionString); + var containerRef2 = blobServiceClient.GetBlobContainerClient(container); + var blobRef = containerRef2.GetBlobClient(file); + + using var memStream = new MemoryStream(); + await blobRef.DownloadToAsync(memStream).ConfigureAwait(false); + return new X509Certificate2(memStream.ToArray(), password); + } + catch (RequestFailedException ex) + when (ex.ErrorCode == BlobErrorCode.ContainerNotFound || ex.ErrorCode == BlobErrorCode.BlobNotFound) + { + return null; + } + catch (Exception) + { + return null; + } + } + + public static long ToEpocMilliseconds(DateTime date) + { + return (long)Math.Round((date - _epoc).TotalMilliseconds, 0); + } + + public static DateTime FromEpocMilliseconds(long milliseconds) + { + return _epoc.AddMilliseconds(milliseconds); + } + + public static long ToEpocSeconds(DateTime date) + { + return (long)Math.Round((date - _epoc).TotalSeconds, 0); + } + + public static DateTime FromEpocSeconds(long seconds) + { + return _epoc.AddSeconds(seconds); + } + + public static string U2fAppIdUrl(GlobalSettings globalSettings) + { + return string.Concat(globalSettings.BaseServiceUri.Vault, "/app-id.json"); + } + + public static string RandomString(int length, bool alpha = true, bool upper = true, bool lower = true, + bool numeric = true, bool special = false) + { + return RandomString(length, RandomStringCharacters(alpha, upper, lower, numeric, special)); + } + + public static string RandomString(int length, string characters) + { + return new string(Enumerable.Repeat(characters, length).Select(s => s[_random.Next(s.Length)]).ToArray()); + } + + public static string SecureRandomString(int length, bool alpha = true, bool upper = true, bool lower = true, + bool numeric = true, bool special = false) + { + return SecureRandomString(length, RandomStringCharacters(alpha, upper, lower, numeric, special)); + } + + // ref https://stackoverflow.com/a/8996788/1090359 with modifications + public static string SecureRandomString(int length, string characters) + { + if (length < 0) + { + throw new ArgumentOutOfRangeException(nameof(length), "length cannot be less than zero."); + } + + if ((characters?.Length ?? 0) == 0) + { + throw new ArgumentOutOfRangeException(nameof(characters), "characters invalid."); + } + + const int byteSize = 0x100; + if (byteSize < characters.Length) + { + throw new ArgumentException( + string.Format("{0} may contain no more than {1} characters.", nameof(characters), byteSize), + nameof(characters)); + } + + var outOfRangeStart = byteSize - (byteSize % characters.Length); + using (var rng = RandomNumberGenerator.Create()) + { + var sb = new StringBuilder(); + var buffer = new byte[128]; + while (sb.Length < length) { - if (bucket == null) + rng.GetBytes(buffer); + for (var i = 0; i < buffer.Length && sb.Length < length; ++i) { - bucket = new T[size]; - } - bucket[count++] = item; - if (count != size) - { - continue; - } - yield return bucket.Select(x => x); - bucket = null; - count = 0; - } - // Return the last bucket with all remaining elements - if (bucket != null && count > 0) - { - yield return bucket.Take(count); - } - } - - public static string CleanCertificateThumbprint(string thumbprint) - { - // Clean possible garbage characters from thumbprint copy/paste - // ref http://stackoverflow.com/questions/8448147/problems-with-x509store-certificates-find-findbythumbprint - return Regex.Replace(thumbprint, @"[^\da-fA-F]", string.Empty).ToUpper(); - } - - public static X509Certificate2 GetCertificate(string thumbprint) - { - thumbprint = CleanCertificateThumbprint(thumbprint); - - X509Certificate2 cert = null; - var certStore = new X509Store(StoreName.My, StoreLocation.CurrentUser); - certStore.Open(OpenFlags.ReadOnly); - var certCollection = certStore.Certificates.Find(X509FindType.FindByThumbprint, thumbprint, false); - if (certCollection.Count > 0) - { - cert = certCollection[0]; - } - - certStore.Close(); - return cert; - } - - public static X509Certificate2 GetCertificate(string file, string password) - { - return new X509Certificate2(file, password); - } - - public async static Task GetEmbeddedCertificateAsync(string file, string password) - { - var assembly = typeof(CoreHelpers).GetTypeInfo().Assembly; - using (var s = assembly.GetManifestResourceStream($"Bit.Core.{file}")) - using (var ms = new MemoryStream()) - { - await s.CopyToAsync(ms); - return new X509Certificate2(ms.ToArray(), password); - } - } - - public static string GetEmbeddedResourceContentsAsync(string file) - { - var assembly = Assembly.GetCallingAssembly(); - var resourceName = assembly.GetManifestResourceNames().Single(n => n.EndsWith(file)); - using (var stream = assembly.GetManifestResourceStream(resourceName)) - using (var reader = new StreamReader(stream)) - { - return reader.ReadToEnd(); - } - } - - public async static Task GetBlobCertificateAsync(string connectionString, string container, string file, string password) - { - try - { - var blobServiceClient = new BlobServiceClient(connectionString); - var containerRef2 = blobServiceClient.GetBlobContainerClient(container); - var blobRef = containerRef2.GetBlobClient(file); - - using var memStream = new MemoryStream(); - await blobRef.DownloadToAsync(memStream).ConfigureAwait(false); - return new X509Certificate2(memStream.ToArray(), password); - } - catch (RequestFailedException ex) - when (ex.ErrorCode == BlobErrorCode.ContainerNotFound || ex.ErrorCode == BlobErrorCode.BlobNotFound) - { - return null; - } - catch (Exception) - { - return null; - } - } - - public static long ToEpocMilliseconds(DateTime date) - { - return (long)Math.Round((date - _epoc).TotalMilliseconds, 0); - } - - public static DateTime FromEpocMilliseconds(long milliseconds) - { - return _epoc.AddMilliseconds(milliseconds); - } - - public static long ToEpocSeconds(DateTime date) - { - return (long)Math.Round((date - _epoc).TotalSeconds, 0); - } - - public static DateTime FromEpocSeconds(long seconds) - { - return _epoc.AddSeconds(seconds); - } - - public static string U2fAppIdUrl(GlobalSettings globalSettings) - { - return string.Concat(globalSettings.BaseServiceUri.Vault, "/app-id.json"); - } - - public static string RandomString(int length, bool alpha = true, bool upper = true, bool lower = true, - bool numeric = true, bool special = false) - { - return RandomString(length, RandomStringCharacters(alpha, upper, lower, numeric, special)); - } - - public static string RandomString(int length, string characters) - { - return new string(Enumerable.Repeat(characters, length).Select(s => s[_random.Next(s.Length)]).ToArray()); - } - - public static string SecureRandomString(int length, bool alpha = true, bool upper = true, bool lower = true, - bool numeric = true, bool special = false) - { - return SecureRandomString(length, RandomStringCharacters(alpha, upper, lower, numeric, special)); - } - - // ref https://stackoverflow.com/a/8996788/1090359 with modifications - public static string SecureRandomString(int length, string characters) - { - if (length < 0) - { - throw new ArgumentOutOfRangeException(nameof(length), "length cannot be less than zero."); - } - - if ((characters?.Length ?? 0) == 0) - { - throw new ArgumentOutOfRangeException(nameof(characters), "characters invalid."); - } - - const int byteSize = 0x100; - if (byteSize < characters.Length) - { - throw new ArgumentException( - string.Format("{0} may contain no more than {1} characters.", nameof(characters), byteSize), - nameof(characters)); - } - - var outOfRangeStart = byteSize - (byteSize % characters.Length); - using (var rng = RandomNumberGenerator.Create()) - { - var sb = new StringBuilder(); - var buffer = new byte[128]; - while (sb.Length < length) - { - rng.GetBytes(buffer); - for (var i = 0; i < buffer.Length && sb.Length < length; ++i) + // Divide the byte into charSet-sized groups. If the random value falls into the last group and the + // last group is too small to choose from the entire allowedCharSet, ignore the value in order to + // avoid biasing the result. + if (outOfRangeStart <= buffer[i]) { - // Divide the byte into charSet-sized groups. If the random value falls into the last group and the - // last group is too small to choose from the entire allowedCharSet, ignore the value in order to - // avoid biasing the result. - if (outOfRangeStart <= buffer[i]) - { - continue; - } - - sb.Append(characters[buffer[i] % characters.Length]); + continue; } - } - return sb.ToString(); - } - } - - private static string RandomStringCharacters(bool alpha, bool upper, bool lower, bool numeric, bool special) - { - var characters = string.Empty; - if (alpha) - { - if (upper) - { - characters += "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; - } - - if (lower) - { - characters += "abcdefghijklmnopqrstuvwxyz"; + sb.Append(characters[buffer[i] % characters.Length]); } } - if (numeric) + return sb.ToString(); + } + } + + private static string RandomStringCharacters(bool alpha, bool upper, bool lower, bool numeric, bool special) + { + var characters = string.Empty; + if (alpha) + { + if (upper) { - characters += "0123456789"; + characters += "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; } - if (special) + if (lower) { - characters += "!@#$%^*&"; - } - - return characters; - } - - // ref: https://stackoverflow.com/a/11124118/1090359 - // Returns the human-readable file size for an arbitrary 64-bit file size . - // The format is "0.## XB", ex: "4.2 KB" or "1.43 GB" - public static string ReadableBytesSize(long size) - { - // Get absolute value - var absoluteSize = (size < 0 ? -size : size); - - // Determine the suffix and readable value - string suffix; - double readable; - if (absoluteSize >= 0x40000000) // 1 Gigabyte - { - suffix = "GB"; - readable = (size >> 20); - } - else if (absoluteSize >= 0x100000) // 1 Megabyte - { - suffix = "MB"; - readable = (size >> 10); - } - else if (absoluteSize >= 0x400) // 1 Kilobyte - { - suffix = "KB"; - readable = size; - } - else - { - return size.ToString("0 Bytes"); // Byte - } - - // Divide by 1024 to get fractional value - readable = (readable / 1024); - - // Return formatted number with suffix - return readable.ToString("0.## ") + suffix; - } - - /// - /// Creates a clone of the given object through serializing to json and deserializing. - /// This method is subject to the limitations of System.Text.Json. For example, properties with - /// inaccessible setters will not be set. - /// - public static T CloneObject(T obj) - { - return JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); - } - - public static bool SettingHasValue(string setting) - { - var normalizedSetting = setting?.ToLowerInvariant(); - return !string.IsNullOrWhiteSpace(normalizedSetting) && !normalizedSetting.Equals("secret") && - !normalizedSetting.Equals("replace"); - } - - public static string Base64EncodeString(string input) - { - return Convert.ToBase64String(Encoding.UTF8.GetBytes(input)); - } - - public static string Base64DecodeString(string input) - { - return Encoding.UTF8.GetString(Convert.FromBase64String(input)); - } - - public static string Base64UrlEncodeString(string input) - { - return Base64UrlEncode(Encoding.UTF8.GetBytes(input)); - } - - public static string Base64UrlDecodeString(string input) - { - return Encoding.UTF8.GetString(Base64UrlDecode(input)); - } - - public static string Base64UrlEncode(byte[] input) - { - var output = Convert.ToBase64String(input) - .Replace('+', '-') - .Replace('/', '_') - .Replace("=", string.Empty); - return output; - } - - public static byte[] Base64UrlDecode(string input) - { - var output = input; - // 62nd char of encoding - output = output.Replace('-', '+'); - // 63rd char of encoding - output = output.Replace('_', '/'); - // Pad with trailing '='s - switch (output.Length % 4) - { - case 0: - // No pad chars in this case - break; - case 2: - // Two pad chars - output += "=="; break; - case 3: - // One pad char - output += "="; break; - default: - throw new InvalidOperationException("Illegal base64url string!"); - } - - // Standard base64 decoder - return Convert.FromBase64String(output); - } - - public static string PunyEncode(string text) - { - if (text == "") - { - return ""; - } - - if (text == null) - { - return null; - } - - if (!text.Contains("@")) - { - // Assume domain name or non-email address - var idn = new IdnMapping(); - return idn.GetAscii(text); - } - else - { - // Assume email address - return MailboxAddress.EncodeAddrspec(text); + characters += "abcdefghijklmnopqrstuvwxyz"; } } - public static string FormatLicenseSignatureValue(object val) + if (numeric) { - if (val == null) - { - return string.Empty; - } - - if (val.GetType() == typeof(DateTime)) - { - return ToEpocSeconds((DateTime)val).ToString(); - } - - if (val.GetType() == typeof(bool)) - { - return val.ToString().ToLowerInvariant(); - } - - if (val is PlanType planType) - { - return planType switch - { - PlanType.Free => "Free", - PlanType.FamiliesAnnually2019 => "FamiliesAnnually", - PlanType.TeamsMonthly2019 => "TeamsMonthly", - PlanType.TeamsAnnually2019 => "TeamsAnnually", - PlanType.EnterpriseMonthly2019 => "EnterpriseMonthly", - PlanType.EnterpriseAnnually2019 => "EnterpriseAnnually", - PlanType.Custom => "Custom", - _ => ((byte)planType).ToString(), - }; - } - - return val.ToString(); + characters += "0123456789"; } - public static string GetVersion() + if (special) { - if (string.IsNullOrWhiteSpace(_version)) - { - _version = Assembly.GetEntryAssembly() - .GetCustomAttribute() - .InformationalVersion; - } - - return _version; + characters += "!@#$%^*&"; } - public static string SanitizeForEmail(string value, bool htmlEncode = true) + return characters; + } + + // ref: https://stackoverflow.com/a/11124118/1090359 + // Returns the human-readable file size for an arbitrary 64-bit file size . + // The format is "0.## XB", ex: "4.2 KB" or "1.43 GB" + public static string ReadableBytesSize(long size) + { + // Get absolute value + var absoluteSize = (size < 0 ? -size : size); + + // Determine the suffix and readable value + string suffix; + double readable; + if (absoluteSize >= 0x40000000) // 1 Gigabyte { - var cleanedValue = value.Replace("@", "[at]"); - var regexOptions = RegexOptions.CultureInvariant | - RegexOptions.Singleline | - RegexOptions.IgnoreCase; - cleanedValue = Regex.Replace(cleanedValue, @"(\.\w)", - m => string.Concat("[dot]", m.ToString().Last()), regexOptions); - while (Regex.IsMatch(cleanedValue, @"((^|\b)(\w*)://)", regexOptions)) - { - cleanedValue = Regex.Replace(cleanedValue, @"((^|\b)(\w*)://)", - string.Empty, regexOptions); - } - return htmlEncode ? HttpUtility.HtmlEncode(cleanedValue) : cleanedValue; + suffix = "GB"; + readable = (size >> 20); + } + else if (absoluteSize >= 0x100000) // 1 Megabyte + { + suffix = "MB"; + readable = (size >> 10); + } + else if (absoluteSize >= 0x400) // 1 Kilobyte + { + suffix = "KB"; + readable = size; + } + else + { + return size.ToString("0 Bytes"); // Byte } - public static string DateTimeToTableStorageKey(DateTime? date = null) - { - if (date.HasValue) - { - date = date.Value.ToUniversalTime(); - } - else - { - date = DateTime.UtcNow; - } + // Divide by 1024 to get fractional value + readable = (readable / 1024); - return _max.Subtract(date.Value).TotalMilliseconds.ToString(CultureInfo.InvariantCulture); + // Return formatted number with suffix + return readable.ToString("0.## ") + suffix; + } + + /// + /// Creates a clone of the given object through serializing to json and deserializing. + /// This method is subject to the limitations of System.Text.Json. For example, properties with + /// inaccessible setters will not be set. + /// + public static T CloneObject(T obj) + { + return JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); + } + + public static bool SettingHasValue(string setting) + { + var normalizedSetting = setting?.ToLowerInvariant(); + return !string.IsNullOrWhiteSpace(normalizedSetting) && !normalizedSetting.Equals("secret") && + !normalizedSetting.Equals("replace"); + } + + public static string Base64EncodeString(string input) + { + return Convert.ToBase64String(Encoding.UTF8.GetBytes(input)); + } + + public static string Base64DecodeString(string input) + { + return Encoding.UTF8.GetString(Convert.FromBase64String(input)); + } + + public static string Base64UrlEncodeString(string input) + { + return Base64UrlEncode(Encoding.UTF8.GetBytes(input)); + } + + public static string Base64UrlDecodeString(string input) + { + return Encoding.UTF8.GetString(Base64UrlDecode(input)); + } + + public static string Base64UrlEncode(byte[] input) + { + var output = Convert.ToBase64String(input) + .Replace('+', '-') + .Replace('/', '_') + .Replace("=", string.Empty); + return output; + } + + public static byte[] Base64UrlDecode(string input) + { + var output = input; + // 62nd char of encoding + output = output.Replace('-', '+'); + // 63rd char of encoding + output = output.Replace('_', '/'); + // Pad with trailing '='s + switch (output.Length % 4) + { + case 0: + // No pad chars in this case + break; + case 2: + // Two pad chars + output += "=="; break; + case 3: + // One pad char + output += "="; break; + default: + throw new InvalidOperationException("Illegal base64url string!"); } - // ref: https://stackoverflow.com/a/27545010/1090359 - public static Uri ExtendQuery(Uri uri, IDictionary values) + // Standard base64 decoder + return Convert.FromBase64String(output); + } + + public static string PunyEncode(string text) + { + if (text == "") { - var baseUri = uri.ToString(); - var queryString = string.Empty; - if (baseUri.Contains("?")) - { - var urlSplit = baseUri.Split('?'); - baseUri = urlSplit[0]; - queryString = urlSplit.Length > 1 ? urlSplit[1] : string.Empty; - } - - var queryCollection = HttpUtility.ParseQueryString(queryString); - foreach (var kvp in values ?? new Dictionary()) - { - queryCollection[kvp.Key] = kvp.Value; - } - - var uriKind = uri.IsAbsoluteUri ? UriKind.Absolute : UriKind.Relative; - if (queryCollection.Count == 0) - { - return new Uri(baseUri, uriKind); - } - return new Uri(string.Format("{0}?{1}", baseUri, queryCollection), uriKind); + return ""; } - public static string CustomProviderName(TwoFactorProviderType type) + if (text == null) { - return string.Concat("Custom_", type.ToString()); - } - - public static bool UserInviteTokenIsValid(IDataProtector protector, string token, string userEmail, - Guid orgUserId, IGlobalSettings globalSettings) - { - return TokenIsValid("OrganizationUserInvite", protector, token, userEmail, orgUserId, - globalSettings.OrganizationInviteExpirationHours); - } - - public static bool TokenIsValid(string firstTokenPart, IDataProtector protector, string token, string userEmail, - Guid id, double expirationInHours) - { - var invalid = true; - try - { - var unprotectedData = protector.Unprotect(token); - var dataParts = unprotectedData.Split(' '); - if (dataParts.Length == 4 && dataParts[0] == firstTokenPart && - new Guid(dataParts[1]) == id && - dataParts[2].Equals(userEmail, StringComparison.InvariantCultureIgnoreCase)) - { - var creationTime = FromEpocMilliseconds(Convert.ToInt64(dataParts[3])); - var expTime = creationTime.AddHours(expirationInHours); - invalid = expTime < DateTime.UtcNow; - } - } - catch - { - invalid = true; - } - - return !invalid; - } - - public static string GetApplicationCacheServiceBusSubcriptionName(GlobalSettings globalSettings) - { - var subName = globalSettings.ServiceBus.ApplicationCacheSubscriptionName; - if (string.IsNullOrWhiteSpace(subName)) - { - var websiteInstanceId = Environment.GetEnvironmentVariable("WEBSITE_INSTANCE_ID"); - if (string.IsNullOrWhiteSpace(websiteInstanceId)) - { - throw new Exception("No service bus subscription name available."); - } - else - { - subName = $"{globalSettings.ProjectName.ToLower()}_{websiteInstanceId}"; - if (subName.Length > 50) - { - subName = subName.Substring(0, 50); - } - } - } - return subName; - } - - public static string GetIpAddress(this Microsoft.AspNetCore.Http.HttpContext httpContext, - GlobalSettings globalSettings) - { - if (httpContext == null) - { - return null; - } - - if (!globalSettings.SelfHosted && httpContext.Request.Headers.ContainsKey(CloudFlareConnectingIp)) - { - return httpContext.Request.Headers[CloudFlareConnectingIp].ToString(); - } - if (globalSettings.SelfHosted && httpContext.Request.Headers.ContainsKey(RealIp)) - { - return httpContext.Request.Headers[RealIp].ToString(); - } - - return httpContext.Connection?.RemoteIpAddress?.ToString(); - } - - public static bool IsCorsOriginAllowed(string origin, GlobalSettings globalSettings) - { - return - // Web vault - origin == globalSettings.BaseServiceUri.Vault || - // Safari extension origin - origin == "file://" || - // Product website - (!globalSettings.SelfHosted && origin == "https://bitwarden.com"); - } - - public static X509Certificate2 GetIdentityServerCertificate(GlobalSettings globalSettings) - { - if (globalSettings.SelfHosted && - SettingHasValue(globalSettings.IdentityServer.CertificatePassword) - && File.Exists("identity.pfx")) - { - return GetCertificate("identity.pfx", - globalSettings.IdentityServer.CertificatePassword); - } - else if (SettingHasValue(globalSettings.IdentityServer.CertificateThumbprint)) - { - return GetCertificate( - globalSettings.IdentityServer.CertificateThumbprint); - } - else if (!globalSettings.SelfHosted && - SettingHasValue(globalSettings.Storage?.ConnectionString) && - SettingHasValue(globalSettings.IdentityServer.CertificatePassword)) - { - return GetBlobCertificateAsync(globalSettings.Storage.ConnectionString, "certificates", - "identity.pfx", globalSettings.IdentityServer.CertificatePassword).GetAwaiter().GetResult(); - } return null; } - public static Dictionary AdjustIdentityServerConfig(Dictionary configDict, - string publicServiceUri, string internalServiceUri) + if (!text.Contains("@")) { - var dictReplace = new Dictionary(); - foreach (var item in configDict) - { - if (item.Key == "authorization_endpoint" && item.Value is string val) - { - var uri = new Uri(val); - dictReplace.Add(item.Key, string.Concat(publicServiceUri, uri.LocalPath)); - } - else if ((item.Key == "jwks_uri" || item.Key.EndsWith("_endpoint")) && item.Value is string val2) - { - var uri = new Uri(val2); - dictReplace.Add(item.Key, string.Concat(internalServiceUri, uri.LocalPath)); - } - } - foreach (var replace in dictReplace) - { - configDict[replace.Key] = replace.Value; - } - return configDict; + // Assume domain name or non-email address + var idn = new IdnMapping(); + return idn.GetAscii(text); } - - public static List> BuildIdentityClaims(User user, ICollection orgs, - ICollection providers, bool isPremium) + else { - var claims = new List>() - { - new KeyValuePair("premium", isPremium ? "true" : "false"), - new KeyValuePair(JwtClaimTypes.Email, user.Email), - new KeyValuePair(JwtClaimTypes.EmailVerified, user.EmailVerified ? "true" : "false"), - new KeyValuePair("sstamp", user.SecurityStamp) - }; - - if (!string.IsNullOrWhiteSpace(user.Name)) - { - claims.Add(new KeyValuePair(JwtClaimTypes.Name, user.Name)); - } - - // Orgs that this user belongs to - if (orgs.Any()) - { - foreach (var group in orgs.GroupBy(o => o.Type)) - { - switch (group.Key) - { - case Enums.OrganizationUserType.Owner: - foreach (var org in group) - { - claims.Add(new KeyValuePair("orgowner", org.Id.ToString())); - } - break; - case Enums.OrganizationUserType.Admin: - foreach (var org in group) - { - claims.Add(new KeyValuePair("orgadmin", org.Id.ToString())); - } - break; - case Enums.OrganizationUserType.Manager: - foreach (var org in group) - { - claims.Add(new KeyValuePair("orgmanager", org.Id.ToString())); - } - break; - case Enums.OrganizationUserType.User: - foreach (var org in group) - { - claims.Add(new KeyValuePair("orguser", org.Id.ToString())); - } - break; - case Enums.OrganizationUserType.Custom: - foreach (var org in group) - { - claims.Add(new KeyValuePair("orgcustom", org.Id.ToString())); - foreach (var (permission, claimName) in org.Permissions.ClaimsMap) - { - if (!permission) - { - continue; - } - - claims.Add(new KeyValuePair(claimName, org.Id.ToString())); - } - } - break; - default: - break; - } - } - } - - if (providers.Any()) - { - foreach (var group in providers.GroupBy(o => o.Type)) - { - switch (group.Key) - { - case ProviderUserType.ProviderAdmin: - foreach (var provider in group) - { - claims.Add(new KeyValuePair("providerprovideradmin", provider.Id.ToString())); - } - break; - case ProviderUserType.ServiceUser: - foreach (var provider in group) - { - claims.Add(new KeyValuePair("providerserviceuser", provider.Id.ToString())); - } - break; - } - } - } - - return claims; - } - - public static T LoadClassFromJsonData(string jsonData) where T : new() - { - if (string.IsNullOrWhiteSpace(jsonData)) - { - return new T(); - } - - var options = new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }; - - return System.Text.Json.JsonSerializer.Deserialize(jsonData, options); - } - - public static string ClassToJsonData(T data) - { - var options = new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }; - - return System.Text.Json.JsonSerializer.Serialize(data, options); - } - - public static ICollection AddIfNotExists(this ICollection list, T item) - { - if (list.Contains(item)) - { - return list; - } - list.Add(item); - return list; - } - - public static string DecodeMessageText(this QueueMessage message) - { - var text = message?.MessageText; - if (string.IsNullOrWhiteSpace(text)) - { - return text; - } - try - { - return Base64DecodeString(text); - } - catch - { - return text; - } - } - - public static bool FixedTimeEquals(string input1, string input2) - { - return CryptographicOperations.FixedTimeEquals( - Encoding.UTF8.GetBytes(input1), Encoding.UTF8.GetBytes(input2)); - } - - public static string ObfuscateEmail(string email) - { - if (email == null) - { - return email; - } - - var emailParts = email.Split('@', StringSplitOptions.RemoveEmptyEntries); - - if (emailParts.Length != 2) - { - return email; - } - - var username = emailParts[0]; - - if (username.Length < 2) - { - return email; - } - - var sb = new StringBuilder(); - sb.Append(emailParts[0][..2]); - for (var i = 2; i < emailParts[0].Length; i++) - { - sb.Append('*'); - } - - return sb.Append('@') - .Append(emailParts[1]) - .ToString(); - + // Assume email address + return MailboxAddress.EncodeAddrspec(text); } } + + public static string FormatLicenseSignatureValue(object val) + { + if (val == null) + { + return string.Empty; + } + + if (val.GetType() == typeof(DateTime)) + { + return ToEpocSeconds((DateTime)val).ToString(); + } + + if (val.GetType() == typeof(bool)) + { + return val.ToString().ToLowerInvariant(); + } + + if (val is PlanType planType) + { + return planType switch + { + PlanType.Free => "Free", + PlanType.FamiliesAnnually2019 => "FamiliesAnnually", + PlanType.TeamsMonthly2019 => "TeamsMonthly", + PlanType.TeamsAnnually2019 => "TeamsAnnually", + PlanType.EnterpriseMonthly2019 => "EnterpriseMonthly", + PlanType.EnterpriseAnnually2019 => "EnterpriseAnnually", + PlanType.Custom => "Custom", + _ => ((byte)planType).ToString(), + }; + } + + return val.ToString(); + } + + public static string GetVersion() + { + if (string.IsNullOrWhiteSpace(_version)) + { + _version = Assembly.GetEntryAssembly() + .GetCustomAttribute() + .InformationalVersion; + } + + return _version; + } + + public static string SanitizeForEmail(string value, bool htmlEncode = true) + { + var cleanedValue = value.Replace("@", "[at]"); + var regexOptions = RegexOptions.CultureInvariant | + RegexOptions.Singleline | + RegexOptions.IgnoreCase; + cleanedValue = Regex.Replace(cleanedValue, @"(\.\w)", + m => string.Concat("[dot]", m.ToString().Last()), regexOptions); + while (Regex.IsMatch(cleanedValue, @"((^|\b)(\w*)://)", regexOptions)) + { + cleanedValue = Regex.Replace(cleanedValue, @"((^|\b)(\w*)://)", + string.Empty, regexOptions); + } + return htmlEncode ? HttpUtility.HtmlEncode(cleanedValue) : cleanedValue; + } + + public static string DateTimeToTableStorageKey(DateTime? date = null) + { + if (date.HasValue) + { + date = date.Value.ToUniversalTime(); + } + else + { + date = DateTime.UtcNow; + } + + return _max.Subtract(date.Value).TotalMilliseconds.ToString(CultureInfo.InvariantCulture); + } + + // ref: https://stackoverflow.com/a/27545010/1090359 + public static Uri ExtendQuery(Uri uri, IDictionary values) + { + var baseUri = uri.ToString(); + var queryString = string.Empty; + if (baseUri.Contains("?")) + { + var urlSplit = baseUri.Split('?'); + baseUri = urlSplit[0]; + queryString = urlSplit.Length > 1 ? urlSplit[1] : string.Empty; + } + + var queryCollection = HttpUtility.ParseQueryString(queryString); + foreach (var kvp in values ?? new Dictionary()) + { + queryCollection[kvp.Key] = kvp.Value; + } + + var uriKind = uri.IsAbsoluteUri ? UriKind.Absolute : UriKind.Relative; + if (queryCollection.Count == 0) + { + return new Uri(baseUri, uriKind); + } + return new Uri(string.Format("{0}?{1}", baseUri, queryCollection), uriKind); + } + + public static string CustomProviderName(TwoFactorProviderType type) + { + return string.Concat("Custom_", type.ToString()); + } + + public static bool UserInviteTokenIsValid(IDataProtector protector, string token, string userEmail, + Guid orgUserId, IGlobalSettings globalSettings) + { + return TokenIsValid("OrganizationUserInvite", protector, token, userEmail, orgUserId, + globalSettings.OrganizationInviteExpirationHours); + } + + public static bool TokenIsValid(string firstTokenPart, IDataProtector protector, string token, string userEmail, + Guid id, double expirationInHours) + { + var invalid = true; + try + { + var unprotectedData = protector.Unprotect(token); + var dataParts = unprotectedData.Split(' '); + if (dataParts.Length == 4 && dataParts[0] == firstTokenPart && + new Guid(dataParts[1]) == id && + dataParts[2].Equals(userEmail, StringComparison.InvariantCultureIgnoreCase)) + { + var creationTime = FromEpocMilliseconds(Convert.ToInt64(dataParts[3])); + var expTime = creationTime.AddHours(expirationInHours); + invalid = expTime < DateTime.UtcNow; + } + } + catch + { + invalid = true; + } + + return !invalid; + } + + public static string GetApplicationCacheServiceBusSubcriptionName(GlobalSettings globalSettings) + { + var subName = globalSettings.ServiceBus.ApplicationCacheSubscriptionName; + if (string.IsNullOrWhiteSpace(subName)) + { + var websiteInstanceId = Environment.GetEnvironmentVariable("WEBSITE_INSTANCE_ID"); + if (string.IsNullOrWhiteSpace(websiteInstanceId)) + { + throw new Exception("No service bus subscription name available."); + } + else + { + subName = $"{globalSettings.ProjectName.ToLower()}_{websiteInstanceId}"; + if (subName.Length > 50) + { + subName = subName.Substring(0, 50); + } + } + } + return subName; + } + + public static string GetIpAddress(this Microsoft.AspNetCore.Http.HttpContext httpContext, + GlobalSettings globalSettings) + { + if (httpContext == null) + { + return null; + } + + if (!globalSettings.SelfHosted && httpContext.Request.Headers.ContainsKey(CloudFlareConnectingIp)) + { + return httpContext.Request.Headers[CloudFlareConnectingIp].ToString(); + } + if (globalSettings.SelfHosted && httpContext.Request.Headers.ContainsKey(RealIp)) + { + return httpContext.Request.Headers[RealIp].ToString(); + } + + return httpContext.Connection?.RemoteIpAddress?.ToString(); + } + + public static bool IsCorsOriginAllowed(string origin, GlobalSettings globalSettings) + { + return + // Web vault + origin == globalSettings.BaseServiceUri.Vault || + // Safari extension origin + origin == "file://" || + // Product website + (!globalSettings.SelfHosted && origin == "https://bitwarden.com"); + } + + public static X509Certificate2 GetIdentityServerCertificate(GlobalSettings globalSettings) + { + if (globalSettings.SelfHosted && + SettingHasValue(globalSettings.IdentityServer.CertificatePassword) + && File.Exists("identity.pfx")) + { + return GetCertificate("identity.pfx", + globalSettings.IdentityServer.CertificatePassword); + } + else if (SettingHasValue(globalSettings.IdentityServer.CertificateThumbprint)) + { + return GetCertificate( + globalSettings.IdentityServer.CertificateThumbprint); + } + else if (!globalSettings.SelfHosted && + SettingHasValue(globalSettings.Storage?.ConnectionString) && + SettingHasValue(globalSettings.IdentityServer.CertificatePassword)) + { + return GetBlobCertificateAsync(globalSettings.Storage.ConnectionString, "certificates", + "identity.pfx", globalSettings.IdentityServer.CertificatePassword).GetAwaiter().GetResult(); + } + return null; + } + + public static Dictionary AdjustIdentityServerConfig(Dictionary configDict, + string publicServiceUri, string internalServiceUri) + { + var dictReplace = new Dictionary(); + foreach (var item in configDict) + { + if (item.Key == "authorization_endpoint" && item.Value is string val) + { + var uri = new Uri(val); + dictReplace.Add(item.Key, string.Concat(publicServiceUri, uri.LocalPath)); + } + else if ((item.Key == "jwks_uri" || item.Key.EndsWith("_endpoint")) && item.Value is string val2) + { + var uri = new Uri(val2); + dictReplace.Add(item.Key, string.Concat(internalServiceUri, uri.LocalPath)); + } + } + foreach (var replace in dictReplace) + { + configDict[replace.Key] = replace.Value; + } + return configDict; + } + + public static List> BuildIdentityClaims(User user, ICollection orgs, + ICollection providers, bool isPremium) + { + var claims = new List>() + { + new KeyValuePair("premium", isPremium ? "true" : "false"), + new KeyValuePair(JwtClaimTypes.Email, user.Email), + new KeyValuePair(JwtClaimTypes.EmailVerified, user.EmailVerified ? "true" : "false"), + new KeyValuePair("sstamp", user.SecurityStamp) + }; + + if (!string.IsNullOrWhiteSpace(user.Name)) + { + claims.Add(new KeyValuePair(JwtClaimTypes.Name, user.Name)); + } + + // Orgs that this user belongs to + if (orgs.Any()) + { + foreach (var group in orgs.GroupBy(o => o.Type)) + { + switch (group.Key) + { + case Enums.OrganizationUserType.Owner: + foreach (var org in group) + { + claims.Add(new KeyValuePair("orgowner", org.Id.ToString())); + } + break; + case Enums.OrganizationUserType.Admin: + foreach (var org in group) + { + claims.Add(new KeyValuePair("orgadmin", org.Id.ToString())); + } + break; + case Enums.OrganizationUserType.Manager: + foreach (var org in group) + { + claims.Add(new KeyValuePair("orgmanager", org.Id.ToString())); + } + break; + case Enums.OrganizationUserType.User: + foreach (var org in group) + { + claims.Add(new KeyValuePair("orguser", org.Id.ToString())); + } + break; + case Enums.OrganizationUserType.Custom: + foreach (var org in group) + { + claims.Add(new KeyValuePair("orgcustom", org.Id.ToString())); + foreach (var (permission, claimName) in org.Permissions.ClaimsMap) + { + if (!permission) + { + continue; + } + + claims.Add(new KeyValuePair(claimName, org.Id.ToString())); + } + } + break; + default: + break; + } + } + } + + if (providers.Any()) + { + foreach (var group in providers.GroupBy(o => o.Type)) + { + switch (group.Key) + { + case ProviderUserType.ProviderAdmin: + foreach (var provider in group) + { + claims.Add(new KeyValuePair("providerprovideradmin", provider.Id.ToString())); + } + break; + case ProviderUserType.ServiceUser: + foreach (var provider in group) + { + claims.Add(new KeyValuePair("providerserviceuser", provider.Id.ToString())); + } + break; + } + } + } + + return claims; + } + + public static T LoadClassFromJsonData(string jsonData) where T : new() + { + if (string.IsNullOrWhiteSpace(jsonData)) + { + return new T(); + } + + var options = new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }; + + return System.Text.Json.JsonSerializer.Deserialize(jsonData, options); + } + + public static string ClassToJsonData(T data) + { + var options = new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }; + + return System.Text.Json.JsonSerializer.Serialize(data, options); + } + + public static ICollection AddIfNotExists(this ICollection list, T item) + { + if (list.Contains(item)) + { + return list; + } + list.Add(item); + return list; + } + + public static string DecodeMessageText(this QueueMessage message) + { + var text = message?.MessageText; + if (string.IsNullOrWhiteSpace(text)) + { + return text; + } + try + { + return Base64DecodeString(text); + } + catch + { + return text; + } + } + + public static bool FixedTimeEquals(string input1, string input2) + { + return CryptographicOperations.FixedTimeEquals( + Encoding.UTF8.GetBytes(input1), Encoding.UTF8.GetBytes(input2)); + } + + public static string ObfuscateEmail(string email) + { + if (email == null) + { + return email; + } + + var emailParts = email.Split('@', StringSplitOptions.RemoveEmptyEntries); + + if (emailParts.Length != 2) + { + return email; + } + + var username = emailParts[0]; + + if (username.Length < 2) + { + return email; + } + + var sb = new StringBuilder(); + sb.Append(emailParts[0][..2]); + for (var i = 2; i < emailParts[0].Length; i++) + { + sb.Append('*'); + } + + return sb.Append('@') + .Append(emailParts[1]) + .ToString(); + + } } diff --git a/src/Core/Utilities/CurrentContextMiddleware.cs b/src/Core/Utilities/CurrentContextMiddleware.cs index bfba894dd..c1ac9322c 100644 --- a/src/Core/Utilities/CurrentContextMiddleware.cs +++ b/src/Core/Utilities/CurrentContextMiddleware.cs @@ -2,21 +2,20 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Http; -namespace Bit.Core.Utilities +namespace Bit.Core.Utilities; + +public class CurrentContextMiddleware { - public class CurrentContextMiddleware + private readonly RequestDelegate _next; + + public CurrentContextMiddleware(RequestDelegate next) { - private readonly RequestDelegate _next; + _next = next; + } - public CurrentContextMiddleware(RequestDelegate next) - { - _next = next; - } - - public async Task Invoke(HttpContext httpContext, ICurrentContext currentContext, GlobalSettings globalSettings) - { - await currentContext.BuildAsync(httpContext, globalSettings); - await _next.Invoke(httpContext); - } + public async Task Invoke(HttpContext httpContext, ICurrentContext currentContext, GlobalSettings globalSettings) + { + await currentContext.BuildAsync(httpContext, globalSettings); + await _next.Invoke(httpContext); } } diff --git a/src/Core/Utilities/CustomIpRateLimitMiddleware.cs b/src/Core/Utilities/CustomIpRateLimitMiddleware.cs index 529495e09..5fb82cac0 100644 --- a/src/Core/Utilities/CustomIpRateLimitMiddleware.cs +++ b/src/Core/Utilities/CustomIpRateLimitMiddleware.cs @@ -6,86 +6,85 @@ using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -namespace Bit.Core.Utilities +namespace Bit.Core.Utilities; + +public class CustomIpRateLimitMiddleware : IpRateLimitMiddleware { - public class CustomIpRateLimitMiddleware : IpRateLimitMiddleware + private readonly IBlockIpService _blockIpService; + private readonly ILogger _logger; + private readonly IDistributedCache _distributedCache; + private readonly IpRateLimitOptions _options; + + public CustomIpRateLimitMiddleware( + IDistributedCache distributedCache, + IBlockIpService blockIpService, + RequestDelegate next, + IProcessingStrategy processingStrategy, + IRateLimitConfiguration rateLimitConfiguration, + IOptions options, + IIpPolicyStore policyStore, + ILogger logger) + : base(next, processingStrategy, options, policyStore, rateLimitConfiguration, logger) { - private readonly IBlockIpService _blockIpService; - private readonly ILogger _logger; - private readonly IDistributedCache _distributedCache; - private readonly IpRateLimitOptions _options; + _distributedCache = distributedCache; + _blockIpService = blockIpService; + _options = options.Value; + _logger = logger; + } - public CustomIpRateLimitMiddleware( - IDistributedCache distributedCache, - IBlockIpService blockIpService, - RequestDelegate next, - IProcessingStrategy processingStrategy, - IRateLimitConfiguration rateLimitConfiguration, - IOptions options, - IIpPolicyStore policyStore, - ILogger logger) - : base(next, processingStrategy, options, policyStore, rateLimitConfiguration, logger) + public override Task ReturnQuotaExceededResponse(HttpContext httpContext, RateLimitRule rule, string retryAfter) + { + var message = string.IsNullOrWhiteSpace(_options.QuotaExceededMessage) + ? $"Slow down! Too many requests. Try again in {rule.Period}." + : _options.QuotaExceededMessage; + httpContext.Response.Headers["Retry-After"] = retryAfter; + httpContext.Response.StatusCode = _options.HttpStatusCode; + var errorModel = new ErrorResponseModel { Message = message }; + return httpContext.Response.WriteAsJsonAsync(errorModel, httpContext.RequestAborted); + } + + protected override void LogBlockedRequest(HttpContext httpContext, ClientRequestIdentity identity, + RateLimitCounter counter, RateLimitRule rule) + { + base.LogBlockedRequest(httpContext, identity, counter, rule); + var key = $"blockedIp_{identity.ClientIp}"; + + _distributedCache.TryGetValue(key, out int blockedCount); + + blockedCount++; + if (blockedCount > 10) { - _distributedCache = distributedCache; - _blockIpService = blockIpService; - _options = options.Value; - _logger = logger; + _blockIpService.BlockIpAsync(identity.ClientIp, false); + _logger.LogInformation(Constants.BypassFiltersEventId, null, + "Banned {0}. \nInfo: \n{1}", identity.ClientIp, GetRequestInfo(httpContext)); } - - public override Task ReturnQuotaExceededResponse(HttpContext httpContext, RateLimitRule rule, string retryAfter) + else { - var message = string.IsNullOrWhiteSpace(_options.QuotaExceededMessage) - ? $"Slow down! Too many requests. Try again in {rule.Period}." - : _options.QuotaExceededMessage; - httpContext.Response.Headers["Retry-After"] = retryAfter; - httpContext.Response.StatusCode = _options.HttpStatusCode; - var errorModel = new ErrorResponseModel { Message = message }; - return httpContext.Response.WriteAsJsonAsync(errorModel, httpContext.RequestAborted); - } - - protected override void LogBlockedRequest(HttpContext httpContext, ClientRequestIdentity identity, - RateLimitCounter counter, RateLimitRule rule) - { - base.LogBlockedRequest(httpContext, identity, counter, rule); - var key = $"blockedIp_{identity.ClientIp}"; - - _distributedCache.TryGetValue(key, out int blockedCount); - - blockedCount++; - if (blockedCount > 10) - { - _blockIpService.BlockIpAsync(identity.ClientIp, false); - _logger.LogInformation(Constants.BypassFiltersEventId, null, - "Banned {0}. \nInfo: \n{1}", identity.ClientIp, GetRequestInfo(httpContext)); - } - else - { - _logger.LogInformation(Constants.BypassFiltersEventId, null, - "Request blocked {0}. \nInfo: \n{1}", identity.ClientIp, GetRequestInfo(httpContext)); - _distributedCache.Set(key, blockedCount, - new DistributedCacheEntryOptions().SetSlidingExpiration(new TimeSpan(0, 5, 0))); - } - } - - private string GetRequestInfo(HttpContext httpContext) - { - if (httpContext == null || httpContext.Request == null) - { - return null; - } - - var s = string.Empty; - foreach (var header in httpContext.Request.Headers) - { - s += $"Header \"{header.Key}\": {header.Value} \n"; - } - - foreach (var query in httpContext.Request.Query) - { - s += $"Query \"{query.Key}\": {query.Value} \n"; - } - - return s; + _logger.LogInformation(Constants.BypassFiltersEventId, null, + "Request blocked {0}. \nInfo: \n{1}", identity.ClientIp, GetRequestInfo(httpContext)); + _distributedCache.Set(key, blockedCount, + new DistributedCacheEntryOptions().SetSlidingExpiration(new TimeSpan(0, 5, 0))); } } + + private string GetRequestInfo(HttpContext httpContext) + { + if (httpContext == null || httpContext.Request == null) + { + return null; + } + + var s = string.Empty; + foreach (var header in httpContext.Request.Headers) + { + s += $"Header \"{header.Key}\": {header.Value} \n"; + } + + foreach (var query in httpContext.Request.Query) + { + s += $"Query \"{query.Key}\": {query.Value} \n"; + } + + return s; + } } diff --git a/src/Core/Utilities/DistributedCacheExtensions.cs b/src/Core/Utilities/DistributedCacheExtensions.cs index d27d0ee46..28282b6a4 100644 --- a/src/Core/Utilities/DistributedCacheExtensions.cs +++ b/src/Core/Utilities/DistributedCacheExtensions.cs @@ -1,48 +1,47 @@ using System.Text.Json; using Microsoft.Extensions.Caching.Distributed; -namespace Bit.Core.Utilities +namespace Bit.Core.Utilities; + +public static class DistributedCacheExtensions { - public static class DistributedCacheExtensions + public static void Set(this IDistributedCache cache, string key, T value) { - public static void Set(this IDistributedCache cache, string key, T value) - { - Set(cache, key, value, new DistributedCacheEntryOptions()); - } + Set(cache, key, value, new DistributedCacheEntryOptions()); + } - public static void Set(this IDistributedCache cache, string key, T value, - DistributedCacheEntryOptions options) - { - var bytes = JsonSerializer.SerializeToUtf8Bytes(value); - cache.Set(key, bytes, options); - } + public static void Set(this IDistributedCache cache, string key, T value, + DistributedCacheEntryOptions options) + { + var bytes = JsonSerializer.SerializeToUtf8Bytes(value); + cache.Set(key, bytes, options); + } - public static Task SetAsync(this IDistributedCache cache, string key, T value) - { - return SetAsync(cache, key, value, new DistributedCacheEntryOptions()); - } + public static Task SetAsync(this IDistributedCache cache, string key, T value) + { + return SetAsync(cache, key, value, new DistributedCacheEntryOptions()); + } - public static Task SetAsync(this IDistributedCache cache, string key, T value, - DistributedCacheEntryOptions options) - { - var bytes = JsonSerializer.SerializeToUtf8Bytes(value); - return cache.SetAsync(key, bytes, options); - } + public static Task SetAsync(this IDistributedCache cache, string key, T value, + DistributedCacheEntryOptions options) + { + var bytes = JsonSerializer.SerializeToUtf8Bytes(value); + return cache.SetAsync(key, bytes, options); + } - public static bool TryGetValue(this IDistributedCache cache, string key, out T value) + public static bool TryGetValue(this IDistributedCache cache, string key, out T value) + { + var val = cache.Get(key); + value = default; + if (val == null) return false; + try { - var val = cache.Get(key); - value = default; - if (val == null) return false; - try - { - value = JsonSerializer.Deserialize(val); - } - catch - { - return false; - } - return true; + value = JsonSerializer.Deserialize(val); } + catch + { + return false; + } + return true; } } diff --git a/src/Core/Utilities/DuoApi.cs b/src/Core/Utilities/DuoApi.cs index c98c98938..b5a3f040d 100644 --- a/src/Core/Utilities/DuoApi.cs +++ b/src/Core/Utilities/DuoApi.cs @@ -16,294 +16,293 @@ using System.Text.Json; using System.Text.RegularExpressions; using System.Web; -namespace Bit.Core.Utilities.Duo +namespace Bit.Core.Utilities.Duo; + +public class DuoApi { - public class DuoApi + private const string UrlScheme = "https"; + private const string UserAgent = "Bitwarden_DuoAPICSharp/1.0 (.NET Core)"; + + private readonly string _host; + private readonly string _ikey; + private readonly string _skey; + + public DuoApi(string ikey, string skey, string host) { - private const string UrlScheme = "https"; - private const string UserAgent = "Bitwarden_DuoAPICSharp/1.0 (.NET Core)"; + _ikey = ikey; + _skey = skey; + _host = host; - private readonly string _host; - private readonly string _ikey; - private readonly string _skey; - - public DuoApi(string ikey, string skey, string host) + if (!ValidHost(host)) { - _ikey = ikey; - _skey = skey; - _host = host; + throw new DuoException("Invalid Duo host configured.", new ArgumentException(nameof(host))); + } + } - if (!ValidHost(host)) - { - throw new DuoException("Invalid Duo host configured.", new ArgumentException(nameof(host))); - } + public static bool ValidHost(string host) + { + if (Uri.TryCreate($"https://{host}", UriKind.Absolute, out var uri)) + { + return (string.IsNullOrWhiteSpace(uri.PathAndQuery) || uri.PathAndQuery == "/") && + uri.Host.StartsWith("api-") && + (uri.Host.EndsWith(".duosecurity.com") || uri.Host.EndsWith(".duofederal.com")); + } + return false; + } + + public static string CanonicalizeParams(Dictionary parameters) + { + var ret = new List(); + foreach (var pair in parameters) + { + var p = string.Format("{0}={1}", HttpUtility.UrlEncode(pair.Key), HttpUtility.UrlEncode(pair.Value)); + // Signatures require upper-case hex digits. + p = Regex.Replace(p, "(%[0-9A-Fa-f][0-9A-Fa-f])", c => c.Value.ToUpperInvariant()); + // Escape only the expected characters. + p = Regex.Replace(p, "([!'()*])", c => "%" + Convert.ToByte(c.Value[0]).ToString("X")); + p = p.Replace("%7E", "~"); + // UrlEncode converts space (" ") to "+". The + // signature algorithm requires "%20" instead. Actual + // + has already been replaced with %2B. + p = p.Replace("+", "%20"); + ret.Add(p); } - public static bool ValidHost(string host) + ret.Sort(StringComparer.Ordinal); + return string.Join("&", ret.ToArray()); + } + + protected string CanonicalizeRequest(string method, string path, string canonParams, string date) + { + string[] lines = { + date, + method.ToUpperInvariant(), + _host.ToLower(), + path, + canonParams, + }; + return string.Join("\n", lines); + } + + public string Sign(string method, string path, string canonParams, string date) + { + var canon = CanonicalizeRequest(method, path, canonParams, date); + var sig = HmacSign(canon); + var auth = string.Concat(_ikey, ':', sig); + return string.Concat("Basic ", Encode64(auth)); + } + + public string ApiCall(string method, string path, Dictionary parameters = null) + { + return ApiCall(method, path, parameters, 0, out var statusCode); + } + + /// The request timeout, in milliseconds. + /// Specify 0 to use the system-default timeout. Use caution if + /// you choose to specify a custom timeout - some API + /// calls (particularly in the Auth APIs) will not + /// return a response until an out-of-band authentication process + /// has completed. In some cases, this may take as much as a + /// small number of minutes. + public string ApiCall(string method, string path, Dictionary parameters, int timeout, + out HttpStatusCode statusCode) + { + if (parameters == null) { - if (Uri.TryCreate($"https://{host}", UriKind.Absolute, out var uri)) - { - return (string.IsNullOrWhiteSpace(uri.PathAndQuery) || uri.PathAndQuery == "/") && - uri.Host.StartsWith("api-") && - (uri.Host.EndsWith(".duosecurity.com") || uri.Host.EndsWith(".duofederal.com")); - } - return false; + parameters = new Dictionary(); } - public static string CanonicalizeParams(Dictionary parameters) + var canonParams = CanonicalizeParams(parameters); + var query = string.Empty; + if (!method.Equals("POST") && !method.Equals("PUT")) { - var ret = new List(); - foreach (var pair in parameters) + if (parameters.Count > 0) { - var p = string.Format("{0}={1}", HttpUtility.UrlEncode(pair.Key), HttpUtility.UrlEncode(pair.Value)); - // Signatures require upper-case hex digits. - p = Regex.Replace(p, "(%[0-9A-Fa-f][0-9A-Fa-f])", c => c.Value.ToUpperInvariant()); - // Escape only the expected characters. - p = Regex.Replace(p, "([!'()*])", c => "%" + Convert.ToByte(c.Value[0]).ToString("X")); - p = p.Replace("%7E", "~"); - // UrlEncode converts space (" ") to "+". The - // signature algorithm requires "%20" instead. Actual - // + has already been replaced with %2B. - p = p.Replace("+", "%20"); - ret.Add(p); - } - - ret.Sort(StringComparer.Ordinal); - return string.Join("&", ret.ToArray()); - } - - protected string CanonicalizeRequest(string method, string path, string canonParams, string date) - { - string[] lines = { - date, - method.ToUpperInvariant(), - _host.ToLower(), - path, - canonParams, - }; - return string.Join("\n", lines); - } - - public string Sign(string method, string path, string canonParams, string date) - { - var canon = CanonicalizeRequest(method, path, canonParams, date); - var sig = HmacSign(canon); - var auth = string.Concat(_ikey, ':', sig); - return string.Concat("Basic ", Encode64(auth)); - } - - public string ApiCall(string method, string path, Dictionary parameters = null) - { - return ApiCall(method, path, parameters, 0, out var statusCode); - } - - /// The request timeout, in milliseconds. - /// Specify 0 to use the system-default timeout. Use caution if - /// you choose to specify a custom timeout - some API - /// calls (particularly in the Auth APIs) will not - /// return a response until an out-of-band authentication process - /// has completed. In some cases, this may take as much as a - /// small number of minutes. - public string ApiCall(string method, string path, Dictionary parameters, int timeout, - out HttpStatusCode statusCode) - { - if (parameters == null) - { - parameters = new Dictionary(); - } - - var canonParams = CanonicalizeParams(parameters); - var query = string.Empty; - if (!method.Equals("POST") && !method.Equals("PUT")) - { - if (parameters.Count > 0) - { - query = "?" + canonParams; - } - } - var url = string.Format("{0}://{1}{2}{3}", UrlScheme, _host, path, query); - - var dateString = RFC822UtcNow(); - var auth = Sign(method, path, canonParams, dateString); - - var request = (HttpWebRequest)WebRequest.Create(url); - request.Method = method; - request.Accept = "application/json"; - request.Headers.Add("Authorization", auth); - request.Headers.Add("X-Duo-Date", dateString); - request.UserAgent = UserAgent; - - if (method.Equals("POST") || method.Equals("PUT")) - { - var data = Encoding.UTF8.GetBytes(canonParams); - request.ContentType = "application/x-www-form-urlencoded"; - request.ContentLength = data.Length; - using (var requestStream = request.GetRequestStream()) - { - requestStream.Write(data, 0, data.Length); - } - } - if (timeout > 0) - { - request.Timeout = timeout; - } - - // Do the request and process the result. - HttpWebResponse response; - try - { - response = (HttpWebResponse)request.GetResponse(); - } - catch (WebException ex) - { - response = (HttpWebResponse)ex.Response; - if (response == null) - { - throw; - } - } - using (var reader = new StreamReader(response.GetResponseStream())) - { - statusCode = response.StatusCode; - return reader.ReadToEnd(); + query = "?" + canonParams; } } + var url = string.Format("{0}://{1}{2}{3}", UrlScheme, _host, path, query); - public T JSONApiCall(string method, string path, Dictionary parameters = null) - where T : class + var dateString = RFC822UtcNow(); + var auth = Sign(method, path, canonParams, dateString); + + var request = (HttpWebRequest)WebRequest.Create(url); + request.Method = method; + request.Accept = "application/json"; + request.Headers.Add("Authorization", auth); + request.Headers.Add("X-Duo-Date", dateString); + request.UserAgent = UserAgent; + + if (method.Equals("POST") || method.Equals("PUT")) { - return JSONApiCall(method, path, parameters, 0); + var data = Encoding.UTF8.GetBytes(canonParams); + request.ContentType = "application/x-www-form-urlencoded"; + request.ContentLength = data.Length; + using (var requestStream = request.GetRequestStream()) + { + requestStream.Write(data, 0, data.Length); + } + } + if (timeout > 0) + { + request.Timeout = timeout; } - /// The request timeout, in milliseconds. - /// Specify 0 to use the system-default timeout. Use caution if - /// you choose to specify a custom timeout - some API - /// calls (particularly in the Auth APIs) will not - /// return a response until an out-of-band authentication process - /// has completed. In some cases, this may take as much as a - /// small number of minutes. - public T JSONApiCall(string method, string path, Dictionary parameters, int timeout) - where T : class + // Do the request and process the result. + HttpWebResponse response; + try { - var res = ApiCall(method, path, parameters, timeout, out var statusCode); - try - { - // TODO: We should deserialize this into our own DTO and not work on dictionaries. - var dict = JsonSerializer.Deserialize>(res); - if (dict["stat"].ToString() == "OK") - { - return JsonSerializer.Deserialize(dict["response"].ToString()); - } - - var check = ToNullableInt(dict["code"].ToString()); - var code = check.GetValueOrDefault(0); - var messageDetail = string.Empty; - if (dict.ContainsKey("message_detail")) - { - messageDetail = dict["message_detail"].ToString(); - } - throw new ApiException(code, (int)statusCode, dict["message"].ToString(), messageDetail); - } - catch (ApiException) + response = (HttpWebResponse)request.GetResponse(); + } + catch (WebException ex) + { + response = (HttpWebResponse)ex.Response; + if (response == null) { throw; } - catch (Exception e) - { - throw new BadResponseException((int)statusCode, e); - } } - - private int? ToNullableInt(string s) + using (var reader = new StreamReader(response.GetResponseStream())) { - int i; - if (int.TryParse(s, out i)) - { - return i; - } - return null; - } - - private string HmacSign(string data) - { - var keyBytes = Encoding.ASCII.GetBytes(_skey); - var dataBytes = Encoding.ASCII.GetBytes(data); - - using (var hmac = new HMACSHA1(keyBytes)) - { - var hash = hmac.ComputeHash(dataBytes); - var hex = BitConverter.ToString(hash); - return hex.Replace("-", string.Empty).ToLower(); - } - } - - private static string Encode64(string plaintext) - { - var plaintextBytes = Encoding.ASCII.GetBytes(plaintext); - return Convert.ToBase64String(plaintextBytes); - } - - private static string RFC822UtcNow() - { - // Can't use the "zzzz" format because it adds a ":" - // between the offset's hours and minutes. - var dateString = DateTime.UtcNow.ToString("ddd, dd MMM yyyy HH:mm:ss", CultureInfo.InvariantCulture); - var offset = 0; - var zone = "+" + offset.ToString(CultureInfo.InvariantCulture).PadLeft(2, '0'); - dateString += " " + zone.PadRight(5, '0'); - return dateString; + statusCode = response.StatusCode; + return reader.ReadToEnd(); } } - public class DuoException : Exception + public T JSONApiCall(string method, string path, Dictionary parameters = null) + where T : class { - public int HttpStatus { get; private set; } - - public DuoException(string message, Exception inner) - : base(message, inner) - { } - - public DuoException(int httpStatus, string message, Exception inner) - : base(message, inner) - { - HttpStatus = httpStatus; - } + return JSONApiCall(method, path, parameters, 0); } - public class ApiException : DuoException + /// The request timeout, in milliseconds. + /// Specify 0 to use the system-default timeout. Use caution if + /// you choose to specify a custom timeout - some API + /// calls (particularly in the Auth APIs) will not + /// return a response until an out-of-band authentication process + /// has completed. In some cases, this may take as much as a + /// small number of minutes. + public T JSONApiCall(string method, string path, Dictionary parameters, int timeout) + where T : class { - public int Code { get; private set; } - public string ApiMessage { get; private set; } - public string ApiMessageDetail { get; private set; } - - public ApiException(int code, int httpStatus, string apiMessage, string apiMessageDetail) - : base(httpStatus, FormatMessage(code, apiMessage, apiMessageDetail), null) + var res = ApiCall(method, path, parameters, timeout, out var statusCode); + try { - Code = code; - ApiMessage = apiMessage; - ApiMessageDetail = apiMessageDetail; - } - - private static string FormatMessage(int code, string apiMessage, string apiMessageDetail) - { - return string.Format("Duo API Error {0}: '{1}' ('{2}')", code, apiMessage, apiMessageDetail); - } - } - - public class BadResponseException : DuoException - { - public BadResponseException(int httpStatus, Exception inner) - : base(httpStatus, FormatMessage(httpStatus, inner), inner) - { } - - private static string FormatMessage(int httpStatus, Exception inner) - { - var innerMessage = "(null)"; - if (inner != null) + // TODO: We should deserialize this into our own DTO and not work on dictionaries. + var dict = JsonSerializer.Deserialize>(res); + if (dict["stat"].ToString() == "OK") { - innerMessage = string.Format("'{0}'", inner.Message); + return JsonSerializer.Deserialize(dict["response"].ToString()); } - return string.Format("Got error {0} with HTTP Status {1}", innerMessage, httpStatus); + + var check = ToNullableInt(dict["code"].ToString()); + var code = check.GetValueOrDefault(0); + var messageDetail = string.Empty; + if (dict.ContainsKey("message_detail")) + { + messageDetail = dict["message_detail"].ToString(); + } + throw new ApiException(code, (int)statusCode, dict["message"].ToString(), messageDetail); } + catch (ApiException) + { + throw; + } + catch (Exception e) + { + throw new BadResponseException((int)statusCode, e); + } + } + + private int? ToNullableInt(string s) + { + int i; + if (int.TryParse(s, out i)) + { + return i; + } + return null; + } + + private string HmacSign(string data) + { + var keyBytes = Encoding.ASCII.GetBytes(_skey); + var dataBytes = Encoding.ASCII.GetBytes(data); + + using (var hmac = new HMACSHA1(keyBytes)) + { + var hash = hmac.ComputeHash(dataBytes); + var hex = BitConverter.ToString(hash); + return hex.Replace("-", string.Empty).ToLower(); + } + } + + private static string Encode64(string plaintext) + { + var plaintextBytes = Encoding.ASCII.GetBytes(plaintext); + return Convert.ToBase64String(plaintextBytes); + } + + private static string RFC822UtcNow() + { + // Can't use the "zzzz" format because it adds a ":" + // between the offset's hours and minutes. + var dateString = DateTime.UtcNow.ToString("ddd, dd MMM yyyy HH:mm:ss", CultureInfo.InvariantCulture); + var offset = 0; + var zone = "+" + offset.ToString(CultureInfo.InvariantCulture).PadLeft(2, '0'); + dateString += " " + zone.PadRight(5, '0'); + return dateString; + } +} + +public class DuoException : Exception +{ + public int HttpStatus { get; private set; } + + public DuoException(string message, Exception inner) + : base(message, inner) + { } + + public DuoException(int httpStatus, string message, Exception inner) + : base(message, inner) + { + HttpStatus = httpStatus; + } +} + +public class ApiException : DuoException +{ + public int Code { get; private set; } + public string ApiMessage { get; private set; } + public string ApiMessageDetail { get; private set; } + + public ApiException(int code, int httpStatus, string apiMessage, string apiMessageDetail) + : base(httpStatus, FormatMessage(code, apiMessage, apiMessageDetail), null) + { + Code = code; + ApiMessage = apiMessage; + ApiMessageDetail = apiMessageDetail; + } + + private static string FormatMessage(int code, string apiMessage, string apiMessageDetail) + { + return string.Format("Duo API Error {0}: '{1}' ('{2}')", code, apiMessage, apiMessageDetail); + } +} + +public class BadResponseException : DuoException +{ + public BadResponseException(int httpStatus, Exception inner) + : base(httpStatus, FormatMessage(httpStatus, inner), inner) + { } + + private static string FormatMessage(int httpStatus, Exception inner) + { + var innerMessage = "(null)"; + if (inner != null) + { + innerMessage = string.Format("'{0}'", inner.Message); + } + return string.Format("Got error {0} with HTTP Status {1}", innerMessage, httpStatus); } } diff --git a/src/Core/Utilities/DuoWeb.cs b/src/Core/Utilities/DuoWeb.cs index f8259d200..151f71a15 100644 --- a/src/Core/Utilities/DuoWeb.cs +++ b/src/Core/Utilities/DuoWeb.cs @@ -36,206 +36,205 @@ THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. using System.Security.Cryptography; using System.Text; -namespace Bit.Core.Utilities.Duo +namespace Bit.Core.Utilities.Duo; + +public static class DuoWeb { - public static class DuoWeb + private const string DuoProfix = "TX"; + private const string AppPrefix = "APP"; + private const string AuthPrefix = "AUTH"; + private const int DuoExpire = 300; + private const int AppExpire = 3600; + private const int IKeyLength = 20; + private const int SKeyLength = 40; + private const int AKeyLength = 40; + + public static string ErrorUser = "ERR|The username passed to sign_request() is invalid."; + public static string ErrorIKey = "ERR|The Duo integration key passed to sign_request() is invalid."; + public static string ErrorSKey = "ERR|The Duo secret key passed to sign_request() is invalid."; + public static string ErrorAKey = "ERR|The application secret key passed to sign_request() must be at least " + + "40 characters."; + public static string ErrorUnknown = "ERR|An unknown error has occurred."; + + // throw on invalid bytes + private static Encoding _encoding = new UTF8Encoding(false, true); + private static DateTime _epoc = new DateTime(1970, 1, 1); + + /// + /// Generate a signed request for Duo authentication. + /// The returned value should be passed into the Duo.init() call + /// in the rendered web page used for Duo authentication. + /// + /// Duo integration key + /// Duo secret key + /// Application secret key + /// Primary-authenticated username + /// (optional) The current UTC time + /// signed request + public static string SignRequest(string ikey, string skey, string akey, string username, + DateTime? currentTime = null) { - private const string DuoProfix = "TX"; - private const string AppPrefix = "APP"; - private const string AuthPrefix = "AUTH"; - private const int DuoExpire = 300; - private const int AppExpire = 3600; - private const int IKeyLength = 20; - private const int SKeyLength = 40; - private const int AKeyLength = 40; + string duoSig; + string appSig; - public static string ErrorUser = "ERR|The username passed to sign_request() is invalid."; - public static string ErrorIKey = "ERR|The Duo integration key passed to sign_request() is invalid."; - public static string ErrorSKey = "ERR|The Duo secret key passed to sign_request() is invalid."; - public static string ErrorAKey = "ERR|The application secret key passed to sign_request() must be at least " + - "40 characters."; - public static string ErrorUnknown = "ERR|An unknown error has occurred."; + var currentTimeValue = currentTime ?? DateTime.UtcNow; - // throw on invalid bytes - private static Encoding _encoding = new UTF8Encoding(false, true); - private static DateTime _epoc = new DateTime(1970, 1, 1); - - /// - /// Generate a signed request for Duo authentication. - /// The returned value should be passed into the Duo.init() call - /// in the rendered web page used for Duo authentication. - /// - /// Duo integration key - /// Duo secret key - /// Application secret key - /// Primary-authenticated username - /// (optional) The current UTC time - /// signed request - public static string SignRequest(string ikey, string skey, string akey, string username, - DateTime? currentTime = null) + if (username == string.Empty) { - string duoSig; - string appSig; - - var currentTimeValue = currentTime ?? DateTime.UtcNow; - - if (username == string.Empty) - { - return ErrorUser; - } - if (username.Contains("|")) - { - return ErrorUser; - } - if (ikey.Length != IKeyLength) - { - return ErrorIKey; - } - if (skey.Length != SKeyLength) - { - return ErrorSKey; - } - if (akey.Length < AKeyLength) - { - return ErrorAKey; - } - - try - { - duoSig = SignVals(skey, username, ikey, DuoProfix, DuoExpire, currentTimeValue); - appSig = SignVals(akey, username, ikey, AppPrefix, AppExpire, currentTimeValue); - } - catch - { - return ErrorUnknown; - } - - return $"{duoSig}:{appSig}"; + return ErrorUser; + } + if (username.Contains("|")) + { + return ErrorUser; + } + if (ikey.Length != IKeyLength) + { + return ErrorIKey; + } + if (skey.Length != SKeyLength) + { + return ErrorSKey; + } + if (akey.Length < AKeyLength) + { + return ErrorAKey; } - /// - /// Validate the signed response returned from Duo. - /// Returns the username of the authenticated user, or null. - /// - /// Duo integration key - /// Duo secret key - /// Application secret key - /// The signed response POST'ed to the server - /// (optional) The current UTC time - /// authenticated username, or null - public static string VerifyResponse(string ikey, string skey, string akey, string sigResponse, - DateTime? currentTime = null) + try { - string authUser = null; - string appUser = null; - var currentTimeValue = currentTime ?? DateTime.UtcNow; - - try - { - var sigs = sigResponse.Split(':'); - var authSig = sigs[0]; - var appSig = sigs[1]; - - authUser = ParseVals(skey, authSig, AuthPrefix, ikey, currentTimeValue); - appUser = ParseVals(akey, appSig, AppPrefix, ikey, currentTimeValue); - } - catch - { - return null; - } - - if (authUser != appUser) - { - return null; - } - - return authUser; + duoSig = SignVals(skey, username, ikey, DuoProfix, DuoExpire, currentTimeValue); + appSig = SignVals(akey, username, ikey, AppPrefix, AppExpire, currentTimeValue); + } + catch + { + return ErrorUnknown; } - private static string SignVals(string key, string username, string ikey, string prefix, long expire, - DateTime currentTime) + return $"{duoSig}:{appSig}"; + } + + /// + /// Validate the signed response returned from Duo. + /// Returns the username of the authenticated user, or null. + /// + /// Duo integration key + /// Duo secret key + /// Application secret key + /// The signed response POST'ed to the server + /// (optional) The current UTC time + /// authenticated username, or null + public static string VerifyResponse(string ikey, string skey, string akey, string sigResponse, + DateTime? currentTime = null) + { + string authUser = null; + string appUser = null; + var currentTimeValue = currentTime ?? DateTime.UtcNow; + + try { - var ts = (long)(currentTime - _epoc).TotalSeconds; - expire = ts + expire; - var val = $"{username}|{ikey}|{expire.ToString()}"; - var cookie = $"{prefix}|{Encode64(val)}"; - var sig = Sign(key, cookie); - return $"{cookie}|{sig}"; + var sigs = sigResponse.Split(':'); + var authSig = sigs[0]; + var appSig = sigs[1]; + + authUser = ParseVals(skey, authSig, AuthPrefix, ikey, currentTimeValue); + appUser = ParseVals(akey, appSig, AppPrefix, ikey, currentTimeValue); + } + catch + { + return null; } - private static string ParseVals(string key, string val, string prefix, string ikey, DateTime currentTime) + if (authUser != appUser) { - var ts = (long)(currentTime - _epoc).TotalSeconds; - - var parts = val.Split('|'); - if (parts.Length != 3) - { - return null; - } - - var uPrefix = parts[0]; - var uB64 = parts[1]; - var uSig = parts[2]; - - var sig = Sign(key, $"{uPrefix}|{uB64}"); - if (Sign(key, sig) != Sign(key, uSig)) - { - return null; - } - - if (uPrefix != prefix) - { - return null; - } - - var cookie = Decode64(uB64); - var cookieParts = cookie.Split('|'); - if (cookieParts.Length != 3) - { - return null; - } - - var username = cookieParts[0]; - var uIKey = cookieParts[1]; - var expire = cookieParts[2]; - - if (uIKey != ikey) - { - return null; - } - - var expireTs = Convert.ToInt32(expire); - if (ts >= expireTs) - { - return null; - } - - return username; + return null; } - private static string Sign(string skey, string data) - { - var keyBytes = Encoding.ASCII.GetBytes(skey); - var dataBytes = Encoding.ASCII.GetBytes(data); + return authUser; + } - using (var hmac = new HMACSHA1(keyBytes)) - { - var hash = hmac.ComputeHash(dataBytes); - var hex = BitConverter.ToString(hash); - return hex.Replace("-", "").ToLower(); - } + private static string SignVals(string key, string username, string ikey, string prefix, long expire, + DateTime currentTime) + { + var ts = (long)(currentTime - _epoc).TotalSeconds; + expire = ts + expire; + var val = $"{username}|{ikey}|{expire.ToString()}"; + var cookie = $"{prefix}|{Encode64(val)}"; + var sig = Sign(key, cookie); + return $"{cookie}|{sig}"; + } + + private static string ParseVals(string key, string val, string prefix, string ikey, DateTime currentTime) + { + var ts = (long)(currentTime - _epoc).TotalSeconds; + + var parts = val.Split('|'); + if (parts.Length != 3) + { + return null; } - private static string Encode64(string plaintext) + var uPrefix = parts[0]; + var uB64 = parts[1]; + var uSig = parts[2]; + + var sig = Sign(key, $"{uPrefix}|{uB64}"); + if (Sign(key, sig) != Sign(key, uSig)) { - var plaintextBytes = _encoding.GetBytes(plaintext); - return Convert.ToBase64String(plaintextBytes); + return null; } - private static string Decode64(string encoded) + if (uPrefix != prefix) { - var plaintextBytes = Convert.FromBase64String(encoded); - return _encoding.GetString(plaintextBytes); + return null; + } + + var cookie = Decode64(uB64); + var cookieParts = cookie.Split('|'); + if (cookieParts.Length != 3) + { + return null; + } + + var username = cookieParts[0]; + var uIKey = cookieParts[1]; + var expire = cookieParts[2]; + + if (uIKey != ikey) + { + return null; + } + + var expireTs = Convert.ToInt32(expire); + if (ts >= expireTs) + { + return null; + } + + return username; + } + + private static string Sign(string skey, string data) + { + var keyBytes = Encoding.ASCII.GetBytes(skey); + var dataBytes = Encoding.ASCII.GetBytes(data); + + using (var hmac = new HMACSHA1(keyBytes)) + { + var hash = hmac.ComputeHash(dataBytes); + var hex = BitConverter.ToString(hash); + return hex.Replace("-", "").ToLower(); } } + + private static string Encode64(string plaintext) + { + var plaintextBytes = _encoding.GetBytes(plaintext); + return Convert.ToBase64String(plaintextBytes); + } + + private static string Decode64(string encoded) + { + var plaintextBytes = Convert.FromBase64String(encoded); + return _encoding.GetString(plaintextBytes); + } } diff --git a/src/Core/Utilities/EncryptedStringLengthAttribute.cs b/src/Core/Utilities/EncryptedStringLengthAttribute.cs index 46170487d..d7a8ffaec 100644 --- a/src/Core/Utilities/EncryptedStringLengthAttribute.cs +++ b/src/Core/Utilities/EncryptedStringLengthAttribute.cs @@ -1,17 +1,16 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Utilities -{ - public class EncryptedStringLengthAttribute : StringLengthAttribute - { - public EncryptedStringLengthAttribute(int maximumLength) - : base(maximumLength) - { } +namespace Bit.Core.Utilities; - public override string FormatErrorMessage(string name) - { - return string.Format("The field {0} exceeds the maximum encrypted value length of {1} characters.", - name, MaximumLength); - } +public class EncryptedStringLengthAttribute : StringLengthAttribute +{ + public EncryptedStringLengthAttribute(int maximumLength) + : base(maximumLength) + { } + + public override string FormatErrorMessage(string name) + { + return string.Format("The field {0} exceeds the maximum encrypted value length of {1} characters.", + name, MaximumLength); } } diff --git a/src/Core/Utilities/EncryptedValueAttribute.cs b/src/Core/Utilities/EncryptedValueAttribute.cs index 9ae43110b..ec0b218c5 100644 --- a/src/Core/Utilities/EncryptedValueAttribute.cs +++ b/src/Core/Utilities/EncryptedValueAttribute.cs @@ -1,138 +1,137 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Utilities +namespace Bit.Core.Utilities; + +/// +/// Validates a string that is in encrypted form: "head.b64iv=|b64ct=|b64mac=" +/// +public class EncryptedStringAttribute : ValidationAttribute { - /// - /// Validates a string that is in encrypted form: "head.b64iv=|b64ct=|b64mac=" - /// - public class EncryptedStringAttribute : ValidationAttribute + public EncryptedStringAttribute() + : base("{0} is not a valid encrypted string.") + { } + + public override bool IsValid(object value) { - public EncryptedStringAttribute() - : base("{0} is not a valid encrypted string.") - { } - - public override bool IsValid(object value) + if (value == null) { - if (value == null) - { - return true; - } + return true; + } - try - { - var encString = value?.ToString(); - if (string.IsNullOrWhiteSpace(encString)) - { - return false; - } - - var headerPieces = encString.Split('.'); - string[] encStringPieces = null; - var encType = Enums.EncryptionType.AesCbc256_B64; - - if (headerPieces.Length == 1) - { - encStringPieces = headerPieces[0].Split('|'); - if (encStringPieces.Length == 3) - { - encType = Enums.EncryptionType.AesCbc128_HmacSha256_B64; - } - else - { - encType = Enums.EncryptionType.AesCbc256_B64; - } - } - else if (headerPieces.Length == 2) - { - encStringPieces = headerPieces[1].Split('|'); - if (!Enum.TryParse(headerPieces[0], out encType)) - { - return false; - } - } - - switch (encType) - { - case Enums.EncryptionType.AesCbc256_B64: - case Enums.EncryptionType.Rsa2048_OaepSha1_HmacSha256_B64: - case Enums.EncryptionType.Rsa2048_OaepSha256_HmacSha256_B64: - if (encStringPieces.Length != 2) - { - return false; - } - break; - case Enums.EncryptionType.AesCbc128_HmacSha256_B64: - case Enums.EncryptionType.AesCbc256_HmacSha256_B64: - if (encStringPieces.Length != 3) - { - return false; - } - break; - case Enums.EncryptionType.Rsa2048_OaepSha256_B64: - case Enums.EncryptionType.Rsa2048_OaepSha1_B64: - if (encStringPieces.Length != 1) - { - return false; - } - break; - default: - return false; - } - - switch (encType) - { - case Enums.EncryptionType.AesCbc256_B64: - case Enums.EncryptionType.AesCbc128_HmacSha256_B64: - case Enums.EncryptionType.AesCbc256_HmacSha256_B64: - var iv = Convert.FromBase64String(encStringPieces[0]); - var ct = Convert.FromBase64String(encStringPieces[1]); - if (iv.Length < 1 || ct.Length < 1) - { - return false; - } - - if (encType == Enums.EncryptionType.AesCbc128_HmacSha256_B64 || - encType == Enums.EncryptionType.AesCbc256_HmacSha256_B64) - { - var mac = Convert.FromBase64String(encStringPieces[2]); - if (mac.Length < 1) - { - return false; - } - } - - break; - case Enums.EncryptionType.Rsa2048_OaepSha256_B64: - case Enums.EncryptionType.Rsa2048_OaepSha1_B64: - case Enums.EncryptionType.Rsa2048_OaepSha1_HmacSha256_B64: - case Enums.EncryptionType.Rsa2048_OaepSha256_HmacSha256_B64: - var rsaCt = Convert.FromBase64String(encStringPieces[0]); - if (rsaCt.Length < 1) - { - return false; - } - - if (encType == Enums.EncryptionType.Rsa2048_OaepSha1_HmacSha256_B64 || - encType == Enums.EncryptionType.Rsa2048_OaepSha256_HmacSha256_B64) - { - var mac = Convert.FromBase64String(encStringPieces[1]); - if (mac.Length < 1) - { - return false; - } - } - - break; - default: - return false; - } - } - catch + try + { + var encString = value?.ToString(); + if (string.IsNullOrWhiteSpace(encString)) { return false; } - return true; + var headerPieces = encString.Split('.'); + string[] encStringPieces = null; + var encType = Enums.EncryptionType.AesCbc256_B64; + + if (headerPieces.Length == 1) + { + encStringPieces = headerPieces[0].Split('|'); + if (encStringPieces.Length == 3) + { + encType = Enums.EncryptionType.AesCbc128_HmacSha256_B64; + } + else + { + encType = Enums.EncryptionType.AesCbc256_B64; + } + } + else if (headerPieces.Length == 2) + { + encStringPieces = headerPieces[1].Split('|'); + if (!Enum.TryParse(headerPieces[0], out encType)) + { + return false; + } + } + + switch (encType) + { + case Enums.EncryptionType.AesCbc256_B64: + case Enums.EncryptionType.Rsa2048_OaepSha1_HmacSha256_B64: + case Enums.EncryptionType.Rsa2048_OaepSha256_HmacSha256_B64: + if (encStringPieces.Length != 2) + { + return false; + } + break; + case Enums.EncryptionType.AesCbc128_HmacSha256_B64: + case Enums.EncryptionType.AesCbc256_HmacSha256_B64: + if (encStringPieces.Length != 3) + { + return false; + } + break; + case Enums.EncryptionType.Rsa2048_OaepSha256_B64: + case Enums.EncryptionType.Rsa2048_OaepSha1_B64: + if (encStringPieces.Length != 1) + { + return false; + } + break; + default: + return false; + } + + switch (encType) + { + case Enums.EncryptionType.AesCbc256_B64: + case Enums.EncryptionType.AesCbc128_HmacSha256_B64: + case Enums.EncryptionType.AesCbc256_HmacSha256_B64: + var iv = Convert.FromBase64String(encStringPieces[0]); + var ct = Convert.FromBase64String(encStringPieces[1]); + if (iv.Length < 1 || ct.Length < 1) + { + return false; + } + + if (encType == Enums.EncryptionType.AesCbc128_HmacSha256_B64 || + encType == Enums.EncryptionType.AesCbc256_HmacSha256_B64) + { + var mac = Convert.FromBase64String(encStringPieces[2]); + if (mac.Length < 1) + { + return false; + } + } + + break; + case Enums.EncryptionType.Rsa2048_OaepSha256_B64: + case Enums.EncryptionType.Rsa2048_OaepSha1_B64: + case Enums.EncryptionType.Rsa2048_OaepSha1_HmacSha256_B64: + case Enums.EncryptionType.Rsa2048_OaepSha256_HmacSha256_B64: + var rsaCt = Convert.FromBase64String(encStringPieces[0]); + if (rsaCt.Length < 1) + { + return false; + } + + if (encType == Enums.EncryptionType.Rsa2048_OaepSha1_HmacSha256_B64 || + encType == Enums.EncryptionType.Rsa2048_OaepSha256_HmacSha256_B64) + { + var mac = Convert.FromBase64String(encStringPieces[1]); + if (mac.Length < 1) + { + return false; + } + } + + break; + default: + return false; + } } + catch + { + return false; + } + + return true; } } diff --git a/src/Core/Utilities/EpochDateTimeJsonConverter.cs b/src/Core/Utilities/EpochDateTimeJsonConverter.cs index a9354fa6f..035da04a7 100644 --- a/src/Core/Utilities/EpochDateTimeJsonConverter.cs +++ b/src/Core/Utilities/EpochDateTimeJsonConverter.cs @@ -1,17 +1,16 @@ using System.Text.Json; using System.Text.Json.Serialization; -namespace Bit.Core.Utilities +namespace Bit.Core.Utilities; + +public class EpochDateTimeJsonConverter : JsonConverter { - public class EpochDateTimeJsonConverter : JsonConverter + public override DateTime Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { - public override DateTime Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - return CoreHelpers.FromEpocMilliseconds(reader.GetInt64()); - } - public override void Write(Utf8JsonWriter writer, DateTime value, JsonSerializerOptions options) - { - writer.WriteNumberValue(CoreHelpers.ToEpocMilliseconds(value)); - } + return CoreHelpers.FromEpocMilliseconds(reader.GetInt64()); + } + public override void Write(Utf8JsonWriter writer, DateTime value, JsonSerializerOptions options) + { + writer.WriteNumberValue(CoreHelpers.ToEpocMilliseconds(value)); } } diff --git a/src/Core/Utilities/HandlebarsObjectJsonConverter.cs b/src/Core/Utilities/HandlebarsObjectJsonConverter.cs index 2ba1d4002..5651da4dc 100644 --- a/src/Core/Utilities/HandlebarsObjectJsonConverter.cs +++ b/src/Core/Utilities/HandlebarsObjectJsonConverter.cs @@ -1,18 +1,17 @@ using System.Text.Json; using System.Text.Json.Serialization; -namespace Bit.Core.Utilities +namespace Bit.Core.Utilities; + +public class HandlebarsObjectJsonConverter : JsonConverter { - public class HandlebarsObjectJsonConverter : JsonConverter + public override bool CanConvert(Type typeToConvert) => true; + public override object Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { - public override bool CanConvert(Type typeToConvert) => true; - public override object Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - return JsonSerializer.Deserialize>(ref reader, options); - } - public override void Write(Utf8JsonWriter writer, object value, JsonSerializerOptions options) - { - JsonSerializer.Serialize(writer, value, options); - } + return JsonSerializer.Deserialize>(ref reader, options); + } + public override void Write(Utf8JsonWriter writer, object value, JsonSerializerOptions options) + { + JsonSerializer.Serialize(writer, value, options); } } diff --git a/src/Core/Utilities/HostBuilderExtensions.cs b/src/Core/Utilities/HostBuilderExtensions.cs index 2d54545ed..4806c4032 100644 --- a/src/Core/Utilities/HostBuilderExtensions.cs +++ b/src/Core/Utilities/HostBuilderExtensions.cs @@ -2,42 +2,41 @@ using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Hosting; -namespace Bit.Core.Utilities +namespace Bit.Core.Utilities; + +public static class HostBuilderExtensions { - public static class HostBuilderExtensions + public static IHostBuilder ConfigureCustomAppConfiguration(this IHostBuilder hostBuilder, string[] args) { - public static IHostBuilder ConfigureCustomAppConfiguration(this IHostBuilder hostBuilder, string[] args) + // Reload app configuration with SelfHosted overrides. + return hostBuilder.ConfigureAppConfiguration((hostingContext, config) => { - // Reload app configuration with SelfHosted overrides. - return hostBuilder.ConfigureAppConfiguration((hostingContext, config) => + if (Environment.GetEnvironmentVariable("globalSettings__selfHosted")?.ToLower() != "true") { - if (Environment.GetEnvironmentVariable("globalSettings__selfHosted")?.ToLower() != "true") + return; + } + + var env = hostingContext.HostingEnvironment; + + config.AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile($"appsettings.{env.EnvironmentName}.json", optional: true, reloadOnChange: true) + .AddJsonFile("appsettings.SelfHosted.json", optional: true, reloadOnChange: true); + + if (env.IsDevelopment()) + { + var appAssembly = Assembly.Load(new AssemblyName(env.ApplicationName)); + if (appAssembly != null) { - return; + config.AddUserSecrets(appAssembly, optional: true); } + } - var env = hostingContext.HostingEnvironment; + config.AddEnvironmentVariables(); - config.AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) - .AddJsonFile($"appsettings.{env.EnvironmentName}.json", optional: true, reloadOnChange: true) - .AddJsonFile("appsettings.SelfHosted.json", optional: true, reloadOnChange: true); - - if (env.IsDevelopment()) - { - var appAssembly = Assembly.Load(new AssemblyName(env.ApplicationName)); - if (appAssembly != null) - { - config.AddUserSecrets(appAssembly, optional: true); - } - } - - config.AddEnvironmentVariables(); - - if (args != null) - { - config.AddCommandLine(args); - } - }); - } + if (args != null) + { + config.AddCommandLine(args); + } + }); } } diff --git a/src/Core/Utilities/JsonHelpers.cs b/src/Core/Utilities/JsonHelpers.cs index b6a9481b6..ad7aefd25 100644 --- a/src/Core/Utilities/JsonHelpers.cs +++ b/src/Core/Utilities/JsonHelpers.cs @@ -3,205 +3,204 @@ using System.Text.Json; using System.Text.Json.Serialization; using NS = Newtonsoft.Json; -namespace Bit.Core.Utilities +namespace Bit.Core.Utilities; + +public static class JsonHelpers { - public static class JsonHelpers + public static JsonSerializerOptions Default { get; } + public static JsonSerializerOptions Indented { get; } + public static JsonSerializerOptions IgnoreCase { get; } + public static JsonSerializerOptions IgnoreWritingNull { get; } + public static JsonSerializerOptions CamelCase { get; } + public static JsonSerializerOptions IgnoreWritingNullAndCamelCase { get; } + + static JsonHelpers() { - public static JsonSerializerOptions Default { get; } - public static JsonSerializerOptions Indented { get; } - public static JsonSerializerOptions IgnoreCase { get; } - public static JsonSerializerOptions IgnoreWritingNull { get; } - public static JsonSerializerOptions CamelCase { get; } - public static JsonSerializerOptions IgnoreWritingNullAndCamelCase { get; } + Default = new JsonSerializerOptions(); - static JsonHelpers() + Indented = new JsonSerializerOptions { - Default = new JsonSerializerOptions(); - - Indented = new JsonSerializerOptions - { - WriteIndented = true, - }; - - IgnoreCase = new JsonSerializerOptions - { - PropertyNameCaseInsensitive = true, - }; - - IgnoreWritingNull = new JsonSerializerOptions - { - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, - }; - - CamelCase = new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }; - - IgnoreWritingNullAndCamelCase = new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, - }; - } - - [Obsolete("This is built into .NET 6, it SHOULD be removed when we upgrade")] - public static T ToObject(this JsonElement element, JsonSerializerOptions options = null) - { - return JsonSerializer.Deserialize(element.GetRawText(), options ?? Default); - } - - [Obsolete("This is built into .NET 6, it SHOULD be removed when we upgrade")] - public static T ToObject(this JsonDocument document, JsonSerializerOptions options = null) - { - return JsonSerializer.Deserialize(document.RootElement.GetRawText(), options ?? default); - } - - public static T DeserializeOrNew(string json, JsonSerializerOptions options = null) - where T : new() - { - if (string.IsNullOrWhiteSpace(json)) - { - return new T(); - } - - return JsonSerializer.Deserialize(json, options); - } - - #region Legacy Newtonsoft.Json usage - private const string LegacyMessage = "Usage of Newtonsoft.Json should be kept to a minimum and will further be removed when we move to .NET 6"; - - [Obsolete(LegacyMessage)] - public static NS.JsonSerializerSettings LegacyEnumKeyResolver { get; } = new NS.JsonSerializerSettings - { - ContractResolver = new EnumKeyResolver(), + WriteIndented = true, }; - [Obsolete(LegacyMessage)] - public static string LegacySerialize(object value, NS.JsonSerializerSettings settings = null) + IgnoreCase = new JsonSerializerOptions { - return NS.JsonConvert.SerializeObject(value, settings); - } + PropertyNameCaseInsensitive = true, + }; - [Obsolete(LegacyMessage)] - public static T LegacyDeserialize(string value, NS.JsonSerializerSettings settings = null) + IgnoreWritingNull = new JsonSerializerOptions { - return NS.JsonConvert.DeserializeObject(value, settings); - } - #endregion + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + }; + + CamelCase = new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }; + + IgnoreWritingNullAndCamelCase = new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + }; } - public class EnumKeyResolver : NS.Serialization.DefaultContractResolver - where T : struct + [Obsolete("This is built into .NET 6, it SHOULD be removed when we upgrade")] + public static T ToObject(this JsonElement element, JsonSerializerOptions options = null) { - protected override NS.Serialization.JsonDictionaryContract CreateDictionaryContract(Type objectType) - { - var contract = base.CreateDictionaryContract(objectType); - var keyType = contract.DictionaryKeyType; - - if (keyType.BaseType == typeof(Enum)) - { - contract.DictionaryKeyResolver = propName => ((T)Enum.Parse(keyType, propName)).ToString(); - } - - return contract; - } + return JsonSerializer.Deserialize(element.GetRawText(), options ?? Default); } - public class MsEpochConverter : JsonConverter + [Obsolete("This is built into .NET 6, it SHOULD be removed when we upgrade")] + public static T ToObject(this JsonDocument document, JsonSerializerOptions options = null) { - public override DateTime? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) - { - if (reader.TokenType == JsonTokenType.Null) - { - return null; - } - - if (!long.TryParse(reader.GetString(), out var milliseconds)) - { - return null; - } - - return CoreHelpers.FromEpocMilliseconds(milliseconds); - } - - public override void Write(Utf8JsonWriter writer, DateTime? value, JsonSerializerOptions options) - { - if (!value.HasValue) - { - writer.WriteNullValue(); - } - - writer.WriteStringValue(CoreHelpers.ToEpocMilliseconds(value.Value).ToString()); - } + return JsonSerializer.Deserialize(document.RootElement.GetRawText(), options ?? default); } - /// - /// Allows reading a string from a JSON number or string, should only be used on properties - /// - public class PermissiveStringConverter : JsonConverter + public static T DeserializeOrNew(string json, JsonSerializerOptions options = null) + where T : new() { - internal static readonly PermissiveStringConverter Instance = new(); - private static readonly CultureInfo _cultureInfo = new("en-US"); - - public override string Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + if (string.IsNullOrWhiteSpace(json)) { - return reader.TokenType switch - { - JsonTokenType.String => reader.GetString(), - JsonTokenType.Number => reader.GetDecimal().ToString(_cultureInfo), - JsonTokenType.True => bool.TrueString, - JsonTokenType.False => bool.FalseString, - _ => throw new JsonException($"Unsupported TokenType: {reader.TokenType}"), - }; + return new T(); } - public override void Write(Utf8JsonWriter writer, string value, JsonSerializerOptions options) - { - writer.WriteStringValue(value); - } + return JsonSerializer.Deserialize(json, options); } - /// - /// Allows reading a JSON array of number or string, should only be used on whose generic type is - /// - public class PermissiveStringEnumerableConverter : JsonConverter> + #region Legacy Newtonsoft.Json usage + private const string LegacyMessage = "Usage of Newtonsoft.Json should be kept to a minimum and will further be removed when we move to .NET 6"; + + [Obsolete(LegacyMessage)] + public static NS.JsonSerializerSettings LegacyEnumKeyResolver { get; } = new NS.JsonSerializerSettings { - public override IEnumerable Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + ContractResolver = new EnumKeyResolver(), + }; + + [Obsolete(LegacyMessage)] + public static string LegacySerialize(object value, NS.JsonSerializerSettings settings = null) + { + return NS.JsonConvert.SerializeObject(value, settings); + } + + [Obsolete(LegacyMessage)] + public static T LegacyDeserialize(string value, NS.JsonSerializerSettings settings = null) + { + return NS.JsonConvert.DeserializeObject(value, settings); + } + #endregion +} + +public class EnumKeyResolver : NS.Serialization.DefaultContractResolver + where T : struct +{ + protected override NS.Serialization.JsonDictionaryContract CreateDictionaryContract(Type objectType) + { + var contract = base.CreateDictionaryContract(objectType); + var keyType = contract.DictionaryKeyType; + + if (keyType.BaseType == typeof(Enum)) { - var stringList = new List(); + contract.DictionaryKeyResolver = propName => ((T)Enum.Parse(keyType, propName)).ToString(); + } - // Handle special cases or throw - if (reader.TokenType != JsonTokenType.StartArray) + return contract; + } +} + +public class MsEpochConverter : JsonConverter +{ + public override DateTime? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType == JsonTokenType.Null) + { + return null; + } + + if (!long.TryParse(reader.GetString(), out var milliseconds)) + { + return null; + } + + return CoreHelpers.FromEpocMilliseconds(milliseconds); + } + + public override void Write(Utf8JsonWriter writer, DateTime? value, JsonSerializerOptions options) + { + if (!value.HasValue) + { + writer.WriteNullValue(); + } + + writer.WriteStringValue(CoreHelpers.ToEpocMilliseconds(value.Value).ToString()); + } +} + +/// +/// Allows reading a string from a JSON number or string, should only be used on properties +/// +public class PermissiveStringConverter : JsonConverter +{ + internal static readonly PermissiveStringConverter Instance = new(); + private static readonly CultureInfo _cultureInfo = new("en-US"); + + public override string Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + return reader.TokenType switch + { + JsonTokenType.String => reader.GetString(), + JsonTokenType.Number => reader.GetDecimal().ToString(_cultureInfo), + JsonTokenType.True => bool.TrueString, + JsonTokenType.False => bool.FalseString, + _ => throw new JsonException($"Unsupported TokenType: {reader.TokenType}"), + }; + } + + public override void Write(Utf8JsonWriter writer, string value, JsonSerializerOptions options) + { + writer.WriteStringValue(value); + } +} + +/// +/// Allows reading a JSON array of number or string, should only be used on whose generic type is +/// +public class PermissiveStringEnumerableConverter : JsonConverter> +{ + public override IEnumerable Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var stringList = new List(); + + // Handle special cases or throw + if (reader.TokenType != JsonTokenType.StartArray) + { + // An array was expected but to be extra permissive allow reading from anything other than an object + if (reader.TokenType == JsonTokenType.StartObject) { - // An array was expected but to be extra permissive allow reading from anything other than an object - if (reader.TokenType == JsonTokenType.StartObject) - { - throw new JsonException("Cannot read JSON Object to an IEnumerable."); - } - - stringList.Add(PermissiveStringConverter.Instance.Read(ref reader, typeof(string), options)); - return stringList; - } - - while (reader.Read() && reader.TokenType != JsonTokenType.EndArray) - { - stringList.Add(PermissiveStringConverter.Instance.Read(ref reader, typeof(string), options)); + throw new JsonException("Cannot read JSON Object to an IEnumerable."); } + stringList.Add(PermissiveStringConverter.Instance.Read(ref reader, typeof(string), options)); return stringList; } - public override void Write(Utf8JsonWriter writer, IEnumerable value, JsonSerializerOptions options) + while (reader.Read() && reader.TokenType != JsonTokenType.EndArray) { - writer.WriteStartArray(); - - foreach (var str in value) - { - PermissiveStringConverter.Instance.Write(writer, str, options); - } - - writer.WriteEndArray(); + stringList.Add(PermissiveStringConverter.Instance.Read(ref reader, typeof(string), options)); } + + return stringList; + } + + public override void Write(Utf8JsonWriter writer, IEnumerable value, JsonSerializerOptions options) + { + writer.WriteStartArray(); + + foreach (var str in value) + { + PermissiveStringConverter.Instance.Write(writer, str, options); + } + + writer.WriteEndArray(); } } diff --git a/src/Core/Utilities/LoggerFactoryExtensions.cs b/src/Core/Utilities/LoggerFactoryExtensions.cs index 98896c56e..792225cdf 100644 --- a/src/Core/Utilities/LoggerFactoryExtensions.cs +++ b/src/Core/Utilities/LoggerFactoryExtensions.cs @@ -10,137 +10,136 @@ using Serilog; using Serilog.Events; using Serilog.Sinks.Syslog; -namespace Bit.Core.Utilities -{ - public static class LoggerFactoryExtensions - { - public static void UseSerilog( - this IApplicationBuilder appBuilder, - IWebHostEnvironment env, - IHostApplicationLifetime applicationLifetime, - GlobalSettings globalSettings) - { - if (env.IsDevelopment()) - { - return; - } +namespace Bit.Core.Utilities; - applicationLifetime.ApplicationStopped.Register(Log.CloseAndFlush); +public static class LoggerFactoryExtensions +{ + public static void UseSerilog( + this IApplicationBuilder appBuilder, + IWebHostEnvironment env, + IHostApplicationLifetime applicationLifetime, + GlobalSettings globalSettings) + { + if (env.IsDevelopment()) + { + return; } - public static ILoggingBuilder AddSerilog( - this ILoggingBuilder builder, - WebHostBuilderContext context, - Func filter = null) + applicationLifetime.ApplicationStopped.Register(Log.CloseAndFlush); + } + + public static ILoggingBuilder AddSerilog( + this ILoggingBuilder builder, + WebHostBuilderContext context, + Func filter = null) + { + if (context.HostingEnvironment.IsDevelopment()) { - if (context.HostingEnvironment.IsDevelopment()) - { - return builder; - } - - bool inclusionPredicate(LogEvent e) - { - if (filter == null) - { - return true; - } - var eventId = e.Properties.ContainsKey("EventId") ? e.Properties["EventId"].ToString() : null; - if (eventId?.Contains(Constants.BypassFiltersEventId.ToString()) ?? false) - { - return true; - } - return filter(e); - } - - var globalSettings = new GlobalSettings(); - ConfigurationBinder.Bind(context.Configuration.GetSection("GlobalSettings"), globalSettings); - - var config = new LoggerConfiguration() - .Enrich.FromLogContext() - .Filter.ByIncludingOnly(inclusionPredicate); - - if (CoreHelpers.SettingHasValue(globalSettings?.DocumentDb.Uri) && - CoreHelpers.SettingHasValue(globalSettings?.DocumentDb.Key)) - { - config.WriteTo.AzureCosmosDB(new Uri(globalSettings.DocumentDb.Uri), - globalSettings.DocumentDb.Key, timeToLive: TimeSpan.FromDays(7), - partitionKey: "_partitionKey") - .Enrich.FromLogContext() - .Enrich.WithProperty("Project", globalSettings.ProjectName); - } - else if (CoreHelpers.SettingHasValue(globalSettings?.Sentry.Dsn)) - { - config.WriteTo.Sentry(globalSettings.Sentry.Dsn) - .Enrich.FromLogContext() - .Enrich.WithProperty("Project", globalSettings.ProjectName); - } - else if (CoreHelpers.SettingHasValue(globalSettings?.Syslog.Destination)) - { - // appending sitename to project name to allow eaiser identification in syslog. - var appName = $"{globalSettings.SiteName}-{globalSettings.ProjectName}"; - if (globalSettings.Syslog.Destination.Equals("local", StringComparison.OrdinalIgnoreCase)) - { - config.WriteTo.LocalSyslog(appName); - } - else if (Uri.TryCreate(globalSettings.Syslog.Destination, UriKind.Absolute, out var syslogAddress)) - { - // Syslog's standard port is 514 (both UDP and TCP). TLS does not have a standard port, so assume 514. - int port = syslogAddress.Port >= 0 - ? syslogAddress.Port - : 514; - - if (syslogAddress.Scheme.Equals("udp")) - { - config.WriteTo.UdpSyslog(syslogAddress.Host, port, appName); - } - else if (syslogAddress.Scheme.Equals("tcp")) - { - config.WriteTo.TcpSyslog(syslogAddress.Host, port, appName); - } - else if (syslogAddress.Scheme.Equals("tls")) - { - // TLS v1.1, v1.2 and v1.3 are explicitly selected (leaving out TLS v1.0) - const SslProtocols protocols = SslProtocols.Tls11 | SslProtocols.Tls12 | SslProtocols.Tls13; - - if (CoreHelpers.SettingHasValue(globalSettings.Syslog.CertificateThumbprint)) - { - config.WriteTo.TcpSyslog(syslogAddress.Host, port, appName, - secureProtocols: protocols, - certProvider: new CertificateStoreProvider(StoreName.My, StoreLocation.CurrentUser, - globalSettings.Syslog.CertificateThumbprint)); - } - else - { - config.WriteTo.TcpSyslog(syslogAddress.Host, port, appName, - secureProtocols: protocols, - certProvider: new CertificateFileProvider(globalSettings.Syslog.CertificatePath, - globalSettings.Syslog?.CertificatePassword ?? string.Empty)); - } - - } - } - } - else if (CoreHelpers.SettingHasValue(globalSettings.LogDirectory)) - { - if (globalSettings.LogRollBySizeLimit.HasValue) - { - config.WriteTo.File($"{globalSettings.LogDirectory}/{globalSettings.ProjectName}/log.txt", - rollOnFileSizeLimit: true, fileSizeLimitBytes: globalSettings.LogRollBySizeLimit); - } - else - { - config.WriteTo - .RollingFile($"{globalSettings.LogDirectory}/{globalSettings.ProjectName}/{{Date}}.txt"); - } - config - .Enrich.FromLogContext() - .Enrich.WithProperty("Project", globalSettings.ProjectName); - } - - var serilog = config.CreateLogger(); - builder.AddSerilog(serilog); - return builder; } + + bool inclusionPredicate(LogEvent e) + { + if (filter == null) + { + return true; + } + var eventId = e.Properties.ContainsKey("EventId") ? e.Properties["EventId"].ToString() : null; + if (eventId?.Contains(Constants.BypassFiltersEventId.ToString()) ?? false) + { + return true; + } + return filter(e); + } + + var globalSettings = new GlobalSettings(); + ConfigurationBinder.Bind(context.Configuration.GetSection("GlobalSettings"), globalSettings); + + var config = new LoggerConfiguration() + .Enrich.FromLogContext() + .Filter.ByIncludingOnly(inclusionPredicate); + + if (CoreHelpers.SettingHasValue(globalSettings?.DocumentDb.Uri) && + CoreHelpers.SettingHasValue(globalSettings?.DocumentDb.Key)) + { + config.WriteTo.AzureCosmosDB(new Uri(globalSettings.DocumentDb.Uri), + globalSettings.DocumentDb.Key, timeToLive: TimeSpan.FromDays(7), + partitionKey: "_partitionKey") + .Enrich.FromLogContext() + .Enrich.WithProperty("Project", globalSettings.ProjectName); + } + else if (CoreHelpers.SettingHasValue(globalSettings?.Sentry.Dsn)) + { + config.WriteTo.Sentry(globalSettings.Sentry.Dsn) + .Enrich.FromLogContext() + .Enrich.WithProperty("Project", globalSettings.ProjectName); + } + else if (CoreHelpers.SettingHasValue(globalSettings?.Syslog.Destination)) + { + // appending sitename to project name to allow eaiser identification in syslog. + var appName = $"{globalSettings.SiteName}-{globalSettings.ProjectName}"; + if (globalSettings.Syslog.Destination.Equals("local", StringComparison.OrdinalIgnoreCase)) + { + config.WriteTo.LocalSyslog(appName); + } + else if (Uri.TryCreate(globalSettings.Syslog.Destination, UriKind.Absolute, out var syslogAddress)) + { + // Syslog's standard port is 514 (both UDP and TCP). TLS does not have a standard port, so assume 514. + int port = syslogAddress.Port >= 0 + ? syslogAddress.Port + : 514; + + if (syslogAddress.Scheme.Equals("udp")) + { + config.WriteTo.UdpSyslog(syslogAddress.Host, port, appName); + } + else if (syslogAddress.Scheme.Equals("tcp")) + { + config.WriteTo.TcpSyslog(syslogAddress.Host, port, appName); + } + else if (syslogAddress.Scheme.Equals("tls")) + { + // TLS v1.1, v1.2 and v1.3 are explicitly selected (leaving out TLS v1.0) + const SslProtocols protocols = SslProtocols.Tls11 | SslProtocols.Tls12 | SslProtocols.Tls13; + + if (CoreHelpers.SettingHasValue(globalSettings.Syslog.CertificateThumbprint)) + { + config.WriteTo.TcpSyslog(syslogAddress.Host, port, appName, + secureProtocols: protocols, + certProvider: new CertificateStoreProvider(StoreName.My, StoreLocation.CurrentUser, + globalSettings.Syslog.CertificateThumbprint)); + } + else + { + config.WriteTo.TcpSyslog(syslogAddress.Host, port, appName, + secureProtocols: protocols, + certProvider: new CertificateFileProvider(globalSettings.Syslog.CertificatePath, + globalSettings.Syslog?.CertificatePassword ?? string.Empty)); + } + + } + } + } + else if (CoreHelpers.SettingHasValue(globalSettings.LogDirectory)) + { + if (globalSettings.LogRollBySizeLimit.HasValue) + { + config.WriteTo.File($"{globalSettings.LogDirectory}/{globalSettings.ProjectName}/log.txt", + rollOnFileSizeLimit: true, fileSizeLimitBytes: globalSettings.LogRollBySizeLimit); + } + else + { + config.WriteTo + .RollingFile($"{globalSettings.LogDirectory}/{globalSettings.ProjectName}/{{Date}}.txt"); + } + config + .Enrich.FromLogContext() + .Enrich.WithProperty("Project", globalSettings.ProjectName); + } + + var serilog = config.CreateLogger(); + builder.AddSerilog(serilog); + + return builder; } } diff --git a/src/Core/Utilities/LoggingExceptionHandlerFilterAttribute.cs b/src/Core/Utilities/LoggingExceptionHandlerFilterAttribute.cs index 8df51b1e5..6709bbb27 100644 --- a/src/Core/Utilities/LoggingExceptionHandlerFilterAttribute.cs +++ b/src/Core/Utilities/LoggingExceptionHandlerFilterAttribute.cs @@ -2,22 +2,21 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; -namespace Bit.Core.Utilities -{ - public class LoggingExceptionHandlerFilterAttribute : ExceptionFilterAttribute - { - public override void OnException(ExceptionContext context) - { - var exception = context.Exception; - if (exception == null) - { - // Should never happen. - return; - } +namespace Bit.Core.Utilities; - var logger = context.HttpContext.RequestServices - .GetRequiredService>(); - logger.LogError(0, exception, exception.Message); +public class LoggingExceptionHandlerFilterAttribute : ExceptionFilterAttribute +{ + public override void OnException(ExceptionContext context) + { + var exception = context.Exception; + if (exception == null) + { + // Should never happen. + return; } + + var logger = context.HttpContext.RequestServices + .GetRequiredService>(); + logger.LogError(0, exception, exception.Message); } } diff --git a/src/Core/Utilities/SecurityHeadersMiddleware.cs b/src/Core/Utilities/SecurityHeadersMiddleware.cs index 3a1cc477e..19616e8a7 100644 --- a/src/Core/Utilities/SecurityHeadersMiddleware.cs +++ b/src/Core/Utilities/SecurityHeadersMiddleware.cs @@ -1,29 +1,28 @@ using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Primitives; -namespace Bit.Core.Utilities +namespace Bit.Core.Utilities; + +public sealed class SecurityHeadersMiddleware { - public sealed class SecurityHeadersMiddleware + private readonly RequestDelegate _next; + + public SecurityHeadersMiddleware(RequestDelegate next) { - private readonly RequestDelegate _next; + _next = next; + } - public SecurityHeadersMiddleware(RequestDelegate next) - { - _next = next; - } + public Task Invoke(HttpContext context) + { + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Frame-Options + context.Response.Headers.Add("x-frame-options", new StringValues("SAMEORIGIN")); - public Task Invoke(HttpContext context) - { - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Frame-Options - context.Response.Headers.Add("x-frame-options", new StringValues("SAMEORIGIN")); + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-XSS-Protection + context.Response.Headers.Add("x-xss-protection", new StringValues("1; mode=block")); - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-XSS-Protection - context.Response.Headers.Add("x-xss-protection", new StringValues("1; mode=block")); + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Content-Type-Options + context.Response.Headers.Add("x-content-type-options", new StringValues("nosniff")); - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Content-Type-Options - context.Response.Headers.Add("x-content-type-options", new StringValues("nosniff")); - - return _next(context); - } + return _next(context); } } diff --git a/src/Core/Utilities/SelfHostedAttribute.cs b/src/Core/Utilities/SelfHostedAttribute.cs index 13dc83fa7..f4ea83592 100644 --- a/src/Core/Utilities/SelfHostedAttribute.cs +++ b/src/Core/Utilities/SelfHostedAttribute.cs @@ -3,24 +3,23 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Mvc.Filters; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.Utilities -{ - public class SelfHostedAttribute : ActionFilterAttribute - { - public bool SelfHostedOnly { get; set; } - public bool NotSelfHostedOnly { get; set; } +namespace Bit.Core.Utilities; - public override void OnActionExecuting(ActionExecutingContext context) +public class SelfHostedAttribute : ActionFilterAttribute +{ + public bool SelfHostedOnly { get; set; } + public bool NotSelfHostedOnly { get; set; } + + public override void OnActionExecuting(ActionExecutingContext context) + { + var globalSettings = context.HttpContext.RequestServices.GetRequiredService(); + if (SelfHostedOnly && !globalSettings.SelfHosted) { - var globalSettings = context.HttpContext.RequestServices.GetRequiredService(); - if (SelfHostedOnly && !globalSettings.SelfHosted) - { - throw new BadRequestException("Only allowed when self hosted."); - } - else if (NotSelfHostedOnly && globalSettings.SelfHosted) - { - throw new BadRequestException("Only allowed when not self hosted."); - } + throw new BadRequestException("Only allowed when self hosted."); + } + else if (NotSelfHostedOnly && globalSettings.SelfHosted) + { + throw new BadRequestException("Only allowed when not self hosted."); } } } diff --git a/src/Core/Utilities/StaticStore.cs b/src/Core/Utilities/StaticStore.cs index 0b8cb61bf..053f7ed45 100644 --- a/src/Core/Utilities/StaticStore.cs +++ b/src/Core/Utilities/StaticStore.cs @@ -2,501 +2,500 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Utilities +namespace Bit.Core.Utilities; + +public class StaticStore { - public class StaticStore + static StaticStore() { - static StaticStore() + #region Global Domains + + GlobalDomains = new Dictionary>(); + + GlobalDomains.Add(GlobalEquivalentDomainsType.Ameritrade, new List { "ameritrade.com", "tdameritrade.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.BoA, new List { "bankofamerica.com", "bofa.com", "mbna.com", "usecfo.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Sprint, new List { "sprint.com", "sprintpcs.com", "nextel.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Google, new List { "youtube.com", "google.com", "gmail.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Apple, new List { "apple.com", "icloud.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.WellsFargo, new List { "wellsfargo.com", "wf.com", "wellsfargoadvisors.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Merrill, new List { "mymerrill.com", "ml.com", "merrilledge.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Citi, new List { "accountonline.com", "citi.com", "citibank.com", "citicards.com", "citibankonline.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Cnet, new List { "cnet.com", "cnettv.com", "com.com", "download.com", "news.com", "search.com", "upload.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Gap, new List { "bananarepublic.com", "gap.com", "oldnavy.com", "piperlime.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Microsoft, new List { "bing.com", "hotmail.com", "live.com", "microsoft.com", "msn.com", "passport.net", "windows.com", "microsoftonline.com", "office.com", "office365.com", "microsoftstore.com", "xbox.com", "azure.com", "windowsazure.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.United, new List { "ua2go.com", "ual.com", "united.com", "unitedwifi.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Yahoo, new List { "overture.com", "yahoo.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Zonelabs, new List { "zonealarm.com", "zonelabs.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.PayPal, new List { "paypal.com", "paypal-search.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Avon, new List { "avon.com", "youravon.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Diapers, new List { "diapers.com", "soap.com", "wag.com", "yoyo.com", "beautybar.com", "casa.com", "afterschool.com", "vine.com", "bookworm.com", "look.com", "vinemarket.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Contacts, new List { "1800contacts.com", "800contacts.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Amazon, new List { "amazon.com", "amazon.ae", "amazon.ca", "amazon.co.uk", "amazon.com.au", "amazon.com.br", "amazon.com.mx", "amazon.com.tr", "amazon.de", "amazon.es", "amazon.fr", "amazon.in", "amazon.it", "amazon.nl", "amazon.pl", "amazon.sa", "amazon.se", "amazon.sg" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Cox, new List { "cox.com", "cox.net", "coxbusiness.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Norton, new List { "mynortonaccount.com", "norton.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Verizon, new List { "verizon.com", "verizon.net" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Buy, new List { "rakuten.com", "buy.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Sirius, new List { "siriusxm.com", "sirius.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Ea, new List { "ea.com", "origin.com", "play4free.com", "tiberiumalliance.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Basecamp, new List { "37signals.com", "basecamp.com", "basecamphq.com", "highrisehq.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Steam, new List { "steampowered.com", "steamcommunity.com", "steamgames.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Chart, new List { "chart.io", "chartio.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Gotomeeting, new List { "gotomeeting.com", "citrixonline.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Gogo, new List { "gogoair.com", "gogoinflight.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Oracle, new List { "mysql.com", "oracle.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Discover, new List { "discover.com", "discovercard.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Dcu, new List { "dcu.org", "dcu-online.org" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Healthcare, new List { "healthcare.gov", "cuidadodesalud.gov", "cms.gov" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Pepco, new List { "pepco.com", "pepcoholdings.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Century21, new List { "century21.com", "21online.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Comcast, new List { "comcast.com", "comcast.net", "xfinity.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Cricket, new List { "cricketwireless.com", "aiowireless.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Mtb, new List { "mandtbank.com", "mtb.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Dropbox, new List { "dropbox.com", "getdropbox.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Snapfish, new List { "snapfish.com", "snapfish.ca" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Alibaba, new List { "alibaba.com", "aliexpress.com", "aliyun.com", "net.cn" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Playstation, new List { "playstation.com", "sonyentertainmentnetwork.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Mercado, new List { "mercadolivre.com", "mercadolivre.com.br", "mercadolibre.com", "mercadolibre.com.ar", "mercadolibre.com.mx" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Zendesk, new List { "zendesk.com", "zopim.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Autodesk, new List { "autodesk.com", "tinkercad.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.RailNation, new List { "railnation.ru", "railnation.de", "rail-nation.com", "railnation.gr", "railnation.us", "trucknation.de", "traviangames.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Wpcu, new List { "wpcu.coop", "wpcuonline.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Mathletics, new List { "mathletics.com", "mathletics.com.au", "mathletics.co.uk" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Discountbank, new List { "discountbank.co.il", "telebank.co.il" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Mi, new List { "mi.com", "xiaomi.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Postepay, new List { "postepay.it", "poste.it" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Facebook, new List { "facebook.com", "messenger.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Skysports, new List { "skysports.com", "skybet.com", "skyvegas.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Disney, new List { "disneymoviesanywhere.com", "go.com", "disney.com", "dadt.com", "disneyplus.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Pokemon, new List { "pokemon-gl.com", "pokemon.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Uv, new List { "myuv.com", "uvvu.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Mdsol, new List { "mdsol.com", "imedidata.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Yahavo, new List { "bank-yahav.co.il", "bankhapoalim.co.il" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Sears, new List { "sears.com", "shld.net" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Xiami, new List { "xiami.com", "alipay.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Belkin, new List { "belkin.com", "seedonk.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Turbotax, new List { "turbotax.com", "intuit.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Shopify, new List { "shopify.com", "myshopify.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Ebay, new List { "ebay.com", "ebay.at", "ebay.be", "ebay.ca", "ebay.ch", "ebay.cn", "ebay.co.jp", "ebay.co.th", "ebay.co.uk", "ebay.com.au", "ebay.com.hk", "ebay.com.my", "ebay.com.sg", "ebay.com.tw", "ebay.de", "ebay.es", "ebay.fr", "ebay.ie", "ebay.in", "ebay.it", "ebay.nl", "ebay.ph", "ebay.pl" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Techdata, new List { "techdata.com", "techdata.ch" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Schwab, new List { "schwab.com", "schwabplan.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Tesla, new List { "tesla.com", "teslamotors.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.MorganStanley, new List { "morganstanley.com", "morganstanleyclientserv.com", "stockplanconnect.com", "ms.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.TaxAct, new List { "taxact.com", "taxactonline.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Wikimedia, new List { "mediawiki.org", "wikibooks.org", "wikidata.org", "wikimedia.org", "wikinews.org", "wikipedia.org", "wikiquote.org", "wikisource.org", "wikiversity.org", "wikivoyage.org", "wiktionary.org" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Airbnb, new List { "airbnb.at", "airbnb.be", "airbnb.ca", "airbnb.ch", "airbnb.cl", "airbnb.co.cr", "airbnb.co.id", "airbnb.co.in", "airbnb.co.kr", "airbnb.co.nz", "airbnb.co.uk", "airbnb.co.ve", "airbnb.com", "airbnb.com.ar", "airbnb.com.au", "airbnb.com.bo", "airbnb.com.br", "airbnb.com.bz", "airbnb.com.co", "airbnb.com.ec", "airbnb.com.gt", "airbnb.com.hk", "airbnb.com.hn", "airbnb.com.mt", "airbnb.com.my", "airbnb.com.ni", "airbnb.com.pa", "airbnb.com.pe", "airbnb.com.py", "airbnb.com.sg", "airbnb.com.sv", "airbnb.com.tr", "airbnb.com.tw", "airbnb.cz", "airbnb.de", "airbnb.dk", "airbnb.es", "airbnb.fi", "airbnb.fr", "airbnb.gr", "airbnb.gy", "airbnb.hu", "airbnb.ie", "airbnb.is", "airbnb.it", "airbnb.jp", "airbnb.mx", "airbnb.nl", "airbnb.no", "airbnb.pl", "airbnb.pt", "airbnb.ru", "airbnb.se" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Eventbrite, new List { "eventbrite.at", "eventbrite.be", "eventbrite.ca", "eventbrite.ch", "eventbrite.cl", "eventbrite.co", "eventbrite.co.nz", "eventbrite.co.uk", "eventbrite.com", "eventbrite.com.ar", "eventbrite.com.au", "eventbrite.com.br", "eventbrite.com.mx", "eventbrite.com.pe", "eventbrite.de", "eventbrite.dk", "eventbrite.es", "eventbrite.fi", "eventbrite.fr", "eventbrite.hk", "eventbrite.ie", "eventbrite.it", "eventbrite.nl", "eventbrite.pt", "eventbrite.se", "eventbrite.sg" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.StackExchange, new List { "stackexchange.com", "superuser.com", "stackoverflow.com", "serverfault.com", "mathoverflow.net", "askubuntu.com", "stackapps.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Docusign, new List { "docusign.com", "docusign.net" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Envato, new List { "envato.com", "themeforest.net", "codecanyon.net", "videohive.net", "audiojungle.net", "graphicriver.net", "photodune.net", "3docean.net" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.X10Hosting, new List { "x10hosting.com", "x10premium.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Cisco, new List { "dnsomatic.com", "opendns.com", "umbrella.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.CedarFair, new List { "cagreatamerica.com", "canadaswonderland.com", "carowinds.com", "cedarfair.com", "cedarpoint.com", "dorneypark.com", "kingsdominion.com", "knotts.com", "miadventure.com", "schlitterbahn.com", "valleyfair.com", "visitkingsisland.com", "worldsoffun.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Ubiquiti, new List { "ubnt.com", "ui.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Discord, new List { "discordapp.com", "discord.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Netcup, new List { "netcup.de", "netcup.eu", "customercontrolpanel.de" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Yandex, new List { "yandex.com", "ya.ru", "yandex.az", "yandex.by", "yandex.co.il", "yandex.com.am", "yandex.com.ge", "yandex.com.tr", "yandex.ee", "yandex.fi", "yandex.fr", "yandex.kg", "yandex.kz", "yandex.lt", "yandex.lv", "yandex.md", "yandex.pl", "yandex.ru", "yandex.tj", "yandex.tm", "yandex.ua", "yandex.uz" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Sony, new List { "sonyentertainmentnetwork.com", "sony.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Proton, new List { "proton.me", "protonmail.com", "protonvpn.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.Ubisoft, new List { "ubisoft.com", "ubi.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.TransferWise, new List { "transferwise.com", "wise.com" }); + GlobalDomains.Add(GlobalEquivalentDomainsType.TakeawayEU, new List { "takeaway.com", "just-eat.dk", "just-eat.no", "just-eat.fr", "just-eat.ch", "lieferando.de", "lieferando.at", "thuisbezorgd.nl", "pyszne.pl" }); + #endregion + + #region Plans + + Plans = new List { - #region Global Domains - - GlobalDomains = new Dictionary>(); - - GlobalDomains.Add(GlobalEquivalentDomainsType.Ameritrade, new List { "ameritrade.com", "tdameritrade.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.BoA, new List { "bankofamerica.com", "bofa.com", "mbna.com", "usecfo.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Sprint, new List { "sprint.com", "sprintpcs.com", "nextel.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Google, new List { "youtube.com", "google.com", "gmail.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Apple, new List { "apple.com", "icloud.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.WellsFargo, new List { "wellsfargo.com", "wf.com", "wellsfargoadvisors.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Merrill, new List { "mymerrill.com", "ml.com", "merrilledge.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Citi, new List { "accountonline.com", "citi.com", "citibank.com", "citicards.com", "citibankonline.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Cnet, new List { "cnet.com", "cnettv.com", "com.com", "download.com", "news.com", "search.com", "upload.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Gap, new List { "bananarepublic.com", "gap.com", "oldnavy.com", "piperlime.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Microsoft, new List { "bing.com", "hotmail.com", "live.com", "microsoft.com", "msn.com", "passport.net", "windows.com", "microsoftonline.com", "office.com", "office365.com", "microsoftstore.com", "xbox.com", "azure.com", "windowsazure.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.United, new List { "ua2go.com", "ual.com", "united.com", "unitedwifi.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Yahoo, new List { "overture.com", "yahoo.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Zonelabs, new List { "zonealarm.com", "zonelabs.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.PayPal, new List { "paypal.com", "paypal-search.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Avon, new List { "avon.com", "youravon.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Diapers, new List { "diapers.com", "soap.com", "wag.com", "yoyo.com", "beautybar.com", "casa.com", "afterschool.com", "vine.com", "bookworm.com", "look.com", "vinemarket.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Contacts, new List { "1800contacts.com", "800contacts.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Amazon, new List { "amazon.com", "amazon.ae", "amazon.ca", "amazon.co.uk", "amazon.com.au", "amazon.com.br", "amazon.com.mx", "amazon.com.tr", "amazon.de", "amazon.es", "amazon.fr", "amazon.in", "amazon.it", "amazon.nl", "amazon.pl", "amazon.sa", "amazon.se", "amazon.sg" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Cox, new List { "cox.com", "cox.net", "coxbusiness.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Norton, new List { "mynortonaccount.com", "norton.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Verizon, new List { "verizon.com", "verizon.net" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Buy, new List { "rakuten.com", "buy.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Sirius, new List { "siriusxm.com", "sirius.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Ea, new List { "ea.com", "origin.com", "play4free.com", "tiberiumalliance.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Basecamp, new List { "37signals.com", "basecamp.com", "basecamphq.com", "highrisehq.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Steam, new List { "steampowered.com", "steamcommunity.com", "steamgames.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Chart, new List { "chart.io", "chartio.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Gotomeeting, new List { "gotomeeting.com", "citrixonline.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Gogo, new List { "gogoair.com", "gogoinflight.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Oracle, new List { "mysql.com", "oracle.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Discover, new List { "discover.com", "discovercard.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Dcu, new List { "dcu.org", "dcu-online.org" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Healthcare, new List { "healthcare.gov", "cuidadodesalud.gov", "cms.gov" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Pepco, new List { "pepco.com", "pepcoholdings.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Century21, new List { "century21.com", "21online.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Comcast, new List { "comcast.com", "comcast.net", "xfinity.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Cricket, new List { "cricketwireless.com", "aiowireless.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Mtb, new List { "mandtbank.com", "mtb.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Dropbox, new List { "dropbox.com", "getdropbox.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Snapfish, new List { "snapfish.com", "snapfish.ca" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Alibaba, new List { "alibaba.com", "aliexpress.com", "aliyun.com", "net.cn" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Playstation, new List { "playstation.com", "sonyentertainmentnetwork.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Mercado, new List { "mercadolivre.com", "mercadolivre.com.br", "mercadolibre.com", "mercadolibre.com.ar", "mercadolibre.com.mx" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Zendesk, new List { "zendesk.com", "zopim.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Autodesk, new List { "autodesk.com", "tinkercad.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.RailNation, new List { "railnation.ru", "railnation.de", "rail-nation.com", "railnation.gr", "railnation.us", "trucknation.de", "traviangames.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Wpcu, new List { "wpcu.coop", "wpcuonline.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Mathletics, new List { "mathletics.com", "mathletics.com.au", "mathletics.co.uk" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Discountbank, new List { "discountbank.co.il", "telebank.co.il" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Mi, new List { "mi.com", "xiaomi.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Postepay, new List { "postepay.it", "poste.it" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Facebook, new List { "facebook.com", "messenger.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Skysports, new List { "skysports.com", "skybet.com", "skyvegas.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Disney, new List { "disneymoviesanywhere.com", "go.com", "disney.com", "dadt.com", "disneyplus.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Pokemon, new List { "pokemon-gl.com", "pokemon.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Uv, new List { "myuv.com", "uvvu.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Mdsol, new List { "mdsol.com", "imedidata.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Yahavo, new List { "bank-yahav.co.il", "bankhapoalim.co.il" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Sears, new List { "sears.com", "shld.net" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Xiami, new List { "xiami.com", "alipay.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Belkin, new List { "belkin.com", "seedonk.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Turbotax, new List { "turbotax.com", "intuit.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Shopify, new List { "shopify.com", "myshopify.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Ebay, new List { "ebay.com", "ebay.at", "ebay.be", "ebay.ca", "ebay.ch", "ebay.cn", "ebay.co.jp", "ebay.co.th", "ebay.co.uk", "ebay.com.au", "ebay.com.hk", "ebay.com.my", "ebay.com.sg", "ebay.com.tw", "ebay.de", "ebay.es", "ebay.fr", "ebay.ie", "ebay.in", "ebay.it", "ebay.nl", "ebay.ph", "ebay.pl" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Techdata, new List { "techdata.com", "techdata.ch" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Schwab, new List { "schwab.com", "schwabplan.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Tesla, new List { "tesla.com", "teslamotors.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.MorganStanley, new List { "morganstanley.com", "morganstanleyclientserv.com", "stockplanconnect.com", "ms.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.TaxAct, new List { "taxact.com", "taxactonline.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Wikimedia, new List { "mediawiki.org", "wikibooks.org", "wikidata.org", "wikimedia.org", "wikinews.org", "wikipedia.org", "wikiquote.org", "wikisource.org", "wikiversity.org", "wikivoyage.org", "wiktionary.org" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Airbnb, new List { "airbnb.at", "airbnb.be", "airbnb.ca", "airbnb.ch", "airbnb.cl", "airbnb.co.cr", "airbnb.co.id", "airbnb.co.in", "airbnb.co.kr", "airbnb.co.nz", "airbnb.co.uk", "airbnb.co.ve", "airbnb.com", "airbnb.com.ar", "airbnb.com.au", "airbnb.com.bo", "airbnb.com.br", "airbnb.com.bz", "airbnb.com.co", "airbnb.com.ec", "airbnb.com.gt", "airbnb.com.hk", "airbnb.com.hn", "airbnb.com.mt", "airbnb.com.my", "airbnb.com.ni", "airbnb.com.pa", "airbnb.com.pe", "airbnb.com.py", "airbnb.com.sg", "airbnb.com.sv", "airbnb.com.tr", "airbnb.com.tw", "airbnb.cz", "airbnb.de", "airbnb.dk", "airbnb.es", "airbnb.fi", "airbnb.fr", "airbnb.gr", "airbnb.gy", "airbnb.hu", "airbnb.ie", "airbnb.is", "airbnb.it", "airbnb.jp", "airbnb.mx", "airbnb.nl", "airbnb.no", "airbnb.pl", "airbnb.pt", "airbnb.ru", "airbnb.se" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Eventbrite, new List { "eventbrite.at", "eventbrite.be", "eventbrite.ca", "eventbrite.ch", "eventbrite.cl", "eventbrite.co", "eventbrite.co.nz", "eventbrite.co.uk", "eventbrite.com", "eventbrite.com.ar", "eventbrite.com.au", "eventbrite.com.br", "eventbrite.com.mx", "eventbrite.com.pe", "eventbrite.de", "eventbrite.dk", "eventbrite.es", "eventbrite.fi", "eventbrite.fr", "eventbrite.hk", "eventbrite.ie", "eventbrite.it", "eventbrite.nl", "eventbrite.pt", "eventbrite.se", "eventbrite.sg" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.StackExchange, new List { "stackexchange.com", "superuser.com", "stackoverflow.com", "serverfault.com", "mathoverflow.net", "askubuntu.com", "stackapps.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Docusign, new List { "docusign.com", "docusign.net" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Envato, new List { "envato.com", "themeforest.net", "codecanyon.net", "videohive.net", "audiojungle.net", "graphicriver.net", "photodune.net", "3docean.net" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.X10Hosting, new List { "x10hosting.com", "x10premium.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Cisco, new List { "dnsomatic.com", "opendns.com", "umbrella.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.CedarFair, new List { "cagreatamerica.com", "canadaswonderland.com", "carowinds.com", "cedarfair.com", "cedarpoint.com", "dorneypark.com", "kingsdominion.com", "knotts.com", "miadventure.com", "schlitterbahn.com", "valleyfair.com", "visitkingsisland.com", "worldsoffun.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Ubiquiti, new List { "ubnt.com", "ui.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Discord, new List { "discordapp.com", "discord.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Netcup, new List { "netcup.de", "netcup.eu", "customercontrolpanel.de" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Yandex, new List { "yandex.com", "ya.ru", "yandex.az", "yandex.by", "yandex.co.il", "yandex.com.am", "yandex.com.ge", "yandex.com.tr", "yandex.ee", "yandex.fi", "yandex.fr", "yandex.kg", "yandex.kz", "yandex.lt", "yandex.lv", "yandex.md", "yandex.pl", "yandex.ru", "yandex.tj", "yandex.tm", "yandex.ua", "yandex.uz" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Sony, new List { "sonyentertainmentnetwork.com", "sony.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Proton, new List { "proton.me", "protonmail.com", "protonvpn.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.Ubisoft, new List { "ubisoft.com", "ubi.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.TransferWise, new List { "transferwise.com", "wise.com" }); - GlobalDomains.Add(GlobalEquivalentDomainsType.TakeawayEU, new List { "takeaway.com", "just-eat.dk", "just-eat.no", "just-eat.fr", "just-eat.ch", "lieferando.de", "lieferando.at", "thuisbezorgd.nl", "pyszne.pl" }); - #endregion - - #region Plans - - Plans = new List + new Plan { - new Plan - { - Type = PlanType.Free, - Product = ProductType.Free, - Name = "Free", - NameLocalizationKey = "planNameFree", - DescriptionLocalizationKey = "planDescFree", - BaseSeats = 2, - MaxCollections = 2, - MaxUsers = 2, + Type = PlanType.Free, + Product = ProductType.Free, + Name = "Free", + NameLocalizationKey = "planNameFree", + DescriptionLocalizationKey = "planDescFree", + BaseSeats = 2, + MaxCollections = 2, + MaxUsers = 2, - UpgradeSortOrder = -1, // Always the lowest plan, cannot be upgraded to - DisplaySortOrder = -1, + UpgradeSortOrder = -1, // Always the lowest plan, cannot be upgraded to + DisplaySortOrder = -1, - AllowSeatAutoscale = false, - }, - new Plan - { - Type = PlanType.FamiliesAnnually2019, - Product = ProductType.Families, - Name = "Families 2019", - IsAnnual = true, - NameLocalizationKey = "planNameFamilies", - DescriptionLocalizationKey = "planDescFamilies", - BaseSeats = 5, - BaseStorageGb = 1, - MaxUsers = 5, - - HasAdditionalStorageOption = true, - HasPremiumAccessOption = true, - TrialPeriodDays = 7, - - HasSelfHost = true, - HasTotp = true, - - UpgradeSortOrder = 1, - DisplaySortOrder = 1, - LegacyYear = 2020, - - StripePlanId = "personal-org-annually", - StripeStoragePlanId = "storage-gb-annually", - StripePremiumAccessPlanId = "personal-org-premium-access-annually", - BasePrice = 12, - AdditionalStoragePricePerGb = 4, - PremiumAccessOptionPrice = 40, - - AllowSeatAutoscale = false, - }, - new Plan - { - Type = PlanType.TeamsAnnually2019, - Product = ProductType.Teams, - Name = "Teams (Annually) 2019", - IsAnnual = true, - NameLocalizationKey = "planNameTeams", - DescriptionLocalizationKey = "planDescTeams", - CanBeUsedByBusiness = true, - BaseSeats = 5, - BaseStorageGb = 1, - - HasAdditionalSeatsOption = true, - HasAdditionalStorageOption = true, - TrialPeriodDays = 7, - - HasTotp = true, - - UpgradeSortOrder = 2, - DisplaySortOrder = 2, - LegacyYear = 2020, - - StripePlanId = "teams-org-annually", - StripeSeatPlanId = "teams-org-seat-annually", - StripeStoragePlanId = "storage-gb-annually", - BasePrice = 60, - SeatPrice = 24, - AdditionalStoragePricePerGb = 4, - - AllowSeatAutoscale = true, - }, - new Plan - { - Type = PlanType.TeamsMonthly2019, - Product = ProductType.Teams, - Name = "Teams (Monthly) 2019", - NameLocalizationKey = "planNameTeams", - DescriptionLocalizationKey = "planDescTeams", - CanBeUsedByBusiness = true, - BaseSeats = 5, - BaseStorageGb = 1, - - HasAdditionalSeatsOption = true, - HasAdditionalStorageOption = true, - TrialPeriodDays = 7, - - HasTotp = true, - - UpgradeSortOrder = 2, - DisplaySortOrder = 2, - LegacyYear = 2020, - - StripePlanId = "teams-org-monthly", - StripeSeatPlanId = "teams-org-seat-monthly", - StripeStoragePlanId = "storage-gb-monthly", - BasePrice = 8, - SeatPrice = 2.5M, - AdditionalStoragePricePerGb = 0.5M, - - AllowSeatAutoscale = true, - }, - new Plan - { - Type = PlanType.EnterpriseAnnually2019, - Name = "Enterprise (Annually) 2019", - IsAnnual = true, - Product = ProductType.Enterprise, - NameLocalizationKey = "planNameEnterprise", - DescriptionLocalizationKey = "planDescEnterprise", - CanBeUsedByBusiness = true, - BaseSeats = 0, - BaseStorageGb = 1, - - HasAdditionalSeatsOption = true, - HasAdditionalStorageOption = true, - TrialPeriodDays = 7, - - HasPolicies = true, - HasSelfHost = true, - HasGroups = true, - HasDirectory = true, - HasEvents = true, - HasTotp = true, - Has2fa = true, - HasApi = true, - UsersGetPremium = true, - - UpgradeSortOrder = 3, - DisplaySortOrder = 3, - LegacyYear = 2020, - - StripePlanId = null, - StripeSeatPlanId = "enterprise-org-seat-annually", - StripeStoragePlanId = "storage-gb-annually", - BasePrice = 0, - SeatPrice = 36, - AdditionalStoragePricePerGb = 4, - - AllowSeatAutoscale = true, - }, - new Plan - { - Type = PlanType.EnterpriseMonthly2019, - Product = ProductType.Enterprise, - Name = "Enterprise (Monthly) 2019", - NameLocalizationKey = "planNameEnterprise", - DescriptionLocalizationKey = "planDescEnterprise", - CanBeUsedByBusiness = true, - BaseSeats = 0, - BaseStorageGb = 1, - - HasAdditionalSeatsOption = true, - HasAdditionalStorageOption = true, - TrialPeriodDays = 7, - - HasPolicies = true, - HasGroups = true, - HasDirectory = true, - HasEvents = true, - HasTotp = true, - Has2fa = true, - HasApi = true, - HasSelfHost = true, - UsersGetPremium = true, - - UpgradeSortOrder = 3, - DisplaySortOrder = 3, - LegacyYear = 2020, - - StripePlanId = null, - StripeSeatPlanId = "enterprise-org-seat-monthly", - StripeStoragePlanId = "storage-gb-monthly", - BasePrice = 0, - SeatPrice = 4M, - AdditionalStoragePricePerGb = 0.5M, - - AllowSeatAutoscale = true, - }, - new Plan - { - Type = PlanType.FamiliesAnnually, - Product = ProductType.Families, - Name = "Families", - IsAnnual = true, - NameLocalizationKey = "planNameFamilies", - DescriptionLocalizationKey = "planDescFamilies", - BaseSeats = 6, - BaseStorageGb = 1, - MaxUsers = 6, - - HasAdditionalStorageOption = true, - TrialPeriodDays = 7, - - HasSelfHost = true, - HasTotp = true, - UsersGetPremium = true, - - UpgradeSortOrder = 1, - DisplaySortOrder = 1, - - StripePlanId = "2020-families-org-annually", - StripeStoragePlanId = "storage-gb-annually", - BasePrice = 40, - AdditionalStoragePricePerGb = 4, - - AllowSeatAutoscale = false, - }, - new Plan - { - Type = PlanType.TeamsAnnually, - Product = ProductType.Teams, - Name = "Teams (Annually)", - IsAnnual = true, - NameLocalizationKey = "planNameTeams", - DescriptionLocalizationKey = "planDescTeams", - CanBeUsedByBusiness = true, - BaseStorageGb = 1, - BaseSeats = 0, - - HasAdditionalSeatsOption = true, - HasAdditionalStorageOption = true, - TrialPeriodDays = 7, - - Has2fa = true, - HasApi = true, - HasDirectory = true, - HasEvents = true, - HasGroups = true, - HasTotp = true, - UsersGetPremium = true, - - UpgradeSortOrder = 2, - DisplaySortOrder = 2, - - StripeSeatPlanId = "2020-teams-org-seat-annually", - StripeStoragePlanId = "storage-gb-annually", - SeatPrice = 36, - AdditionalStoragePricePerGb = 4, - - AllowSeatAutoscale = true, - }, - new Plan - { - Type = PlanType.TeamsMonthly, - Product = ProductType.Teams, - Name = "Teams (Monthly)", - NameLocalizationKey = "planNameTeams", - DescriptionLocalizationKey = "planDescTeams", - CanBeUsedByBusiness = true, - BaseStorageGb = 1, - BaseSeats = 0, - - HasAdditionalSeatsOption = true, - HasAdditionalStorageOption = true, - TrialPeriodDays = 7, - - Has2fa = true, - HasApi = true, - HasDirectory = true, - HasEvents = true, - HasGroups = true, - HasTotp = true, - UsersGetPremium = true, - - UpgradeSortOrder = 2, - DisplaySortOrder = 2, - - StripeSeatPlanId = "2020-teams-org-seat-monthly", - StripeStoragePlanId = "storage-gb-monthly", - SeatPrice = 4, - AdditionalStoragePricePerGb = 0.5M, - - AllowSeatAutoscale = true, - }, - new Plan - { - Type = PlanType.EnterpriseAnnually, - Name = "Enterprise (Annually)", - Product = ProductType.Enterprise, - IsAnnual = true, - NameLocalizationKey = "planNameEnterprise", - DescriptionLocalizationKey = "planDescEnterprise", - CanBeUsedByBusiness = true, - BaseSeats = 0, - BaseStorageGb = 1, - - HasAdditionalSeatsOption = true, - HasAdditionalStorageOption = true, - TrialPeriodDays = 7, - - HasPolicies = true, - HasSelfHost = true, - HasGroups = true, - HasDirectory = true, - HasEvents = true, - HasTotp = true, - Has2fa = true, - HasApi = true, - HasSso = true, - HasKeyConnector = true, - HasScim = true, - HasResetPassword = true, - UsersGetPremium = true, - - UpgradeSortOrder = 3, - DisplaySortOrder = 3, - - StripeSeatPlanId = "2020-enterprise-org-seat-annually", - StripeStoragePlanId = "storage-gb-annually", - BasePrice = 0, - SeatPrice = 60, - AdditionalStoragePricePerGb = 4, - - AllowSeatAutoscale = true, - }, - new Plan - { - Type = PlanType.EnterpriseMonthly, - Product = ProductType.Enterprise, - Name = "Enterprise (Monthly)", - NameLocalizationKey = "planNameEnterprise", - DescriptionLocalizationKey = "planDescEnterprise", - CanBeUsedByBusiness = true, - BaseSeats = 0, - BaseStorageGb = 1, - - HasAdditionalSeatsOption = true, - HasAdditionalStorageOption = true, - TrialPeriodDays = 7, - - HasPolicies = true, - HasGroups = true, - HasDirectory = true, - HasEvents = true, - HasTotp = true, - Has2fa = true, - HasApi = true, - HasSelfHost = true, - HasSso = true, - HasKeyConnector = true, - HasScim = true, - HasResetPassword = true, - UsersGetPremium = true, - - UpgradeSortOrder = 3, - DisplaySortOrder = 3, - - StripeSeatPlanId = "2020-enterprise-org-seat-monthly", - StripeStoragePlanId = "storage-gb-monthly", - BasePrice = 0, - SeatPrice = 6, - AdditionalStoragePricePerGb = 0.5M, - - AllowSeatAutoscale = true, - }, - new Plan - { - Type = PlanType.Custom, - - AllowSeatAutoscale = true, - }, - }; - - #endregion - } - - public static IDictionary> GlobalDomains { get; set; } - public static IEnumerable Plans { get; set; } - public static IEnumerable SponsoredPlans { get; set; } = new[] + AllowSeatAutoscale = false, + }, + new Plan { - new SponsoredPlan - { - PlanSponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, - SponsoredProductType = ProductType.Families, - SponsoringProductType = ProductType.Enterprise, - StripePlanId = "2021-family-for-enterprise-annually", - UsersCanSponsor = (OrganizationUserOrganizationDetails org) => - GetPlan(org.PlanType).Product == ProductType.Enterprise, - } - }; - public static Plan GetPlan(PlanType planType) => - Plans.FirstOrDefault(p => p.Type == planType); - public static SponsoredPlan GetSponsoredPlan(PlanSponsorshipType planSponsorshipType) => - SponsoredPlans.FirstOrDefault(p => p.PlanSponsorshipType == planSponsorshipType); + Type = PlanType.FamiliesAnnually2019, + Product = ProductType.Families, + Name = "Families 2019", + IsAnnual = true, + NameLocalizationKey = "planNameFamilies", + DescriptionLocalizationKey = "planDescFamilies", + BaseSeats = 5, + BaseStorageGb = 1, + MaxUsers = 5, + + HasAdditionalStorageOption = true, + HasPremiumAccessOption = true, + TrialPeriodDays = 7, + + HasSelfHost = true, + HasTotp = true, + + UpgradeSortOrder = 1, + DisplaySortOrder = 1, + LegacyYear = 2020, + + StripePlanId = "personal-org-annually", + StripeStoragePlanId = "storage-gb-annually", + StripePremiumAccessPlanId = "personal-org-premium-access-annually", + BasePrice = 12, + AdditionalStoragePricePerGb = 4, + PremiumAccessOptionPrice = 40, + + AllowSeatAutoscale = false, + }, + new Plan + { + Type = PlanType.TeamsAnnually2019, + Product = ProductType.Teams, + Name = "Teams (Annually) 2019", + IsAnnual = true, + NameLocalizationKey = "planNameTeams", + DescriptionLocalizationKey = "planDescTeams", + CanBeUsedByBusiness = true, + BaseSeats = 5, + BaseStorageGb = 1, + + HasAdditionalSeatsOption = true, + HasAdditionalStorageOption = true, + TrialPeriodDays = 7, + + HasTotp = true, + + UpgradeSortOrder = 2, + DisplaySortOrder = 2, + LegacyYear = 2020, + + StripePlanId = "teams-org-annually", + StripeSeatPlanId = "teams-org-seat-annually", + StripeStoragePlanId = "storage-gb-annually", + BasePrice = 60, + SeatPrice = 24, + AdditionalStoragePricePerGb = 4, + + AllowSeatAutoscale = true, + }, + new Plan + { + Type = PlanType.TeamsMonthly2019, + Product = ProductType.Teams, + Name = "Teams (Monthly) 2019", + NameLocalizationKey = "planNameTeams", + DescriptionLocalizationKey = "planDescTeams", + CanBeUsedByBusiness = true, + BaseSeats = 5, + BaseStorageGb = 1, + + HasAdditionalSeatsOption = true, + HasAdditionalStorageOption = true, + TrialPeriodDays = 7, + + HasTotp = true, + + UpgradeSortOrder = 2, + DisplaySortOrder = 2, + LegacyYear = 2020, + + StripePlanId = "teams-org-monthly", + StripeSeatPlanId = "teams-org-seat-monthly", + StripeStoragePlanId = "storage-gb-monthly", + BasePrice = 8, + SeatPrice = 2.5M, + AdditionalStoragePricePerGb = 0.5M, + + AllowSeatAutoscale = true, + }, + new Plan + { + Type = PlanType.EnterpriseAnnually2019, + Name = "Enterprise (Annually) 2019", + IsAnnual = true, + Product = ProductType.Enterprise, + NameLocalizationKey = "planNameEnterprise", + DescriptionLocalizationKey = "planDescEnterprise", + CanBeUsedByBusiness = true, + BaseSeats = 0, + BaseStorageGb = 1, + + HasAdditionalSeatsOption = true, + HasAdditionalStorageOption = true, + TrialPeriodDays = 7, + + HasPolicies = true, + HasSelfHost = true, + HasGroups = true, + HasDirectory = true, + HasEvents = true, + HasTotp = true, + Has2fa = true, + HasApi = true, + UsersGetPremium = true, + + UpgradeSortOrder = 3, + DisplaySortOrder = 3, + LegacyYear = 2020, + + StripePlanId = null, + StripeSeatPlanId = "enterprise-org-seat-annually", + StripeStoragePlanId = "storage-gb-annually", + BasePrice = 0, + SeatPrice = 36, + AdditionalStoragePricePerGb = 4, + + AllowSeatAutoscale = true, + }, + new Plan + { + Type = PlanType.EnterpriseMonthly2019, + Product = ProductType.Enterprise, + Name = "Enterprise (Monthly) 2019", + NameLocalizationKey = "planNameEnterprise", + DescriptionLocalizationKey = "planDescEnterprise", + CanBeUsedByBusiness = true, + BaseSeats = 0, + BaseStorageGb = 1, + + HasAdditionalSeatsOption = true, + HasAdditionalStorageOption = true, + TrialPeriodDays = 7, + + HasPolicies = true, + HasGroups = true, + HasDirectory = true, + HasEvents = true, + HasTotp = true, + Has2fa = true, + HasApi = true, + HasSelfHost = true, + UsersGetPremium = true, + + UpgradeSortOrder = 3, + DisplaySortOrder = 3, + LegacyYear = 2020, + + StripePlanId = null, + StripeSeatPlanId = "enterprise-org-seat-monthly", + StripeStoragePlanId = "storage-gb-monthly", + BasePrice = 0, + SeatPrice = 4M, + AdditionalStoragePricePerGb = 0.5M, + + AllowSeatAutoscale = true, + }, + new Plan + { + Type = PlanType.FamiliesAnnually, + Product = ProductType.Families, + Name = "Families", + IsAnnual = true, + NameLocalizationKey = "planNameFamilies", + DescriptionLocalizationKey = "planDescFamilies", + BaseSeats = 6, + BaseStorageGb = 1, + MaxUsers = 6, + + HasAdditionalStorageOption = true, + TrialPeriodDays = 7, + + HasSelfHost = true, + HasTotp = true, + UsersGetPremium = true, + + UpgradeSortOrder = 1, + DisplaySortOrder = 1, + + StripePlanId = "2020-families-org-annually", + StripeStoragePlanId = "storage-gb-annually", + BasePrice = 40, + AdditionalStoragePricePerGb = 4, + + AllowSeatAutoscale = false, + }, + new Plan + { + Type = PlanType.TeamsAnnually, + Product = ProductType.Teams, + Name = "Teams (Annually)", + IsAnnual = true, + NameLocalizationKey = "planNameTeams", + DescriptionLocalizationKey = "planDescTeams", + CanBeUsedByBusiness = true, + BaseStorageGb = 1, + BaseSeats = 0, + + HasAdditionalSeatsOption = true, + HasAdditionalStorageOption = true, + TrialPeriodDays = 7, + + Has2fa = true, + HasApi = true, + HasDirectory = true, + HasEvents = true, + HasGroups = true, + HasTotp = true, + UsersGetPremium = true, + + UpgradeSortOrder = 2, + DisplaySortOrder = 2, + + StripeSeatPlanId = "2020-teams-org-seat-annually", + StripeStoragePlanId = "storage-gb-annually", + SeatPrice = 36, + AdditionalStoragePricePerGb = 4, + + AllowSeatAutoscale = true, + }, + new Plan + { + Type = PlanType.TeamsMonthly, + Product = ProductType.Teams, + Name = "Teams (Monthly)", + NameLocalizationKey = "planNameTeams", + DescriptionLocalizationKey = "planDescTeams", + CanBeUsedByBusiness = true, + BaseStorageGb = 1, + BaseSeats = 0, + + HasAdditionalSeatsOption = true, + HasAdditionalStorageOption = true, + TrialPeriodDays = 7, + + Has2fa = true, + HasApi = true, + HasDirectory = true, + HasEvents = true, + HasGroups = true, + HasTotp = true, + UsersGetPremium = true, + + UpgradeSortOrder = 2, + DisplaySortOrder = 2, + + StripeSeatPlanId = "2020-teams-org-seat-monthly", + StripeStoragePlanId = "storage-gb-monthly", + SeatPrice = 4, + AdditionalStoragePricePerGb = 0.5M, + + AllowSeatAutoscale = true, + }, + new Plan + { + Type = PlanType.EnterpriseAnnually, + Name = "Enterprise (Annually)", + Product = ProductType.Enterprise, + IsAnnual = true, + NameLocalizationKey = "planNameEnterprise", + DescriptionLocalizationKey = "planDescEnterprise", + CanBeUsedByBusiness = true, + BaseSeats = 0, + BaseStorageGb = 1, + + HasAdditionalSeatsOption = true, + HasAdditionalStorageOption = true, + TrialPeriodDays = 7, + + HasPolicies = true, + HasSelfHost = true, + HasGroups = true, + HasDirectory = true, + HasEvents = true, + HasTotp = true, + Has2fa = true, + HasApi = true, + HasSso = true, + HasKeyConnector = true, + HasScim = true, + HasResetPassword = true, + UsersGetPremium = true, + + UpgradeSortOrder = 3, + DisplaySortOrder = 3, + + StripeSeatPlanId = "2020-enterprise-org-seat-annually", + StripeStoragePlanId = "storage-gb-annually", + BasePrice = 0, + SeatPrice = 60, + AdditionalStoragePricePerGb = 4, + + AllowSeatAutoscale = true, + }, + new Plan + { + Type = PlanType.EnterpriseMonthly, + Product = ProductType.Enterprise, + Name = "Enterprise (Monthly)", + NameLocalizationKey = "planNameEnterprise", + DescriptionLocalizationKey = "planDescEnterprise", + CanBeUsedByBusiness = true, + BaseSeats = 0, + BaseStorageGb = 1, + + HasAdditionalSeatsOption = true, + HasAdditionalStorageOption = true, + TrialPeriodDays = 7, + + HasPolicies = true, + HasGroups = true, + HasDirectory = true, + HasEvents = true, + HasTotp = true, + Has2fa = true, + HasApi = true, + HasSelfHost = true, + HasSso = true, + HasKeyConnector = true, + HasScim = true, + HasResetPassword = true, + UsersGetPremium = true, + + UpgradeSortOrder = 3, + DisplaySortOrder = 3, + + StripeSeatPlanId = "2020-enterprise-org-seat-monthly", + StripeStoragePlanId = "storage-gb-monthly", + BasePrice = 0, + SeatPrice = 6, + AdditionalStoragePricePerGb = 0.5M, + + AllowSeatAutoscale = true, + }, + new Plan + { + Type = PlanType.Custom, + + AllowSeatAutoscale = true, + }, + }; + + #endregion } + + public static IDictionary> GlobalDomains { get; set; } + public static IEnumerable Plans { get; set; } + public static IEnumerable SponsoredPlans { get; set; } = new[] + { + new SponsoredPlan + { + PlanSponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, + SponsoredProductType = ProductType.Families, + SponsoringProductType = ProductType.Enterprise, + StripePlanId = "2021-family-for-enterprise-annually", + UsersCanSponsor = (OrganizationUserOrganizationDetails org) => + GetPlan(org.PlanType).Product == ProductType.Enterprise, + } + }; + public static Plan GetPlan(PlanType planType) => + Plans.FirstOrDefault(p => p.Type == planType); + public static SponsoredPlan GetSponsoredPlan(PlanSponsorshipType planSponsorshipType) => + SponsoredPlans.FirstOrDefault(p => p.PlanSponsorshipType == planSponsorshipType); } diff --git a/src/Core/Utilities/StrictEmailAddressAttribute.cs b/src/Core/Utilities/StrictEmailAddressAttribute.cs index 15347ab83..f84e41852 100644 --- a/src/Core/Utilities/StrictEmailAddressAttribute.cs +++ b/src/Core/Utilities/StrictEmailAddressAttribute.cs @@ -2,52 +2,51 @@ using System.Text.RegularExpressions; using MimeKit; -namespace Bit.Core.Utilities +namespace Bit.Core.Utilities; + +public class StrictEmailAddressAttribute : ValidationAttribute { - public class StrictEmailAddressAttribute : ValidationAttribute + public StrictEmailAddressAttribute() + : base("The {0} field is not a supported e-mail address format.") + { } + + public override bool IsValid(object value) { - public StrictEmailAddressAttribute() - : base("The {0} field is not a supported e-mail address format.") - { } - - public override bool IsValid(object value) + var emailAddress = value?.ToString(); + if (emailAddress == null) { - var emailAddress = value?.ToString(); - if (emailAddress == null) - { - return false; - } - - try - { - var parsedEmailAddress = MailboxAddress.Parse(emailAddress).Address; - if (parsedEmailAddress != emailAddress) - { - return false; - } - } - catch (ParseException) - { - return false; - } - - /** - The regex below is intended to catch edge cases that are not handled by the general parsing check above. - This enforces the following rules: - * Requires ASCII only in the local-part (code points 0-127) - * Requires an @ symbol - * Allows any char in second-level domain name, including unicode and symbols - * Requires at least one period (.) separating SLD from TLD - * Must end in a letter (including unicode) - See the unit tests for examples of what is allowed. - **/ - var emailFormat = @"^[\x00-\x7F]+@.+\.\p{L}+$"; - if (!Regex.IsMatch(emailAddress, emailFormat)) - { - return false; - } - - return new EmailAddressAttribute().IsValid(emailAddress); + return false; } + + try + { + var parsedEmailAddress = MailboxAddress.Parse(emailAddress).Address; + if (parsedEmailAddress != emailAddress) + { + return false; + } + } + catch (ParseException) + { + return false; + } + + /** + The regex below is intended to catch edge cases that are not handled by the general parsing check above. + This enforces the following rules: + * Requires ASCII only in the local-part (code points 0-127) + * Requires an @ symbol + * Allows any char in second-level domain name, including unicode and symbols + * Requires at least one period (.) separating SLD from TLD + * Must end in a letter (including unicode) + See the unit tests for examples of what is allowed. + **/ + var emailFormat = @"^[\x00-\x7F]+@.+\.\p{L}+$"; + if (!Regex.IsMatch(emailAddress, emailFormat)) + { + return false; + } + + return new EmailAddressAttribute().IsValid(emailAddress); } } diff --git a/src/Core/Utilities/StrictEmailAddressListAttribute.cs b/src/Core/Utilities/StrictEmailAddressListAttribute.cs index dcff171cd..456980397 100644 --- a/src/Core/Utilities/StrictEmailAddressListAttribute.cs +++ b/src/Core/Utilities/StrictEmailAddressListAttribute.cs @@ -1,39 +1,38 @@ using System.ComponentModel.DataAnnotations; -namespace Bit.Core.Utilities +namespace Bit.Core.Utilities; + +public class StrictEmailAddressListAttribute : ValidationAttribute { - public class StrictEmailAddressListAttribute : ValidationAttribute + protected override ValidationResult IsValid(object value, ValidationContext validationContext) { - protected override ValidationResult IsValid(object value, ValidationContext validationContext) + var strictEmailAttribute = new StrictEmailAddressAttribute(); + var emails = value as IList; + + if (!emails?.Any() ?? true) { - var strictEmailAttribute = new StrictEmailAddressAttribute(); - var emails = value as IList; - - if (!emails?.Any() ?? true) - { - return new ValidationResult("An email is required."); - } - - if (emails.Count() > 20) - { - return new ValidationResult("You can only submit up to 20 emails at a time."); - } - - for (var i = 0; i < emails.Count(); i++) - { - var email = emails.ElementAt(i); - if (!strictEmailAttribute.IsValid(email)) - { - return new ValidationResult($"Email #{i + 1} is not valid."); - } - - if (email.Length > 256) - { - return new ValidationResult($"Email #{i + 1} is longer than 256 characters."); - } - } - - return ValidationResult.Success; + return new ValidationResult("An email is required."); } + + if (emails.Count() > 20) + { + return new ValidationResult("You can only submit up to 20 emails at a time."); + } + + for (var i = 0; i < emails.Count(); i++) + { + var email = emails.ElementAt(i); + if (!strictEmailAttribute.IsValid(email)) + { + return new ValidationResult($"Email #{i + 1} is not valid."); + } + + if (email.Length > 256) + { + return new ValidationResult($"Email #{i + 1} is longer than 256 characters."); + } + } + + return ValidationResult.Success; } } diff --git a/src/Events/Controllers/CollectController.cs b/src/Events/Controllers/CollectController.cs index f2599d26e..aaed0b358 100644 --- a/src/Events/Controllers/CollectController.cs +++ b/src/Events/Controllers/CollectController.cs @@ -8,99 +8,98 @@ using Bit.Events.Models; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Events.Controllers +namespace Bit.Events.Controllers; + +[Route("collect")] +[Authorize("Application")] +public class CollectController : Controller { - [Route("collect")] - [Authorize("Application")] - public class CollectController : Controller + private readonly ICurrentContext _currentContext; + private readonly IEventService _eventService; + private readonly ICipherRepository _cipherRepository; + private readonly IOrganizationRepository _organizationRepository; + + public CollectController( + ICurrentContext currentContext, + IEventService eventService, + ICipherRepository cipherRepository, + IOrganizationRepository organizationRepository) { - private readonly ICurrentContext _currentContext; - private readonly IEventService _eventService; - private readonly ICipherRepository _cipherRepository; - private readonly IOrganizationRepository _organizationRepository; + _currentContext = currentContext; + _eventService = eventService; + _cipherRepository = cipherRepository; + _organizationRepository = organizationRepository; + } - public CollectController( - ICurrentContext currentContext, - IEventService eventService, - ICipherRepository cipherRepository, - IOrganizationRepository organizationRepository) + [HttpPost] + public async Task Post([FromBody] IEnumerable model) + { + if (model == null || !model.Any()) { - _currentContext = currentContext; - _eventService = eventService; - _cipherRepository = cipherRepository; - _organizationRepository = organizationRepository; + return new BadRequestResult(); } - - [HttpPost] - public async Task Post([FromBody] IEnumerable model) + var cipherEvents = new List>(); + var ciphersCache = new Dictionary(); + foreach (var eventModel in model) { - if (model == null || !model.Any()) + switch (eventModel.Type) { - return new BadRequestResult(); - } - var cipherEvents = new List>(); - var ciphersCache = new Dictionary(); - foreach (var eventModel in model) - { - switch (eventModel.Type) - { - // User events - case EventType.User_ClientExportedVault: - await _eventService.LogUserEventAsync(_currentContext.UserId.Value, eventModel.Type, eventModel.Date); - break; - // Cipher events - case EventType.Cipher_ClientAutofilled: - case EventType.Cipher_ClientCopiedHiddenField: - case EventType.Cipher_ClientCopiedPassword: - case EventType.Cipher_ClientCopiedCardCode: - case EventType.Cipher_ClientToggledCardCodeVisible: - case EventType.Cipher_ClientToggledHiddenFieldVisible: - case EventType.Cipher_ClientToggledPasswordVisible: - case EventType.Cipher_ClientViewed: - if (!eventModel.CipherId.HasValue) - { - continue; - } - Cipher cipher = null; - if (ciphersCache.ContainsKey(eventModel.CipherId.Value)) - { - cipher = ciphersCache[eventModel.CipherId.Value]; - } - else - { - cipher = await _cipherRepository.GetByIdAsync(eventModel.CipherId.Value, - _currentContext.UserId.Value); - } - if (cipher == null) - { - continue; - } - if (!ciphersCache.ContainsKey(eventModel.CipherId.Value)) - { - ciphersCache.Add(eventModel.CipherId.Value, cipher); - } - cipherEvents.Add(new Tuple(cipher, eventModel.Type, eventModel.Date)); - break; - case EventType.Organization_ClientExportedVault: - if (!eventModel.OrganizationId.HasValue) - { - continue; - } - var organization = await _organizationRepository.GetByIdAsync(eventModel.OrganizationId.Value); - await _eventService.LogOrganizationEventAsync(organization, eventModel.Type, eventModel.Date); - break; - default: + // User events + case EventType.User_ClientExportedVault: + await _eventService.LogUserEventAsync(_currentContext.UserId.Value, eventModel.Type, eventModel.Date); + break; + // Cipher events + case EventType.Cipher_ClientAutofilled: + case EventType.Cipher_ClientCopiedHiddenField: + case EventType.Cipher_ClientCopiedPassword: + case EventType.Cipher_ClientCopiedCardCode: + case EventType.Cipher_ClientToggledCardCodeVisible: + case EventType.Cipher_ClientToggledHiddenFieldVisible: + case EventType.Cipher_ClientToggledPasswordVisible: + case EventType.Cipher_ClientViewed: + if (!eventModel.CipherId.HasValue) + { continue; - } + } + Cipher cipher = null; + if (ciphersCache.ContainsKey(eventModel.CipherId.Value)) + { + cipher = ciphersCache[eventModel.CipherId.Value]; + } + else + { + cipher = await _cipherRepository.GetByIdAsync(eventModel.CipherId.Value, + _currentContext.UserId.Value); + } + if (cipher == null) + { + continue; + } + if (!ciphersCache.ContainsKey(eventModel.CipherId.Value)) + { + ciphersCache.Add(eventModel.CipherId.Value, cipher); + } + cipherEvents.Add(new Tuple(cipher, eventModel.Type, eventModel.Date)); + break; + case EventType.Organization_ClientExportedVault: + if (!eventModel.OrganizationId.HasValue) + { + continue; + } + var organization = await _organizationRepository.GetByIdAsync(eventModel.OrganizationId.Value); + await _eventService.LogOrganizationEventAsync(organization, eventModel.Type, eventModel.Date); + break; + default: + continue; } - if (cipherEvents.Any()) - { - foreach (var eventsBatch in cipherEvents.Batch(50)) - { - await _eventService.LogCipherEventsAsync(eventsBatch); - } - } - return new OkResult(); } + if (cipherEvents.Any()) + { + foreach (var eventsBatch in cipherEvents.Batch(50)) + { + await _eventService.LogCipherEventsAsync(eventsBatch); + } + } + return new OkResult(); } } diff --git a/src/Events/Controllers/InfoController.cs b/src/Events/Controllers/InfoController.cs index 23234c654..6d42f6757 100644 --- a/src/Events/Controllers/InfoController.cs +++ b/src/Events/Controllers/InfoController.cs @@ -1,21 +1,20 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Events.Controllers -{ - public class InfoController : Controller - { - [HttpGet("~/alive")] - [HttpGet("~/now")] - public DateTime GetAlive() - { - return DateTime.UtcNow; - } +namespace Bit.Events.Controllers; - [HttpGet("~/version")] - public JsonResult GetVersion() - { - return Json(CoreHelpers.GetVersion()); - } +public class InfoController : Controller +{ + [HttpGet("~/alive")] + [HttpGet("~/now")] + public DateTime GetAlive() + { + return DateTime.UtcNow; + } + + [HttpGet("~/version")] + public JsonResult GetVersion() + { + return Json(CoreHelpers.GetVersion()); } } diff --git a/src/Events/Models/EventModel.cs b/src/Events/Models/EventModel.cs index 80b69398a..dc5cef084 100644 --- a/src/Events/Models/EventModel.cs +++ b/src/Events/Models/EventModel.cs @@ -1,12 +1,11 @@ using Bit.Core.Enums; -namespace Bit.Events.Models +namespace Bit.Events.Models; + +public class EventModel { - public class EventModel - { - public EventType Type { get; set; } - public Guid? CipherId { get; set; } - public DateTime Date { get; set; } - public Guid? OrganizationId { get; set; } - } + public EventType Type { get; set; } + public Guid? CipherId { get; set; } + public DateTime Date { get; set; } + public Guid? OrganizationId { get; set; } } diff --git a/src/Events/Program.cs b/src/Events/Program.cs index a6a95646a..74f82cd41 100644 --- a/src/Events/Program.cs +++ b/src/Events/Program.cs @@ -1,40 +1,39 @@ using Bit.Core.Utilities; using Serilog.Events; -namespace Bit.Events +namespace Bit.Events; + +public class Program { - public class Program + public static void Main(string[] args) { - public static void Main(string[] args) - { - Host - .CreateDefaultBuilder(args) - .ConfigureCustomAppConfiguration(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => + Host + .CreateDefaultBuilder(args) + .ConfigureCustomAppConfiguration(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => + { + var context = e.Properties["SourceContext"].ToString(); + if (context.Contains("IdentityServer4.Validation.TokenValidator") || + context.Contains("IdentityServer4.Validation.TokenRequestValidator")) { - var context = e.Properties["SourceContext"].ToString(); - if (context.Contains("IdentityServer4.Validation.TokenValidator") || - context.Contains("IdentityServer4.Validation.TokenRequestValidator")) - { - return e.Level > LogEventLevel.Error; - } + return e.Level > LogEventLevel.Error; + } - if (e.Properties.ContainsKey("RequestPath") && - !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) - { - return false; - } + if (e.Properties.ContainsKey("RequestPath") && + !string.IsNullOrWhiteSpace(e.Properties["RequestPath"]?.ToString()) && + (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) + { + return false; + } - return e.Level >= LogEventLevel.Error; - })); - }) - .Build() - .Run(); - } + return e.Level >= LogEventLevel.Error; + })); + }) + .Build() + .Run(); } } diff --git a/src/Events/Startup.cs b/src/Events/Startup.cs index 7c777abb4..c44ca3c1a 100644 --- a/src/Events/Startup.cs +++ b/src/Events/Startup.cs @@ -6,113 +6,112 @@ using Bit.Core.Utilities; using Bit.SharedWeb.Utilities; using IdentityModel; -namespace Bit.Events +namespace Bit.Events; + +public class Startup { - public class Startup + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Configuration = configuration; + Environment = env; + } + + public IConfiguration Configuration { get; } + public IWebHostEnvironment Environment { get; set; } + + public void ConfigureServices(IServiceCollection services) + { + // Options + services.AddOptions(); + + // Settings + var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); + + // Repositories + services.AddSqlServerRepositories(globalSettings); + + // Context + services.AddScoped(); + + // Identity + services.AddIdentityAuthenticationServices(globalSettings, Environment, config => { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - Configuration = configuration; - Environment = env; + config.AddPolicy("Application", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external"); + policy.RequireClaim(JwtClaimTypes.Scope, "api"); + }); + }); + + // Services + var usingServiceBusAppCache = CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ConnectionString) && + CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ApplicationCacheTopicName); + if (usingServiceBusAppCache) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); + } + services.AddScoped(); + if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Events.ConnectionString)) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); } - public IConfiguration Configuration { get; } - public IWebHostEnvironment Environment { get; set; } - - public void ConfigureServices(IServiceCollection services) + // Mvc + services.AddMvc(config => { - // Options - services.AddOptions(); + config.Filters.Add(new LoggingExceptionHandlerFilterAttribute()); + }); - // Settings - var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); - - // Repositories - services.AddSqlServerRepositories(globalSettings); - - // Context - services.AddScoped(); - - // Identity - services.AddIdentityAuthenticationServices(globalSettings, Environment, config => - { - config.AddPolicy("Application", policy => - { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external"); - policy.RequireClaim(JwtClaimTypes.Scope, "api"); - }); - }); - - // Services - var usingServiceBusAppCache = CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ConnectionString) && - CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ApplicationCacheTopicName); - if (usingServiceBusAppCache) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } - services.AddScoped(); - if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Events.ConnectionString)) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } - - // Mvc - services.AddMvc(config => - { - config.Filters.Add(new LoggingExceptionHandlerFilterAttribute()); - }); - - if (usingServiceBusAppCache) - { - services.AddHostedService(); - } - } - - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings) + if (usingServiceBusAppCache) { - app.UseSerilog(env, appLifetime, globalSettings); - - // Add general security headers - app.UseMiddleware(); - - if (env.IsDevelopment()) - { - app.UseDeveloperExceptionPage(); - } - - // Default Middleware - app.UseDefaultMiddleware(env, globalSettings); - - // Add routing - app.UseRouting(); - - // Add Cors - app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) - .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); - - // Add authentication and authorization to the request pipeline. - app.UseAuthentication(); - app.UseAuthorization(); - - // Add current context - app.UseMiddleware(); - - // Add MVC to the request pipeline. - app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); + services.AddHostedService(); } } + + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings) + { + app.UseSerilog(env, appLifetime, globalSettings); + + // Add general security headers + app.UseMiddleware(); + + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); + } + + // Default Middleware + app.UseDefaultMiddleware(env, globalSettings); + + // Add routing + app.UseRouting(); + + // Add Cors + app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) + .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); + + // Add authentication and authorization to the request pipeline. + app.UseAuthentication(); + app.UseAuthorization(); + + // Add current context + app.UseMiddleware(); + + // Add MVC to the request pipeline. + app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); + } } diff --git a/src/EventsProcessor/AzureQueueHostedService.cs b/src/EventsProcessor/AzureQueueHostedService.cs index 41f203e9d..837e4ad14 100644 --- a/src/EventsProcessor/AzureQueueHostedService.cs +++ b/src/EventsProcessor/AzureQueueHostedService.cs @@ -5,122 +5,121 @@ using Bit.Core.Models.Data; using Bit.Core.Services; using Bit.Core.Utilities; -namespace Bit.EventsProcessor +namespace Bit.EventsProcessor; + +public class AzureQueueHostedService : IHostedService, IDisposable { - public class AzureQueueHostedService : IHostedService, IDisposable + private readonly ILogger _logger; + private readonly IConfiguration _configuration; + + private Task _executingTask; + private CancellationTokenSource _cts; + private QueueClient _queueClient; + private IEventWriteService _eventWriteService; + + public AzureQueueHostedService( + ILogger logger, + IConfiguration configuration) { - private readonly ILogger _logger; - private readonly IConfiguration _configuration; + _logger = logger; + _configuration = configuration; + } - private Task _executingTask; - private CancellationTokenSource _cts; - private QueueClient _queueClient; - private IEventWriteService _eventWriteService; + public Task StartAsync(CancellationToken cancellationToken) + { + _logger.LogInformation(Constants.BypassFiltersEventId, "Starting service."); + _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _executingTask = ExecuteAsync(_cts.Token); + return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; + } - public AzureQueueHostedService( - ILogger logger, - IConfiguration configuration) + public async Task StopAsync(CancellationToken cancellationToken) + { + if (_executingTask == null) { - _logger = logger; - _configuration = configuration; + return; + } + _logger.LogWarning("Stopping service."); + _cts.Cancel(); + await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); + cancellationToken.ThrowIfCancellationRequested(); + } + + public void Dispose() + { } + + private async Task ExecuteAsync(CancellationToken cancellationToken) + { + var storageConnectionString = _configuration["azureStorageConnectionString"]; + if (string.IsNullOrWhiteSpace(storageConnectionString)) + { + return; } - public Task StartAsync(CancellationToken cancellationToken) + var repo = new Core.Repositories.TableStorage.EventRepository(storageConnectionString); + _eventWriteService = new RepositoryEventWriteService(repo); + _queueClient = new QueueClient(storageConnectionString, "event"); + + while (!cancellationToken.IsCancellationRequested) { - _logger.LogInformation(Constants.BypassFiltersEventId, "Starting service."); - _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - _executingTask = ExecuteAsync(_cts.Token); - return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; - } - - public async Task StopAsync(CancellationToken cancellationToken) - { - if (_executingTask == null) + try { - return; - } - _logger.LogWarning("Stopping service."); - _cts.Cancel(); - await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); - cancellationToken.ThrowIfCancellationRequested(); - } - - public void Dispose() - { } - - private async Task ExecuteAsync(CancellationToken cancellationToken) - { - var storageConnectionString = _configuration["azureStorageConnectionString"]; - if (string.IsNullOrWhiteSpace(storageConnectionString)) - { - return; - } - - var repo = new Core.Repositories.TableStorage.EventRepository(storageConnectionString); - _eventWriteService = new RepositoryEventWriteService(repo); - _queueClient = new QueueClient(storageConnectionString, "event"); - - while (!cancellationToken.IsCancellationRequested) - { - try + var messages = await _queueClient.ReceiveMessagesAsync(32); + if (messages.Value?.Any() ?? false) { - var messages = await _queueClient.ReceiveMessagesAsync(32); - if (messages.Value?.Any() ?? false) + foreach (var message in messages.Value) { - foreach (var message in messages.Value) - { - await ProcessQueueMessageAsync(message.DecodeMessageText(), cancellationToken); - await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); - } - } - else - { - await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken); + await ProcessQueueMessageAsync(message.DecodeMessageText(), cancellationToken); + await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); } } - catch (Exception e) + else { - _logger.LogError(e, "Exception occurred: " + e.Message); await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken); } } - - _logger.LogWarning("Done processing."); + catch (Exception e) + { + _logger.LogError(e, "Exception occurred: " + e.Message); + await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken); + } } - public async Task ProcessQueueMessageAsync(string message, CancellationToken cancellationToken) + _logger.LogWarning("Done processing."); + } + + public async Task ProcessQueueMessageAsync(string message, CancellationToken cancellationToken) + { + if (_eventWriteService == null || message == null || message.Length == 0) { - if (_eventWriteService == null || message == null || message.Length == 0) + return; + } + + try + { + _logger.LogInformation("Processing message."); + var events = new List(); + + using var jsonDocument = JsonDocument.Parse(message); + var root = jsonDocument.RootElement; + if (root.ValueKind == JsonValueKind.Array) { - return; + var indexedEntities = root.ToObject>() + .SelectMany(e => EventTableEntity.IndexEvent(e)); + events.AddRange(indexedEntities); + } + else if (root.ValueKind == JsonValueKind.Object) + { + var eventMessage = root.ToObject(); + events.AddRange(EventTableEntity.IndexEvent(eventMessage)); } - try - { - _logger.LogInformation("Processing message."); - var events = new List(); - - using var jsonDocument = JsonDocument.Parse(message); - var root = jsonDocument.RootElement; - if (root.ValueKind == JsonValueKind.Array) - { - var indexedEntities = root.ToObject>() - .SelectMany(e => EventTableEntity.IndexEvent(e)); - events.AddRange(indexedEntities); - } - else if (root.ValueKind == JsonValueKind.Object) - { - var eventMessage = root.ToObject(); - events.AddRange(EventTableEntity.IndexEvent(eventMessage)); - } - - await _eventWriteService.CreateManyAsync(events); - _logger.LogInformation("Processed message."); - } - catch (JsonException) - { - _logger.LogError("JsonReaderException: Unable to parse message."); - } + await _eventWriteService.CreateManyAsync(events); + _logger.LogInformation("Processed message."); + } + catch (JsonException) + { + _logger.LogError("JsonReaderException: Unable to parse message."); } } } diff --git a/src/EventsProcessor/Program.cs b/src/EventsProcessor/Program.cs index a63c7742c..0cf2d17fa 100644 --- a/src/EventsProcessor/Program.cs +++ b/src/EventsProcessor/Program.cs @@ -1,22 +1,21 @@ using Bit.Core.Utilities; using Serilog.Events; -namespace Bit.EventsProcessor +namespace Bit.EventsProcessor; + +public class Program { - public class Program + public static void Main(string[] args) { - public static void Main(string[] args) - { - Host - .CreateDefaultBuilder(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => e.Level >= LogEventLevel.Warning)); - }) - .Build() - .Run(); - } + Host + .CreateDefaultBuilder(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => e.Level >= LogEventLevel.Warning)); + }) + .Build() + .Run(); } } diff --git a/src/EventsProcessor/Startup.cs b/src/EventsProcessor/Startup.cs index e995816a0..d0a624f73 100644 --- a/src/EventsProcessor/Startup.cs +++ b/src/EventsProcessor/Startup.cs @@ -4,53 +4,52 @@ using Bit.Core.Utilities; using Bit.SharedWeb.Utilities; using Microsoft.IdentityModel.Logging; -namespace Bit.EventsProcessor +namespace Bit.EventsProcessor; + +public class Startup { - public class Startup + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Configuration = configuration; + Environment = env; + } + + public IConfiguration Configuration { get; } + public IWebHostEnvironment Environment { get; set; } + + public void ConfigureServices(IServiceCollection services) + { + // Options + services.AddOptions(); + + // Settings + services.AddGlobalSettingsServices(Configuration, Environment); + + // Hosted Services + services.AddHostedService(); + } + + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings) + { + IdentityModelEventSource.ShowPII = true; + app.UseSerilog(env, appLifetime, globalSettings); + // Add general security headers + app.UseMiddleware(); + app.UseRouting(); + app.UseEndpoints(endpoints => { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - Configuration = configuration; - Environment = env; - } + endpoints.MapGet("/alive", + async context => await context.Response.WriteAsJsonAsync(System.DateTime.UtcNow)); + endpoints.MapGet("/now", + async context => await context.Response.WriteAsJsonAsync(System.DateTime.UtcNow)); + endpoints.MapGet("/version", + async context => await context.Response.WriteAsJsonAsync(CoreHelpers.GetVersion())); - public IConfiguration Configuration { get; } - public IWebHostEnvironment Environment { get; set; } - - public void ConfigureServices(IServiceCollection services) - { - // Options - services.AddOptions(); - - // Settings - services.AddGlobalSettingsServices(Configuration, Environment); - - // Hosted Services - services.AddHostedService(); - } - - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings) - { - IdentityModelEventSource.ShowPII = true; - app.UseSerilog(env, appLifetime, globalSettings); - // Add general security headers - app.UseMiddleware(); - app.UseRouting(); - app.UseEndpoints(endpoints => - { - endpoints.MapGet("/alive", - async context => await context.Response.WriteAsJsonAsync(System.DateTime.UtcNow)); - endpoints.MapGet("/now", - async context => await context.Response.WriteAsJsonAsync(System.DateTime.UtcNow)); - endpoints.MapGet("/version", - async context => await context.Response.WriteAsJsonAsync(CoreHelpers.GetVersion())); - - }); - } + }); } } diff --git a/src/Icons/Controllers/IconsController.cs b/src/Icons/Controllers/IconsController.cs index 5e27ece56..ad9b6cfd4 100644 --- a/src/Icons/Controllers/IconsController.cs +++ b/src/Icons/Controllers/IconsController.cs @@ -3,106 +3,105 @@ using Bit.Icons.Services; using Microsoft.AspNetCore.Mvc; using Microsoft.Extensions.Caching.Memory; -namespace Bit.Icons.Controllers +namespace Bit.Icons.Controllers; + +[Route("")] +public class IconsController : Controller { - [Route("")] - public class IconsController : Controller + // Basic bwi-globe icon + private static readonly byte[] _notFoundImage = Convert.FromBase64String("iVBORw0KGgoAAAANSUhEUg" + + "AAABMAAAATCAQAAADYWf5HAAABu0lEQVR42nXSvWuTURTH8R+t0heI9Y04aJycdBLNJNrBFBU7OFgUER3q21I0bXK+JwZ" + + "pXISm/QdcRB3EgqBBsNihsUbbgODQQSKCuKSDOApJuuhj8tCYQj/jvYfD795z1MZ+nBKrNKhSwrMxbZTrtRnqlEjZkB/x" + + "C/xmhZrlc71qS0Up8yVzTCGucFNKD1JhORVd70SZNU4okNx5d4+U2UXRIpJFWLClsR79YzN88wQvLWNzzPKEeS/wkQGpW" + + "VhhqhW8TtDJD3Mm1x/23zLSrZCdpBY8BueTNjHSbc+8wC9HlHgU5Aj5AW5zPdcVdpq0UcknWBSr/pjixO4gfp899Kd23p" + + "M2qQCH7LkCnqAqGh73OK/8NPOcaibr90LrW/yWAnaUhqjaOSl9nFR2r5rsqo22ypn1B5IN8VOUMHVgOnNQIX+d62plcz6" + + "rg1/jskK8CMb4we4pG6OWHtR/LBJkC2E4a7ZPkuX5ntumAOM2xxveclEhLvGH6XCmLPs735Eetrw63NnOgr9P9q1viC3x" + + "lRUGOjImqFDuOBvrYYoaZU9z1uPpYae5NfdvbNVG2ZjDIlXq/oMi46lo++4vjjPBl2Dlg00AAAAASUVORK5CYII="); + + private readonly IMemoryCache _memoryCache; + private readonly IDomainMappingService _domainMappingService; + private readonly IIconFetchingService _iconFetchingService; + private readonly ILogger _logger; + private readonly IconsSettings _iconsSettings; + + public IconsController( + IMemoryCache memoryCache, + IDomainMappingService domainMappingService, + IIconFetchingService iconFetchingService, + ILogger logger, + IconsSettings iconsSettings) { - // Basic bwi-globe icon - private static readonly byte[] _notFoundImage = Convert.FromBase64String("iVBORw0KGgoAAAANSUhEUg" + - "AAABMAAAATCAQAAADYWf5HAAABu0lEQVR42nXSvWuTURTH8R+t0heI9Y04aJycdBLNJNrBFBU7OFgUER3q21I0bXK+JwZ" + - "pXISm/QdcRB3EgqBBsNihsUbbgODQQSKCuKSDOApJuuhj8tCYQj/jvYfD795z1MZ+nBKrNKhSwrMxbZTrtRnqlEjZkB/x" + - "C/xmhZrlc71qS0Up8yVzTCGucFNKD1JhORVd70SZNU4okNx5d4+U2UXRIpJFWLClsR79YzN88wQvLWNzzPKEeS/wkQGpW" + - "VhhqhW8TtDJD3Mm1x/23zLSrZCdpBY8BueTNjHSbc+8wC9HlHgU5Aj5AW5zPdcVdpq0UcknWBSr/pjixO4gfp899Kd23p" + - "M2qQCH7LkCnqAqGh73OK/8NPOcaibr90LrW/yWAnaUhqjaOSl9nFR2r5rsqo22ypn1B5IN8VOUMHVgOnNQIX+d62plcz6" + - "rg1/jskK8CMb4we4pG6OWHtR/LBJkC2E4a7ZPkuX5ntumAOM2xxveclEhLvGH6XCmLPs735Eetrw63NnOgr9P9q1viC3x" + - "lRUGOjImqFDuOBvrYYoaZU9z1uPpYae5NfdvbNVG2ZjDIlXq/oMi46lo++4vjjPBl2Dlg00AAAAASUVORK5CYII="); + _memoryCache = memoryCache; + _domainMappingService = domainMappingService; + _iconFetchingService = iconFetchingService; + _logger = logger; + _iconsSettings = iconsSettings; + } - private readonly IMemoryCache _memoryCache; - private readonly IDomainMappingService _domainMappingService; - private readonly IIconFetchingService _iconFetchingService; - private readonly ILogger _logger; - private readonly IconsSettings _iconsSettings; - - public IconsController( - IMemoryCache memoryCache, - IDomainMappingService domainMappingService, - IIconFetchingService iconFetchingService, - ILogger logger, - IconsSettings iconsSettings) + [HttpGet("~/config")] + public IActionResult GetConfig() + { + return new JsonResult(new { - _memoryCache = memoryCache; - _domainMappingService = domainMappingService; - _iconFetchingService = iconFetchingService; - _logger = logger; - _iconsSettings = iconsSettings; + CacheEnabled = _iconsSettings.CacheEnabled, + CacheHours = _iconsSettings.CacheHours, + CacheSizeLimit = _iconsSettings.CacheSizeLimit + }); + } + + [HttpGet("{hostname}/icon.png")] + public async Task Get(string hostname) + { + if (string.IsNullOrWhiteSpace(hostname) || !hostname.Contains(".")) + { + return new BadRequestResult(); } - [HttpGet("~/config")] - public IActionResult GetConfig() + var url = $"http://{hostname}"; + if (!Uri.TryCreate(url, UriKind.Absolute, out var uri)) { - return new JsonResult(new - { - CacheEnabled = _iconsSettings.CacheEnabled, - CacheHours = _iconsSettings.CacheHours, - CacheSizeLimit = _iconsSettings.CacheSizeLimit - }); + return new BadRequestResult(); } - [HttpGet("{hostname}/icon.png")] - public async Task Get(string hostname) + var domain = uri.Host; + // Convert sub.domain.com => domain.com + //if(DomainName.TryParseBaseDomain(domain, out var baseDomain)) + //{ + // domain = baseDomain; + //} + + var mappedDomain = _domainMappingService.MapDomain(domain); + if (!_iconsSettings.CacheEnabled || !_memoryCache.TryGetValue(mappedDomain, out Icon icon)) { - if (string.IsNullOrWhiteSpace(hostname) || !hostname.Contains(".")) + var result = await _iconFetchingService.GetIconAsync(domain); + if (result == null) { - return new BadRequestResult(); + _logger.LogWarning("Null result returned for {0}.", domain); + icon = null; + } + else + { + icon = result.Icon; } - var url = $"http://{hostname}"; - if (!Uri.TryCreate(url, UriKind.Absolute, out var uri)) + // Only cache not found and smaller images (<= 50kb) + if (_iconsSettings.CacheEnabled && (icon == null || icon.Image.Length <= 50012)) { - return new BadRequestResult(); - } - - var domain = uri.Host; - // Convert sub.domain.com => domain.com - //if(DomainName.TryParseBaseDomain(domain, out var baseDomain)) - //{ - // domain = baseDomain; - //} - - var mappedDomain = _domainMappingService.MapDomain(domain); - if (!_iconsSettings.CacheEnabled || !_memoryCache.TryGetValue(mappedDomain, out Icon icon)) - { - var result = await _iconFetchingService.GetIconAsync(domain); - if (result == null) + _logger.LogInformation("Cache icon for {0}.", domain); + _memoryCache.Set(mappedDomain, icon, new MemoryCacheEntryOptions { - _logger.LogWarning("Null result returned for {0}.", domain); - icon = null; - } - else - { - icon = result.Icon; - } - - // Only cache not found and smaller images (<= 50kb) - if (_iconsSettings.CacheEnabled && (icon == null || icon.Image.Length <= 50012)) - { - _logger.LogInformation("Cache icon for {0}.", domain); - _memoryCache.Set(mappedDomain, icon, new MemoryCacheEntryOptions - { - AbsoluteExpirationRelativeToNow = new TimeSpan(_iconsSettings.CacheHours, 0, 0), - Size = icon?.Image.Length ?? 0, - Priority = icon == null ? CacheItemPriority.High : CacheItemPriority.Normal - }); - } + AbsoluteExpirationRelativeToNow = new TimeSpan(_iconsSettings.CacheHours, 0, 0), + Size = icon?.Image.Length ?? 0, + Priority = icon == null ? CacheItemPriority.High : CacheItemPriority.Normal + }); } - - if (icon == null) - { - return new FileContentResult(_notFoundImage, "image/png"); - } - - return new FileContentResult(icon.Image, icon.Format); } + + if (icon == null) + { + return new FileContentResult(_notFoundImage, "image/png"); + } + + return new FileContentResult(icon.Image, icon.Format); } } diff --git a/src/Icons/Controllers/InfoController.cs b/src/Icons/Controllers/InfoController.cs index 47c6ca553..1ebbd473a 100644 --- a/src/Icons/Controllers/InfoController.cs +++ b/src/Icons/Controllers/InfoController.cs @@ -1,21 +1,20 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Icons.Controllers -{ - public class InfoController : Controller - { - [HttpGet("~/alive")] - [HttpGet("~/now")] - public DateTime GetAlive() - { - return DateTime.UtcNow; - } +namespace Bit.Icons.Controllers; - [HttpGet("~/version")] - public JsonResult GetVersion() - { - return Json(CoreHelpers.GetVersion()); - } +public class InfoController : Controller +{ + [HttpGet("~/alive")] + [HttpGet("~/now")] + public DateTime GetAlive() + { + return DateTime.UtcNow; + } + + [HttpGet("~/version")] + public JsonResult GetVersion() + { + return Json(CoreHelpers.GetVersion()); } } diff --git a/src/Icons/IconsSettings.cs b/src/Icons/IconsSettings.cs index e6de86629..7cfd64d11 100644 --- a/src/Icons/IconsSettings.cs +++ b/src/Icons/IconsSettings.cs @@ -1,9 +1,8 @@ -namespace Bit.Icons +namespace Bit.Icons; + +public class IconsSettings { - public class IconsSettings - { - public virtual bool CacheEnabled { get; set; } - public virtual int CacheHours { get; set; } - public virtual long? CacheSizeLimit { get; set; } - } + public virtual bool CacheEnabled { get; set; } + public virtual int CacheHours { get; set; } + public virtual long? CacheSizeLimit { get; set; } } diff --git a/src/Icons/Models/DomainName.cs b/src/Icons/Models/DomainName.cs index ee5a5f0d4..b04011050 100644 --- a/src/Icons/Models/DomainName.cs +++ b/src/Icons/Models/DomainName.cs @@ -2,324 +2,323 @@ using System.Reflection; using System.Text.RegularExpressions; -namespace Bit.Icons.Models +namespace Bit.Icons.Models; + +// ref: https://github.com/danesparza/domainname-parser +public class DomainName { - // ref: https://github.com/danesparza/domainname-parser - public class DomainName + private const string IpRegex = "^(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\." + + "(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\." + + "(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\." + + "(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"; + + private string _subDomain = string.Empty; + private string _domain = string.Empty; + private string _tld = string.Empty; + private TLDRule _tldRule = null; + + public string SubDomain => _subDomain; + public string Domain => _domain; + public string SLD => _domain; + public string TLD => _tld; + public TLDRule Rule => _tldRule; + public string BaseDomain => $"{_domain}.{_tld}"; + + public DomainName(string TLD, string SLD, string SubDomain, TLDRule TLDRule) { - private const string IpRegex = "^(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\." + - "(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\." + - "(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\." + - "(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"; + _tld = TLD; + _domain = SLD; + _subDomain = SubDomain; + _tldRule = TLDRule; + } - private string _subDomain = string.Empty; - private string _domain = string.Empty; - private string _tld = string.Empty; - private TLDRule _tldRule = null; + public static bool TryParse(string domainString, out DomainName result) + { + var retval = false; - public string SubDomain => _subDomain; - public string Domain => _domain; - public string SLD => _domain; - public string TLD => _tld; - public TLDRule Rule => _tldRule; - public string BaseDomain => $"{_domain}.{_tld}"; + // Our temporary domain parts: + var tld = string.Empty; + var sld = string.Empty; + var subdomain = string.Empty; + TLDRule _tldrule = null; + result = null; - public DomainName(string TLD, string SLD, string SubDomain, TLDRule TLDRule) + try { - _tld = TLD; - _domain = SLD; - _subDomain = SubDomain; - _tldRule = TLDRule; + // Try parsing the domain name ... this might throw formatting exceptions + ParseDomainName(domainString, out tld, out sld, out subdomain, out _tldrule); + // Construct a new DomainName object and return it + result = new DomainName(tld, sld, subdomain, _tldrule); + // Return 'true' + retval = true; + } + catch + { + // Looks like something bad happened -- return 'false' + retval = false; } - public static bool TryParse(string domainString, out DomainName result) + return retval; + } + + public static bool TryParseBaseDomain(string domainString, out string result) + { + if (Regex.IsMatch(domainString, IpRegex)) { - var retval = false; - - // Our temporary domain parts: - var tld = string.Empty; - var sld = string.Empty; - var subdomain = string.Empty; - TLDRule _tldrule = null; - result = null; - - try - { - // Try parsing the domain name ... this might throw formatting exceptions - ParseDomainName(domainString, out tld, out sld, out subdomain, out _tldrule); - // Construct a new DomainName object and return it - result = new DomainName(tld, sld, subdomain, _tldrule); - // Return 'true' - retval = true; - } - catch - { - // Looks like something bad happened -- return 'false' - retval = false; - } - - return retval; + result = domainString; + return true; } - public static bool TryParseBaseDomain(string domainString, out string result) - { - if (Regex.IsMatch(domainString, IpRegex)) - { - result = domainString; - return true; - } + DomainName domain; + var retval = TryParse(domainString, out domain); + result = domain?.BaseDomain; + return retval; + } - DomainName domain; - var retval = TryParse(domainString, out domain); - result = domain?.BaseDomain; - return retval; + private static void ParseDomainName(string domainString, out string TLD, out string SLD, + out string SubDomain, out TLDRule MatchingRule) + { + // Make sure domain is all lowercase + domainString = domainString.ToLower(); + + TLD = string.Empty; + SLD = string.Empty; + SubDomain = string.Empty; + MatchingRule = null; + + // If the fqdn is empty, we have a problem already + if (domainString.Trim() == string.Empty) + { + throw new ArgumentException("The domain cannot be blank"); } - private static void ParseDomainName(string domainString, out string TLD, out string SLD, - out string SubDomain, out TLDRule MatchingRule) + // Next, find the matching rule: + MatchingRule = FindMatchingTLDRule(domainString); + + // At this point, no rules match, we have a problem + if (MatchingRule == null) { - // Make sure domain is all lowercase - domainString = domainString.ToLower(); + throw new FormatException("The domain does not have a recognized TLD"); + } - TLD = string.Empty; - SLD = string.Empty; - SubDomain = string.Empty; - MatchingRule = null; + // Based on the tld rule found, get the domain (and possibly the subdomain) + var tempSudomainAndDomain = string.Empty; + var tldIndex = 0; - // If the fqdn is empty, we have a problem already - if (domainString.Trim() == string.Empty) + // First, determine what type of rule we have, and set the TLD accordingly + switch (MatchingRule.Type) + { + case TLDRule.RuleType.Normal: + tldIndex = domainString.LastIndexOf("." + MatchingRule.Name); + tempSudomainAndDomain = domainString.Substring(0, tldIndex); + TLD = domainString.Substring(tldIndex + 1); + break; + case TLDRule.RuleType.Wildcard: + // This finds the last portion of the TLD... + tldIndex = domainString.LastIndexOf("." + MatchingRule.Name); + tempSudomainAndDomain = domainString.Substring(0, tldIndex); + + // But we need to find the wildcard portion of it: + tldIndex = tempSudomainAndDomain.LastIndexOf("."); + tempSudomainAndDomain = domainString.Substring(0, tldIndex); + TLD = domainString.Substring(tldIndex + 1); + break; + case TLDRule.RuleType.Exception: + tldIndex = domainString.LastIndexOf("."); + tempSudomainAndDomain = domainString.Substring(0, tldIndex); + TLD = domainString.Substring(tldIndex + 1); + break; + } + + // See if we have a subdomain: + List lstRemainingParts = new List(tempSudomainAndDomain.Split('.')); + + // If we have 0 parts left, there is just a tld and no domain or subdomain + // If we have 1 part, it's the domain, and there is no subdomain + // If we have 2+ parts, the last part is the domain, the other parts (combined) are the subdomain + if (lstRemainingParts.Count > 0) + { + // Set the domain: + SLD = lstRemainingParts[lstRemainingParts.Count - 1]; + + // Set the subdomain, if there is one to set: + if (lstRemainingParts.Count > 1) { - throw new ArgumentException("The domain cannot be blank"); + // We strip off the trailing period, too + SubDomain = tempSudomainAndDomain.Substring(0, tempSudomainAndDomain.Length - SLD.Length - 1); + } + } + } + + private static TLDRule FindMatchingTLDRule(string domainString) + { + // Split our domain into parts (based on the '.') + // ...Put these parts in a list + // ...Make sure these parts are in reverse order + // (we'll be checking rules from the right-most pat of the domain) + var lstDomainParts = domainString.Split('.').ToList(); + lstDomainParts.Reverse(); + + // Begin building our partial domain to check rules with: + var checkAgainst = string.Empty; + + // Our 'matches' collection: + var ruleMatches = new List(); + + foreach (string domainPart in lstDomainParts) + { + // Add on our next domain part: + checkAgainst = string.Format("{0}.{1}", domainPart, checkAgainst); + + // If we end in a period, strip it off: + if (checkAgainst.EndsWith(".")) + { + checkAgainst = checkAgainst.Substring(0, checkAgainst.Length - 1); } - // Next, find the matching rule: - MatchingRule = FindMatchingTLDRule(domainString); - - // At this point, no rules match, we have a problem - if (MatchingRule == null) + var rules = Enum.GetValues(typeof(TLDRule.RuleType)).Cast(); + foreach (var rule in rules) { - throw new FormatException("The domain does not have a recognized TLD"); - } - - // Based on the tld rule found, get the domain (and possibly the subdomain) - var tempSudomainAndDomain = string.Empty; - var tldIndex = 0; - - // First, determine what type of rule we have, and set the TLD accordingly - switch (MatchingRule.Type) - { - case TLDRule.RuleType.Normal: - tldIndex = domainString.LastIndexOf("." + MatchingRule.Name); - tempSudomainAndDomain = domainString.Substring(0, tldIndex); - TLD = domainString.Substring(tldIndex + 1); - break; - case TLDRule.RuleType.Wildcard: - // This finds the last portion of the TLD... - tldIndex = domainString.LastIndexOf("." + MatchingRule.Name); - tempSudomainAndDomain = domainString.Substring(0, tldIndex); - - // But we need to find the wildcard portion of it: - tldIndex = tempSudomainAndDomain.LastIndexOf("."); - tempSudomainAndDomain = domainString.Substring(0, tldIndex); - TLD = domainString.Substring(tldIndex + 1); - break; - case TLDRule.RuleType.Exception: - tldIndex = domainString.LastIndexOf("."); - tempSudomainAndDomain = domainString.Substring(0, tldIndex); - TLD = domainString.Substring(tldIndex + 1); - break; - } - - // See if we have a subdomain: - List lstRemainingParts = new List(tempSudomainAndDomain.Split('.')); - - // If we have 0 parts left, there is just a tld and no domain or subdomain - // If we have 1 part, it's the domain, and there is no subdomain - // If we have 2+ parts, the last part is the domain, the other parts (combined) are the subdomain - if (lstRemainingParts.Count > 0) - { - // Set the domain: - SLD = lstRemainingParts[lstRemainingParts.Count - 1]; - - // Set the subdomain, if there is one to set: - if (lstRemainingParts.Count > 1) + // Try to match rule: + TLDRule result; + if (TLDRulesCache.Instance.TLDRuleLists[rule].TryGetValue(checkAgainst, out result)) { - // We strip off the trailing period, too - SubDomain = tempSudomainAndDomain.Substring(0, tempSudomainAndDomain.Length - SLD.Length - 1); + ruleMatches.Add(result); } } } - private static TLDRule FindMatchingTLDRule(string domainString) + // Sort our matches list (longest rule wins, according to : + var results = from match in ruleMatches + orderby match.Name.Length descending + select match; + + // Take the top result (our primary match): + var primaryMatch = results.Take(1).SingleOrDefault(); + return primaryMatch; + } + + public class TLDRule : IComparable + { + public string Name { get; private set; } + public RuleType Type { get; private set; } + + public TLDRule(string RuleInfo) { - // Split our domain into parts (based on the '.') - // ...Put these parts in a list - // ...Make sure these parts are in reverse order - // (we'll be checking rules from the right-most pat of the domain) - var lstDomainParts = domainString.Split('.').ToList(); - lstDomainParts.Reverse(); - - // Begin building our partial domain to check rules with: - var checkAgainst = string.Empty; - - // Our 'matches' collection: - var ruleMatches = new List(); - - foreach (string domainPart in lstDomainParts) + // Parse the rule and set properties accordingly: + if (RuleInfo.StartsWith("*")) { - // Add on our next domain part: - checkAgainst = string.Format("{0}.{1}", domainPart, checkAgainst); + Type = RuleType.Wildcard; + Name = RuleInfo.Substring(2); + } + else if (RuleInfo.StartsWith("!")) + { + Type = RuleType.Exception; + Name = RuleInfo.Substring(1); + } + else + { + Type = RuleType.Normal; + Name = RuleInfo; + } + } - // If we end in a period, strip it off: - if (checkAgainst.EndsWith(".")) - { - checkAgainst = checkAgainst.Substring(0, checkAgainst.Length - 1); - } + public int CompareTo(TLDRule other) + { + if (other == null) + { + return -1; + } - var rules = Enum.GetValues(typeof(TLDRule.RuleType)).Cast(); - foreach (var rule in rules) + return Name.CompareTo(other.Name); + } + + public enum RuleType + { + Normal, + Wildcard, + Exception + } + } + + public class TLDRulesCache + { + private static volatile TLDRulesCache _uniqueInstance; + private static object _syncObj = new object(); + private static object _syncList = new object(); + + private TLDRulesCache() + { + // Initialize our internal list: + TLDRuleLists = GetTLDRules(); + } + + public static TLDRulesCache Instance + { + get + { + if (_uniqueInstance == null) { - // Try to match rule: - TLDRule result; - if (TLDRulesCache.Instance.TLDRuleLists[rule].TryGetValue(checkAgainst, out result)) + lock (_syncObj) { - ruleMatches.Add(result); - } - } - } - - // Sort our matches list (longest rule wins, according to : - var results = from match in ruleMatches - orderby match.Name.Length descending - select match; - - // Take the top result (our primary match): - var primaryMatch = results.Take(1).SingleOrDefault(); - return primaryMatch; - } - - public class TLDRule : IComparable - { - public string Name { get; private set; } - public RuleType Type { get; private set; } - - public TLDRule(string RuleInfo) - { - // Parse the rule and set properties accordingly: - if (RuleInfo.StartsWith("*")) - { - Type = RuleType.Wildcard; - Name = RuleInfo.Substring(2); - } - else if (RuleInfo.StartsWith("!")) - { - Type = RuleType.Exception; - Name = RuleInfo.Substring(1); - } - else - { - Type = RuleType.Normal; - Name = RuleInfo; - } - } - - public int CompareTo(TLDRule other) - { - if (other == null) - { - return -1; - } - - return Name.CompareTo(other.Name); - } - - public enum RuleType - { - Normal, - Wildcard, - Exception - } - } - - public class TLDRulesCache - { - private static volatile TLDRulesCache _uniqueInstance; - private static object _syncObj = new object(); - private static object _syncList = new object(); - - private TLDRulesCache() - { - // Initialize our internal list: - TLDRuleLists = GetTLDRules(); - } - - public static TLDRulesCache Instance - { - get - { - if (_uniqueInstance == null) - { - lock (_syncObj) + if (_uniqueInstance == null) { - if (_uniqueInstance == null) - { - _uniqueInstance = new TLDRulesCache(); - } + _uniqueInstance = new TLDRulesCache(); } } - return (_uniqueInstance); } + return (_uniqueInstance); + } + } + + public IDictionary> TLDRuleLists { get; set; } + + public static void Reset() + { + lock (_syncObj) + { + _uniqueInstance = null; + } + } + + private IDictionary> GetTLDRules() + { + var results = new Dictionary>(); + var rules = Enum.GetValues(typeof(TLDRule.RuleType)).Cast(); + foreach (var rule in rules) + { + results[rule] = new Dictionary(StringComparer.CurrentCultureIgnoreCase); } - public IDictionary> TLDRuleLists { get; set; } + var ruleStrings = ReadRulesData(); - public static void Reset() + // Strip out any lines that are: + // a.) A comment + // b.) Blank + var rulesStrings = ruleStrings + .Where(ruleString => !ruleString.StartsWith("//") && ruleString.Trim().Length != 0); + foreach (var ruleString in rulesStrings) { - lock (_syncObj) - { - _uniqueInstance = null; - } + var result = new TLDRule(ruleString); + results[result.Type][result.Name] = result; } - private IDictionary> GetTLDRules() + // Return our results: + Debug.WriteLine(string.Format("Loaded {0} rules into cache.", + results.Values.Sum(r => r.Values.Count))); + return results; + } + + private IEnumerable ReadRulesData() + { + var assembly = typeof(TLDRulesCache).GetTypeInfo().Assembly; + var stream = assembly.GetManifestResourceStream("Bit.Icons.Resources.public_suffix_list.dat"); + string line; + using (var reader = new StreamReader(stream)) { - var results = new Dictionary>(); - var rules = Enum.GetValues(typeof(TLDRule.RuleType)).Cast(); - foreach (var rule in rules) + while ((line = reader.ReadLine()) != null) { - results[rule] = new Dictionary(StringComparer.CurrentCultureIgnoreCase); - } - - var ruleStrings = ReadRulesData(); - - // Strip out any lines that are: - // a.) A comment - // b.) Blank - var rulesStrings = ruleStrings - .Where(ruleString => !ruleString.StartsWith("//") && ruleString.Trim().Length != 0); - foreach (var ruleString in rulesStrings) - { - var result = new TLDRule(ruleString); - results[result.Type][result.Name] = result; - } - - // Return our results: - Debug.WriteLine(string.Format("Loaded {0} rules into cache.", - results.Values.Sum(r => r.Values.Count))); - return results; - } - - private IEnumerable ReadRulesData() - { - var assembly = typeof(TLDRulesCache).GetTypeInfo().Assembly; - var stream = assembly.GetManifestResourceStream("Bit.Icons.Resources.public_suffix_list.dat"); - string line; - using (var reader = new StreamReader(stream)) - { - while ((line = reader.ReadLine()) != null) - { - yield return line; - } + yield return line; } } } diff --git a/src/Icons/Models/Icon.cs b/src/Icons/Models/Icon.cs index cca6d78d5..8bd23541f 100644 --- a/src/Icons/Models/Icon.cs +++ b/src/Icons/Models/Icon.cs @@ -1,8 +1,7 @@ -namespace Bit.Icons.Models +namespace Bit.Icons.Models; + +public class Icon { - public class Icon - { - public byte[] Image { get; set; } - public string Format { get; set; } - } + public byte[] Image { get; set; } + public string Format { get; set; } } diff --git a/src/Icons/Models/IconResult.cs b/src/Icons/Models/IconResult.cs index 104c2627a..ca1e6929e 100644 --- a/src/Icons/Models/IconResult.cs +++ b/src/Icons/Models/IconResult.cs @@ -1,66 +1,65 @@ -namespace Bit.Icons.Models -{ - public class IconResult - { - public IconResult(string href, string sizes) - { - Path = href; - if (!string.IsNullOrWhiteSpace(sizes)) - { - var sizeParts = sizes.Split('x'); - if (sizeParts.Length == 2 && int.TryParse(sizeParts[0].Trim(), out var width) && - int.TryParse(sizeParts[1].Trim(), out var height)) - { - DefinedWidth = width; - DefinedHeight = height; +namespace Bit.Icons.Models; - if (width == height) +public class IconResult +{ + public IconResult(string href, string sizes) + { + Path = href; + if (!string.IsNullOrWhiteSpace(sizes)) + { + var sizeParts = sizes.Split('x'); + if (sizeParts.Length == 2 && int.TryParse(sizeParts[0].Trim(), out var width) && + int.TryParse(sizeParts[1].Trim(), out var height)) + { + DefinedWidth = width; + DefinedHeight = height; + + if (width == height) + { + if (width == 32) { - if (width == 32) - { - Priority = 1; - } - else if (width == 64) - { - Priority = 2; - } - else if (width >= 24 && width <= 128) - { - Priority = 3; - } - else if (width == 16) - { - Priority = 4; - } - else - { - Priority = 100; - } + Priority = 1; + } + else if (width == 64) + { + Priority = 2; + } + else if (width >= 24 && width <= 128) + { + Priority = 3; + } + else if (width == 16) + { + Priority = 4; + } + else + { + Priority = 100; } } } - - if (Priority == 0) - { - Priority = 200; - } } - public IconResult(Uri uri, byte[] bytes, string format) + if (Priority == 0) { - Path = uri.ToString(); - Icon = new Icon - { - Image = bytes, - Format = format - }; - Priority = 10; + Priority = 200; } - - public string Path { get; set; } - public int? DefinedWidth { get; set; } - public int? DefinedHeight { get; set; } - public Icon Icon { get; set; } - public int Priority { get; set; } } + + public IconResult(Uri uri, byte[] bytes, string format) + { + Path = uri.ToString(); + Icon = new Icon + { + Image = bytes, + Format = format + }; + Priority = 10; + } + + public string Path { get; set; } + public int? DefinedWidth { get; set; } + public int? DefinedHeight { get; set; } + public Icon Icon { get; set; } + public int Priority { get; set; } } diff --git a/src/Icons/Program.cs b/src/Icons/Program.cs index 1f65ea406..d57a6fd1c 100644 --- a/src/Icons/Program.cs +++ b/src/Icons/Program.cs @@ -1,22 +1,21 @@ using Bit.Core.Utilities; using Serilog.Events; -namespace Bit.Icons +namespace Bit.Icons; + +public class Program { - public class Program + public static void Main(string[] args) { - public static void Main(string[] args) - { - Host - .CreateDefaultBuilder(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => e.Level >= LogEventLevel.Error)); - }) - .Build() - .Run(); - } + Host + .CreateDefaultBuilder(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => e.Level >= LogEventLevel.Error)); + }) + .Build() + .Run(); } } diff --git a/src/Icons/Services/DomainMappingService.cs b/src/Icons/Services/DomainMappingService.cs index 406145af9..b41d48233 100644 --- a/src/Icons/Services/DomainMappingService.cs +++ b/src/Icons/Services/DomainMappingService.cs @@ -1,24 +1,23 @@ -namespace Bit.Icons.Services +namespace Bit.Icons.Services; + +public class DomainMappingService : IDomainMappingService { - public class DomainMappingService : IDomainMappingService + private readonly Dictionary _map = new Dictionary { - private readonly Dictionary _map = new Dictionary - { - ["login.yahoo.com"] = "yahoo.com", - ["accounts.google.com"] = "google.com", - ["photo.walgreens.com"] = "walgreens.com", - ["passport.yandex.com"] = "yandex.com", - // TODO: Add others here - }; + ["login.yahoo.com"] = "yahoo.com", + ["accounts.google.com"] = "google.com", + ["photo.walgreens.com"] = "walgreens.com", + ["passport.yandex.com"] = "yandex.com", + // TODO: Add others here + }; - public string MapDomain(string hostname) + public string MapDomain(string hostname) + { + if (_map.ContainsKey(hostname)) { - if (_map.ContainsKey(hostname)) - { - return _map[hostname]; - } - - return hostname; + return _map[hostname]; } + + return hostname; } } diff --git a/src/Icons/Services/IDomainMappingService.cs b/src/Icons/Services/IDomainMappingService.cs index 194ee8f64..4ee3f4594 100644 --- a/src/Icons/Services/IDomainMappingService.cs +++ b/src/Icons/Services/IDomainMappingService.cs @@ -1,7 +1,6 @@ -namespace Bit.Icons.Services +namespace Bit.Icons.Services; + +public interface IDomainMappingService { - public interface IDomainMappingService - { - string MapDomain(string hostname); - } + string MapDomain(string hostname); } diff --git a/src/Icons/Services/IIconFetchingService.cs b/src/Icons/Services/IIconFetchingService.cs index 4c15ddffb..ff6704291 100644 --- a/src/Icons/Services/IIconFetchingService.cs +++ b/src/Icons/Services/IIconFetchingService.cs @@ -1,9 +1,8 @@ using Bit.Icons.Models; -namespace Bit.Icons.Services +namespace Bit.Icons.Services; + +public interface IIconFetchingService { - public interface IIconFetchingService - { - Task GetIconAsync(string domain); - } + Task GetIconAsync(string domain); } diff --git a/src/Icons/Services/IconFetchingService.cs b/src/Icons/Services/IconFetchingService.cs index e7ae38450..166d5a0aa 100644 --- a/src/Icons/Services/IconFetchingService.cs +++ b/src/Icons/Services/IconFetchingService.cs @@ -3,448 +3,447 @@ using System.Text; using AngleSharp.Html.Parser; using Bit.Icons.Models; -namespace Bit.Icons.Services +namespace Bit.Icons.Services; + +public class IconFetchingService : IIconFetchingService { - public class IconFetchingService : IIconFetchingService + private readonly HashSet _iconRels = + new HashSet { "icon", "apple-touch-icon", "shortcut icon" }; + private readonly HashSet _blacklistedRels = + new HashSet { "preload", "image_src", "preconnect", "canonical", "alternate", "stylesheet" }; + private readonly HashSet _iconExtensions = + new HashSet { ".ico", ".png", ".jpg", ".jpeg" }; + + private readonly string _pngMediaType = "image/png"; + private readonly byte[] _pngHeader = new byte[] { 137, 80, 78, 71 }; + private readonly byte[] _webpHeader = Encoding.UTF8.GetBytes("RIFF"); + + private readonly string _icoMediaType = "image/x-icon"; + private readonly string _icoAltMediaType = "image/vnd.microsoft.icon"; + private readonly byte[] _icoHeader = new byte[] { 00, 00, 01, 00 }; + + private readonly string _jpegMediaType = "image/jpeg"; + private readonly byte[] _jpegHeader = new byte[] { 255, 216, 255 }; + + private readonly HashSet _allowedMediaTypes; + private readonly HttpClient _httpClient; + private readonly ILogger _logger; + + public IconFetchingService(ILogger logger) { - private readonly HashSet _iconRels = - new HashSet { "icon", "apple-touch-icon", "shortcut icon" }; - private readonly HashSet _blacklistedRels = - new HashSet { "preload", "image_src", "preconnect", "canonical", "alternate", "stylesheet" }; - private readonly HashSet _iconExtensions = - new HashSet { ".ico", ".png", ".jpg", ".jpeg" }; - - private readonly string _pngMediaType = "image/png"; - private readonly byte[] _pngHeader = new byte[] { 137, 80, 78, 71 }; - private readonly byte[] _webpHeader = Encoding.UTF8.GetBytes("RIFF"); - - private readonly string _icoMediaType = "image/x-icon"; - private readonly string _icoAltMediaType = "image/vnd.microsoft.icon"; - private readonly byte[] _icoHeader = new byte[] { 00, 00, 01, 00 }; - - private readonly string _jpegMediaType = "image/jpeg"; - private readonly byte[] _jpegHeader = new byte[] { 255, 216, 255 }; - - private readonly HashSet _allowedMediaTypes; - private readonly HttpClient _httpClient; - private readonly ILogger _logger; - - public IconFetchingService(ILogger logger) + _logger = logger; + _allowedMediaTypes = new HashSet { - _logger = logger; - _allowedMediaTypes = new HashSet - { - _pngMediaType, - _icoMediaType, - _icoAltMediaType, - _jpegMediaType - }; + _pngMediaType, + _icoMediaType, + _icoAltMediaType, + _jpegMediaType + }; - _httpClient = new HttpClient(new HttpClientHandler - { - AllowAutoRedirect = false, - AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate, - }); - _httpClient.Timeout = TimeSpan.FromSeconds(20); - _httpClient.MaxResponseContentBufferSize = 5000000; // 5 MB + _httpClient = new HttpClient(new HttpClientHandler + { + AllowAutoRedirect = false, + AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate, + }); + _httpClient.Timeout = TimeSpan.FromSeconds(20); + _httpClient.MaxResponseContentBufferSize = 5000000; // 5 MB + } + + public async Task GetIconAsync(string domain) + { + if (IPAddress.TryParse(domain, out _)) + { + _logger.LogWarning("IP address: {0}.", domain); + return null; } - public async Task GetIconAsync(string domain) + if (!Uri.TryCreate($"https://{domain}", UriKind.Absolute, out var parsedHttpsUri)) { - if (IPAddress.TryParse(domain, out _)) - { - _logger.LogWarning("IP address: {0}.", domain); - return null; - } + _logger.LogWarning("Bad domain: {0}.", domain); + return null; + } - if (!Uri.TryCreate($"https://{domain}", UriKind.Absolute, out var parsedHttpsUri)) - { - _logger.LogWarning("Bad domain: {0}.", domain); - return null; - } + var uri = parsedHttpsUri; + var response = await GetAndFollowAsync(uri, 2); + if ((response == null || !response.IsSuccessStatusCode) && + Uri.TryCreate($"http://{parsedHttpsUri.Host}", UriKind.Absolute, out var parsedHttpUri)) + { + Cleanup(response); + uri = parsedHttpUri; + response = await GetAndFollowAsync(uri, 2); - var uri = parsedHttpsUri; - var response = await GetAndFollowAsync(uri, 2); - if ((response == null || !response.IsSuccessStatusCode) && - Uri.TryCreate($"http://{parsedHttpsUri.Host}", UriKind.Absolute, out var parsedHttpUri)) + if (response == null || !response.IsSuccessStatusCode) { - Cleanup(response); - uri = parsedHttpUri; - response = await GetAndFollowAsync(uri, 2); - - if (response == null || !response.IsSuccessStatusCode) + var dotCount = domain.Count(c => c == '.'); + if (dotCount > 1 && DomainName.TryParseBaseDomain(domain, out var baseDomain) && + Uri.TryCreate($"https://{baseDomain}", UriKind.Absolute, out var parsedBaseUri)) { - var dotCount = domain.Count(c => c == '.'); - if (dotCount > 1 && DomainName.TryParseBaseDomain(domain, out var baseDomain) && - Uri.TryCreate($"https://{baseDomain}", UriKind.Absolute, out var parsedBaseUri)) - { - Cleanup(response); - uri = parsedBaseUri; - response = await GetAndFollowAsync(uri, 2); - } - else if (dotCount < 2 && - Uri.TryCreate($"https://www.{parsedHttpsUri.Host}", UriKind.Absolute, out var parsedWwwUri)) - { - Cleanup(response); - uri = parsedWwwUri; - response = await GetAndFollowAsync(uri, 2); - } + Cleanup(response); + uri = parsedBaseUri; + response = await GetAndFollowAsync(uri, 2); + } + else if (dotCount < 2 && + Uri.TryCreate($"https://www.{parsedHttpsUri.Host}", UriKind.Absolute, out var parsedWwwUri)) + { + Cleanup(response); + uri = parsedWwwUri; + response = await GetAndFollowAsync(uri, 2); } } + } - if (response?.Content == null || !response.IsSuccessStatusCode) + if (response?.Content == null || !response.IsSuccessStatusCode) + { + _logger.LogWarning("Couldn't load a website for {0}: {1}.", domain, + response?.StatusCode.ToString() ?? "null"); + Cleanup(response); + return null; + } + + var parser = new HtmlParser(); + using (response) + using (var htmlStream = await response.Content.ReadAsStreamAsync()) + using (var document = await parser.ParseDocumentAsync(htmlStream)) + { + uri = response.RequestMessage.RequestUri; + if (document.DocumentElement == null) { - _logger.LogWarning("Couldn't load a website for {0}: {1}.", domain, - response?.StatusCode.ToString() ?? "null"); - Cleanup(response); + _logger.LogWarning("No DocumentElement for {0}.", domain); return null; } - var parser = new HtmlParser(); - using (response) - using (var htmlStream = await response.Content.ReadAsStreamAsync()) - using (var document = await parser.ParseDocumentAsync(htmlStream)) + var baseUrl = "/"; + var baseUrlNode = document.QuerySelector("head base[href]"); + if (baseUrlNode != null) { - uri = response.RequestMessage.RequestUri; - if (document.DocumentElement == null) + var hrefAttr = baseUrlNode.Attributes["href"]; + if (!string.IsNullOrWhiteSpace(hrefAttr?.Value)) { - _logger.LogWarning("No DocumentElement for {0}.", domain); - return null; + baseUrl = hrefAttr.Value; } - var baseUrl = "/"; - var baseUrlNode = document.QuerySelector("head base[href]"); - if (baseUrlNode != null) + baseUrlNode = null; + hrefAttr = null; + } + + var icons = new List(); + var links = document.QuerySelectorAll("head link[href]"); + if (links != null) + { + foreach (var link in links.Take(200)) { - var hrefAttr = baseUrlNode.Attributes["href"]; - if (!string.IsNullOrWhiteSpace(hrefAttr?.Value)) + var hrefAttr = link.Attributes["href"]; + if (string.IsNullOrWhiteSpace(hrefAttr?.Value)) { - baseUrl = hrefAttr.Value; + continue; } - baseUrlNode = null; + var relAttr = link.Attributes["rel"]; + var sizesAttr = link.Attributes["sizes"]; + if (relAttr != null && _iconRels.Contains(relAttr.Value.ToLower())) + { + icons.Add(new IconResult(hrefAttr.Value, sizesAttr?.Value)); + } + else if (relAttr == null || !_blacklistedRels.Contains(relAttr.Value.ToLower())) + { + try + { + var extension = Path.GetExtension(hrefAttr.Value); + if (_iconExtensions.Contains(extension.ToLower())) + { + icons.Add(new IconResult(hrefAttr.Value, sizesAttr?.Value)); + } + } + catch (ArgumentException) { } + } + + sizesAttr = null; + relAttr = null; hrefAttr = null; } - var icons = new List(); - var links = document.QuerySelectorAll("head link[href]"); - if (links != null) - { - foreach (var link in links.Take(200)) - { - var hrefAttr = link.Attributes["href"]; - if (string.IsNullOrWhiteSpace(hrefAttr?.Value)) - { - continue; - } - - var relAttr = link.Attributes["rel"]; - var sizesAttr = link.Attributes["sizes"]; - if (relAttr != null && _iconRels.Contains(relAttr.Value.ToLower())) - { - icons.Add(new IconResult(hrefAttr.Value, sizesAttr?.Value)); - } - else if (relAttr == null || !_blacklistedRels.Contains(relAttr.Value.ToLower())) - { - try - { - var extension = Path.GetExtension(hrefAttr.Value); - if (_iconExtensions.Contains(extension.ToLower())) - { - icons.Add(new IconResult(hrefAttr.Value, sizesAttr?.Value)); - } - } - catch (ArgumentException) { } - } - - sizesAttr = null; - relAttr = null; - hrefAttr = null; - } - - links = null; - } - - var iconResultTasks = new List(); - foreach (var icon in icons.OrderBy(i => i.Priority).Take(10)) - { - Uri iconUri = null; - if (icon.Path.StartsWith("//") && Uri.TryCreate($"{GetScheme(uri)}://{icon.Path.Substring(2)}", - UriKind.Absolute, out var slashUri)) - { - iconUri = slashUri; - } - else if (Uri.TryCreate(icon.Path, UriKind.Relative, out var relUri)) - { - iconUri = ResolveUri($"{GetScheme(uri)}://{uri.Host}", baseUrl, relUri.OriginalString); - } - else if (Uri.TryCreate(icon.Path, UriKind.Absolute, out var absUri)) - { - iconUri = absUri; - } - - if (iconUri != null) - { - var task = GetIconAsync(iconUri).ContinueWith(async (r) => - { - var result = await r; - if (result != null) - { - icon.Path = iconUri.ToString(); - icon.Icon = result.Icon; - } - }); - iconResultTasks.Add(task); - } - } - - await Task.WhenAll(iconResultTasks); - if (!icons.Any(i => i.Icon != null)) - { - var faviconUri = ResolveUri($"{GetScheme(uri)}://{uri.Host}", "favicon.ico"); - var result = await GetIconAsync(faviconUri); - if (result != null) - { - icons.Add(result); - } - else - { - _logger.LogWarning("No favicon.ico found for {0}.", uri.Host); - return null; - } - } - - return icons.Where(i => i.Icon != null).OrderBy(i => i.Priority).First(); + links = null; } - } - private async Task GetIconAsync(Uri uri) - { - using (var response = await GetAndFollowAsync(uri, 2)) + var iconResultTasks = new List(); + foreach (var icon in icons.OrderBy(i => i.Priority).Take(10)) { - if (response?.Content?.Headers == null || !response.IsSuccessStatusCode) + Uri iconUri = null; + if (icon.Path.StartsWith("//") && Uri.TryCreate($"{GetScheme(uri)}://{icon.Path.Substring(2)}", + UriKind.Absolute, out var slashUri)) { - response?.Content?.Dispose(); + iconUri = slashUri; + } + else if (Uri.TryCreate(icon.Path, UriKind.Relative, out var relUri)) + { + iconUri = ResolveUri($"{GetScheme(uri)}://{uri.Host}", baseUrl, relUri.OriginalString); + } + else if (Uri.TryCreate(icon.Path, UriKind.Absolute, out var absUri)) + { + iconUri = absUri; + } + + if (iconUri != null) + { + var task = GetIconAsync(iconUri).ContinueWith(async (r) => + { + var result = await r; + if (result != null) + { + icon.Path = iconUri.ToString(); + icon.Icon = result.Icon; + } + }); + iconResultTasks.Add(task); + } + } + + await Task.WhenAll(iconResultTasks); + if (!icons.Any(i => i.Icon != null)) + { + var faviconUri = ResolveUri($"{GetScheme(uri)}://{uri.Host}", "favicon.ico"); + var result = await GetIconAsync(faviconUri); + if (result != null) + { + icons.Add(result); + } + else + { + _logger.LogWarning("No favicon.ico found for {0}.", uri.Host); return null; } + } - var format = response.Content.Headers?.ContentType?.MediaType; - var bytes = await response.Content.ReadAsByteArrayAsync(); - response.Content.Dispose(); - if (format == null || !_allowedMediaTypes.Contains(format)) + return icons.Where(i => i.Icon != null).OrderBy(i => i.Priority).First(); + } + } + + private async Task GetIconAsync(Uri uri) + { + using (var response = await GetAndFollowAsync(uri, 2)) + { + if (response?.Content?.Headers == null || !response.IsSuccessStatusCode) + { + response?.Content?.Dispose(); + return null; + } + + var format = response.Content.Headers?.ContentType?.MediaType; + var bytes = await response.Content.ReadAsByteArrayAsync(); + response.Content.Dispose(); + if (format == null || !_allowedMediaTypes.Contains(format)) + { + if (HeaderMatch(bytes, _icoHeader)) { - if (HeaderMatch(bytes, _icoHeader)) - { - format = _icoMediaType; - } - else if (HeaderMatch(bytes, _pngHeader) || HeaderMatch(bytes, _webpHeader)) - { - format = _pngMediaType; - } - else if (HeaderMatch(bytes, _jpegHeader)) - { - format = _jpegMediaType; - } - else - { - return null; - } + format = _icoMediaType; + } + else if (HeaderMatch(bytes, _pngHeader) || HeaderMatch(bytes, _webpHeader)) + { + format = _pngMediaType; + } + else if (HeaderMatch(bytes, _jpegHeader)) + { + format = _jpegMediaType; + } + else + { + return null; } - - return new IconResult(uri, bytes, format); } + + return new IconResult(uri, bytes, format); + } + } + + private async Task GetAndFollowAsync(Uri uri, int maxRedirectCount) + { + var response = await GetAsync(uri); + if (response == null) + { + return null; + } + return await FollowRedirectsAsync(response, maxRedirectCount); + } + + private async Task GetAsync(Uri uri) + { + if (uri == null) + { + return null; } - private async Task GetAndFollowAsync(Uri uri, int maxRedirectCount) + // Prevent non-http(s) and non-default ports + if ((uri.Scheme != "http" && uri.Scheme != "https") || !uri.IsDefaultPort) { - var response = await GetAsync(uri); - if (response == null) - { - return null; - } - return await FollowRedirectsAsync(response, maxRedirectCount); + return null; } - private async Task GetAsync(Uri uri) + // Prevent local hosts (localhost, bobs-pc, etc) and IP addresses + if (!uri.Host.Contains(".") || IPAddress.TryParse(uri.Host, out _)) { - if (uri == null) + return null; + } + + // Resolve host to make sure it is not an internal/private IP address + try + { + var hostEntry = Dns.GetHostEntry(uri.Host); + if (hostEntry?.AddressList.Any(ip => IsInternal(ip)) ?? true) { return null; } + } + catch + { + return null; + } - // Prevent non-http(s) and non-default ports - if ((uri.Scheme != "http" && uri.Scheme != "https") || !uri.IsDefaultPort) - { - return null; - } + using (var message = new HttpRequestMessage()) + { + message.RequestUri = uri; + message.Method = HttpMethod.Get; - // Prevent local hosts (localhost, bobs-pc, etc) and IP addresses - if (!uri.Host.Contains(".") || IPAddress.TryParse(uri.Host, out _)) - { - return null; - } + // Let's add some headers to look like we're coming from a web browser request. Some websites + // will block our request without these. + message.Headers.Add("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 " + + "(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36 Edge/16.16299"); + message.Headers.Add("Accept-Language", "en-US,en;q=0.8"); + message.Headers.Add("Cache-Control", "no-cache"); + message.Headers.Add("Pragma", "no-cache"); + message.Headers.Add("Accept", "text/html,application/xhtml+xml,application/xml;" + + "q=0.9,image/webp,image/apng,*/*;q=0.8"); - // Resolve host to make sure it is not an internal/private IP address try { - var hostEntry = Dns.GetHostEntry(uri.Host); - if (hostEntry?.AddressList.Any(ip => IsInternal(ip)) ?? true) - { - return null; - } + return await _httpClient.SendAsync(message); } catch { return null; } + } + } - using (var message = new HttpRequestMessage()) - { - message.RequestUri = uri; - message.Method = HttpMethod.Get; - - // Let's add some headers to look like we're coming from a web browser request. Some websites - // will block our request without these. - message.Headers.Add("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 " + - "(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36 Edge/16.16299"); - message.Headers.Add("Accept-Language", "en-US,en;q=0.8"); - message.Headers.Add("Cache-Control", "no-cache"); - message.Headers.Add("Pragma", "no-cache"); - message.Headers.Add("Accept", "text/html,application/xhtml+xml,application/xml;" + - "q=0.9,image/webp,image/apng,*/*;q=0.8"); - - try - { - return await _httpClient.SendAsync(message); - } - catch - { - return null; - } - } + private async Task FollowRedirectsAsync(HttpResponseMessage response, + int maxFollowCount, int followCount = 0) + { + if (response == null || response.IsSuccessStatusCode || followCount > maxFollowCount) + { + return response; } - private async Task FollowRedirectsAsync(HttpResponseMessage response, - int maxFollowCount, int followCount = 0) + if (!(response.StatusCode == HttpStatusCode.Redirect || + response.StatusCode == HttpStatusCode.MovedPermanently || + response.StatusCode == HttpStatusCode.RedirectKeepVerb || + response.StatusCode == HttpStatusCode.SeeOther) || + response.Headers.Location == null) { - if (response == null || response.IsSuccessStatusCode || followCount > maxFollowCount) - { - return response; - } + Cleanup(response); + return null; + } - if (!(response.StatusCode == HttpStatusCode.Redirect || - response.StatusCode == HttpStatusCode.MovedPermanently || - response.StatusCode == HttpStatusCode.RedirectKeepVerb || - response.StatusCode == HttpStatusCode.SeeOther) || - response.Headers.Location == null) + Uri location = null; + if (response.Headers.Location.IsAbsoluteUri) + { + if (response.Headers.Location.Scheme != "http" && response.Headers.Location.Scheme != "https") { - Cleanup(response); - return null; - } - - Uri location = null; - if (response.Headers.Location.IsAbsoluteUri) - { - if (response.Headers.Location.Scheme != "http" && response.Headers.Location.Scheme != "https") + if (Uri.TryCreate($"https://{response.Headers.Location.OriginalString}", + UriKind.Absolute, out var newUri)) { - if (Uri.TryCreate($"https://{response.Headers.Location.OriginalString}", - UriKind.Absolute, out var newUri)) - { - location = newUri; - } - } - else - { - location = response.Headers.Location; + location = newUri; } } else { - var requestUri = response.RequestMessage.RequestUri; - location = ResolveUri($"{GetScheme(requestUri)}://{requestUri.Host}", - response.Headers.Location.OriginalString); + location = response.Headers.Location; } + } + else + { + var requestUri = response.RequestMessage.RequestUri; + location = ResolveUri($"{GetScheme(requestUri)}://{requestUri.Host}", + response.Headers.Location.OriginalString); + } - Cleanup(response); - var newResponse = await GetAsync(location); - if (newResponse != null) + Cleanup(response); + var newResponse = await GetAsync(location); + if (newResponse != null) + { + followCount++; + var redirectedResponse = await FollowRedirectsAsync(newResponse, maxFollowCount, followCount); + if (redirectedResponse != null) { - followCount++; - var redirectedResponse = await FollowRedirectsAsync(newResponse, maxFollowCount, followCount); - if (redirectedResponse != null) + if (redirectedResponse != newResponse) { - if (redirectedResponse != newResponse) - { - Cleanup(newResponse); - } - return redirectedResponse; + Cleanup(newResponse); } + return redirectedResponse; } - - return null; } - private bool HeaderMatch(byte[] imageBytes, byte[] header) - { - return imageBytes.Length >= header.Length && header.SequenceEqual(imageBytes.Take(header.Length)); - } + return null; + } - private Uri ResolveUri(string baseUrl, params string[] paths) + private bool HeaderMatch(byte[] imageBytes, byte[] header) + { + return imageBytes.Length >= header.Length && header.SequenceEqual(imageBytes.Take(header.Length)); + } + + private Uri ResolveUri(string baseUrl, params string[] paths) + { + var url = baseUrl; + foreach (var path in paths) { - var url = baseUrl; - foreach (var path in paths) + if (Uri.TryCreate(new Uri(url), path, out var r)) { - if (Uri.TryCreate(new Uri(url), path, out var r)) - { - url = r.ToString(); - } + url = r.ToString(); } - return new Uri(url); } + return new Uri(url); + } - private void Cleanup(IDisposable obj) + private void Cleanup(IDisposable obj) + { + obj?.Dispose(); + obj = null; + } + + private string GetScheme(Uri uri) + { + return uri != null && uri.Scheme == "http" ? "http" : "https"; + } + + public static bool IsInternal(IPAddress ip) + { + if (IPAddress.IsLoopback(ip)) { - obj?.Dispose(); - obj = null; + return true; } - private string GetScheme(Uri uri) + var ipString = ip.ToString(); + if (ipString == "::1" || ipString == "::" || ipString.StartsWith("::ffff:")) { - return uri != null && uri.Scheme == "http" ? "http" : "https"; + return true; } - public static bool IsInternal(IPAddress ip) + // IPv6 + if (ip.AddressFamily == System.Net.Sockets.AddressFamily.InterNetworkV6) { - if (IPAddress.IsLoopback(ip)) - { - return true; - } - - var ipString = ip.ToString(); - if (ipString == "::1" || ipString == "::" || ipString.StartsWith("::ffff:")) - { - return true; - } - - // IPv6 - if (ip.AddressFamily == System.Net.Sockets.AddressFamily.InterNetworkV6) - { - return ipString.StartsWith("fc") || ipString.StartsWith("fd") || - ipString.StartsWith("fe") || ipString.StartsWith("ff"); - } - - // IPv4 - var bytes = ip.GetAddressBytes(); - return (bytes[0]) switch - { - 0 => true, - 10 => true, - 127 => true, - 169 => bytes[1] == 254, // Cloud environments, such as AWS - 172 => bytes[1] < 32 && bytes[1] >= 16, - 192 => bytes[1] == 168, - _ => false, - }; + return ipString.StartsWith("fc") || ipString.StartsWith("fd") || + ipString.StartsWith("fe") || ipString.StartsWith("ff"); } + + // IPv4 + var bytes = ip.GetAddressBytes(); + return (bytes[0]) switch + { + 0 => true, + 10 => true, + 127 => true, + 169 => bytes[1] == 254, // Cloud environments, such as AWS + 172 => bytes[1] < 32 && bytes[1] >= 16, + 192 => bytes[1] == 168, + _ => false, + }; } } diff --git a/src/Icons/Startup.cs b/src/Icons/Startup.cs index 71442772b..f64ea07ed 100644 --- a/src/Icons/Startup.cs +++ b/src/Icons/Startup.cs @@ -5,73 +5,72 @@ using Bit.Icons.Services; using Bit.SharedWeb.Utilities; using Microsoft.Net.Http.Headers; -namespace Bit.Icons +namespace Bit.Icons; + +public class Startup { - public class Startup + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Configuration = configuration; + Environment = env; + } + + public IConfiguration Configuration { get; } + public IWebHostEnvironment Environment { get; } + + public void ConfigureServices(IServiceCollection services) + { + // Options + services.AddOptions(); + + // Settings + var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); + var iconsSettings = new IconsSettings(); + ConfigurationBinder.Bind(Configuration.GetSection("IconsSettings"), iconsSettings); + services.AddSingleton(s => iconsSettings); + + // Cache + services.AddMemoryCache(options => { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - Configuration = configuration; - Environment = env; + options.SizeLimit = iconsSettings.CacheSizeLimit; + }); + + // Services + services.AddSingleton(); + services.AddSingleton(); + + // Mvc + services.AddMvc(); + } + + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings) + { + app.UseSerilog(env, appLifetime, globalSettings); + + // Add general security headers + app.UseMiddleware(); + + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); } - public IConfiguration Configuration { get; } - public IWebHostEnvironment Environment { get; } - - public void ConfigureServices(IServiceCollection services) + app.Use(async (context, next) => { - // Options - services.AddOptions(); - - // Settings - var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); - var iconsSettings = new IconsSettings(); - ConfigurationBinder.Bind(Configuration.GetSection("IconsSettings"), iconsSettings); - services.AddSingleton(s => iconsSettings); - - // Cache - services.AddMemoryCache(options => + context.Response.GetTypedHeaders().CacheControl = new CacheControlHeaderValue { - options.SizeLimit = iconsSettings.CacheSizeLimit; - }); + Public = true, + MaxAge = TimeSpan.FromDays(7) + }; + await next(); + }); - // Services - services.AddSingleton(); - services.AddSingleton(); - - // Mvc - services.AddMvc(); - } - - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings) - { - app.UseSerilog(env, appLifetime, globalSettings); - - // Add general security headers - app.UseMiddleware(); - - if (env.IsDevelopment()) - { - app.UseDeveloperExceptionPage(); - } - - app.Use(async (context, next) => - { - context.Response.GetTypedHeaders().CacheControl = new CacheControlHeaderValue - { - Public = true, - MaxAge = TimeSpan.FromDays(7) - }; - await next(); - }); - - app.UseRouting(); - app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); - } + app.UseRouting(); + app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); } } diff --git a/src/Identity/Controllers/AccountsController.cs b/src/Identity/Controllers/AccountsController.cs index d7151a3ee..940e2ab97 100644 --- a/src/Identity/Controllers/AccountsController.cs +++ b/src/Identity/Controllers/AccountsController.cs @@ -9,61 +9,60 @@ using Bit.Core.Utilities; using Bit.SharedWeb.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Identity.Controllers +namespace Bit.Identity.Controllers; + +[Route("accounts")] +[ExceptionHandlerFilter] +public class AccountsController : Controller { - [Route("accounts")] - [ExceptionHandlerFilter] - public class AccountsController : Controller + private readonly ILogger _logger; + private readonly IUserRepository _userRepository; + private readonly IUserService _userService; + + public AccountsController( + ILogger logger, + IUserRepository userRepository, + IUserService userService) { - private readonly ILogger _logger; - private readonly IUserRepository _userRepository; - private readonly IUserService _userService; + _logger = logger; + _userRepository = userRepository; + _userService = userService; + } - public AccountsController( - ILogger logger, - IUserRepository userRepository, - IUserService userService) + // Moved from API, If you modify this endpoint, please update API as well. + [HttpPost("register")] + [CaptchaProtected] + public async Task PostRegister([FromBody] RegisterRequestModel model) + { + var result = await _userService.RegisterUserAsync(model.ToUser(), model.MasterPasswordHash, + model.Token, model.OrganizationUserId); + if (result.Succeeded) { - _logger = logger; - _userRepository = userRepository; - _userService = userService; + return; } - // Moved from API, If you modify this endpoint, please update API as well. - [HttpPost("register")] - [CaptchaProtected] - public async Task PostRegister([FromBody] RegisterRequestModel model) + foreach (var error in result.Errors.Where(e => e.Code != "DuplicateUserName")) { - var result = await _userService.RegisterUserAsync(model.ToUser(), model.MasterPasswordHash, - model.Token, model.OrganizationUserId); - if (result.Succeeded) - { - return; - } - - foreach (var error in result.Errors.Where(e => e.Code != "DuplicateUserName")) - { - ModelState.AddModelError(string.Empty, error.Description); - } - - await Task.Delay(2000); - throw new BadRequestException(ModelState); + ModelState.AddModelError(string.Empty, error.Description); } - // Moved from API, If you modify this endpoint, please update API as well. - [HttpPost("prelogin")] - public async Task PostPrelogin([FromBody] PreloginRequestModel model) + await Task.Delay(2000); + throw new BadRequestException(ModelState); + } + + // Moved from API, If you modify this endpoint, please update API as well. + [HttpPost("prelogin")] + public async Task PostPrelogin([FromBody] PreloginRequestModel model) + { + var kdfInformation = await _userRepository.GetKdfInformationByEmailAsync(model.Email); + if (kdfInformation == null) { - var kdfInformation = await _userRepository.GetKdfInformationByEmailAsync(model.Email); - if (kdfInformation == null) + kdfInformation = new UserKdfInformation { - kdfInformation = new UserKdfInformation - { - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 100000, - }; - } - return new PreloginResponseModel(kdfInformation); + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 100000, + }; } + return new PreloginResponseModel(kdfInformation); } } diff --git a/src/Identity/Controllers/InfoController.cs b/src/Identity/Controllers/InfoController.cs index d8c161c61..c06812cdf 100644 --- a/src/Identity/Controllers/InfoController.cs +++ b/src/Identity/Controllers/InfoController.cs @@ -1,21 +1,20 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Identity.Controllers -{ - public class InfoController : Controller - { - [HttpGet("~/alive")] - [HttpGet("~/now")] - public DateTime GetAlive() - { - return DateTime.UtcNow; - } +namespace Bit.Identity.Controllers; - [HttpGet("~/version")] - public JsonResult GetVersion() - { - return Json(CoreHelpers.GetVersion()); - } +public class InfoController : Controller +{ + [HttpGet("~/alive")] + [HttpGet("~/now")] + public DateTime GetAlive() + { + return DateTime.UtcNow; + } + + [HttpGet("~/version")] + public JsonResult GetVersion() + { + return Json(CoreHelpers.GetVersion()); } } diff --git a/src/Identity/Controllers/SsoController.cs b/src/Identity/Controllers/SsoController.cs index e3dc8f504..b61d89b86 100644 --- a/src/Identity/Controllers/SsoController.cs +++ b/src/Identity/Controllers/SsoController.cs @@ -11,265 +11,264 @@ using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Localization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Identity.Controllers +namespace Bit.Identity.Controllers; + +// TODO: 2022-01-12, Remove account alias +[Route("account/[action]")] +[Route("sso/[action]")] +public class SsoController : Controller { - // TODO: 2022-01-12, Remove account alias - [Route("account/[action]")] - [Route("sso/[action]")] - public class SsoController : Controller + private readonly IIdentityServerInteractionService _interaction; + private readonly ILogger _logger; + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly IUserRepository _userRepository; + private readonly IHttpClientFactory _clientFactory; + + public SsoController( + IIdentityServerInteractionService interaction, + ILogger logger, + ISsoConfigRepository ssoConfigRepository, + IUserRepository userRepository, + IHttpClientFactory clientFactory) { - private readonly IIdentityServerInteractionService _interaction; - private readonly ILogger _logger; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly IUserRepository _userRepository; - private readonly IHttpClientFactory _clientFactory; + _interaction = interaction; + _logger = logger; + _ssoConfigRepository = ssoConfigRepository; + _userRepository = userRepository; + _clientFactory = clientFactory; + } - public SsoController( - IIdentityServerInteractionService interaction, - ILogger logger, - ISsoConfigRepository ssoConfigRepository, - IUserRepository userRepository, - IHttpClientFactory clientFactory) + [HttpGet] + public async Task PreValidate(string domainHint) + { + if (string.IsNullOrWhiteSpace(domainHint)) { - _interaction = interaction; - _logger = logger; - _ssoConfigRepository = ssoConfigRepository; - _userRepository = userRepository; - _clientFactory = clientFactory; + Response.StatusCode = 400; + return Json(new ErrorResponseModel("No domain hint was provided")); } - - [HttpGet] - public async Task PreValidate(string domainHint) + try { - if (string.IsNullOrWhiteSpace(domainHint)) - { - Response.StatusCode = 400; - return Json(new ErrorResponseModel("No domain hint was provided")); - } - try - { - // Calls Sso Pre-Validate, assumes baseUri set - var requestCultureFeature = Request.HttpContext.Features.Get(); - var culture = requestCultureFeature.RequestCulture.Culture.Name; - var requestPath = $"/Account/PreValidate?domainHint={domainHint}&culture={culture}"; - var httpClient = _clientFactory.CreateClient("InternalSso"); + // Calls Sso Pre-Validate, assumes baseUri set + var requestCultureFeature = Request.HttpContext.Features.Get(); + var culture = requestCultureFeature.RequestCulture.Culture.Name; + var requestPath = $"/Account/PreValidate?domainHint={domainHint}&culture={culture}"; + var httpClient = _clientFactory.CreateClient("InternalSso"); - // Forward the internal SSO result - using var responseMessage = await httpClient.GetAsync(requestPath); - var responseJson = await responseMessage.Content.ReadAsStringAsync(); - Response.StatusCode = (int)responseMessage.StatusCode; - return Content(responseJson, "application/json"); - } - catch (Exception ex) - { - _logger.LogError(ex, "Error pre-validating against SSO service"); - Response.StatusCode = 500; - return Json(new ErrorResponseModel("Error pre-validating SSO authentication") - { - ExceptionMessage = ex.Message, - ExceptionStackTrace = ex.StackTrace, - InnerExceptionMessage = ex.InnerException?.Message, - }); - } + // Forward the internal SSO result + using var responseMessage = await httpClient.GetAsync(requestPath); + var responseJson = await responseMessage.Content.ReadAsStringAsync(); + Response.StatusCode = (int)responseMessage.StatusCode; + return Content(responseJson, "application/json"); } - - [HttpGet] - public async Task Login(string returnUrl) + catch (Exception ex) { - var context = await _interaction.GetAuthorizationContextAsync(returnUrl); - - var domainHint = context.Parameters.AllKeys.Contains("domain_hint") ? - context.Parameters["domain_hint"] : null; - var ssoToken = context.Parameters[SsoTokenable.TokenIdentifier]; - - if (string.IsNullOrWhiteSpace(domainHint)) + _logger.LogError(ex, "Error pre-validating against SSO service"); + Response.StatusCode = 500; + return Json(new ErrorResponseModel("Error pre-validating SSO authentication") { - throw new Exception("No domain_hint provided"); - } - - var userIdentifier = context.Parameters.AllKeys.Contains("user_identifier") ? - context.Parameters["user_identifier"] : null; - - return RedirectToAction(nameof(ExternalChallenge), new - { - domainHint = domainHint, - returnUrl, - userIdentifier, - ssoToken, + ExceptionMessage = ex.Message, + ExceptionStackTrace = ex.StackTrace, + InnerExceptionMessage = ex.InnerException?.Message, }); } + } - [HttpGet] - public async Task ExternalChallenge(string domainHint, string returnUrl, - string userIdentifier, string ssoToken) + [HttpGet] + public async Task Login(string returnUrl) + { + var context = await _interaction.GetAuthorizationContextAsync(returnUrl); + + var domainHint = context.Parameters.AllKeys.Contains("domain_hint") ? + context.Parameters["domain_hint"] : null; + var ssoToken = context.Parameters[SsoTokenable.TokenIdentifier]; + + if (string.IsNullOrWhiteSpace(domainHint)) { - if (string.IsNullOrWhiteSpace(domainHint)) - { - throw new Exception("Invalid organization reference id."); - } - - var ssoConfig = await _ssoConfigRepository.GetByIdentifierAsync(domainHint); - if (ssoConfig == null || !ssoConfig.Enabled) - { - throw new Exception("Organization not found or SSO configuration not enabled"); - } - var organizationId = ssoConfig.OrganizationId.ToString(); - - var scheme = "sso"; - var props = new AuthenticationProperties - { - RedirectUri = Url.Action(nameof(ExternalCallback)), - Items = - { - { "return_url", returnUrl }, - { "domain_hint", domainHint }, - { "organizationId", organizationId }, - { "scheme", scheme }, - }, - Parameters = - { - { "ssoToken", ssoToken }, - } - }; - - if (!string.IsNullOrWhiteSpace(userIdentifier)) - { - props.Items.Add("user_identifier", userIdentifier); - } - - return Challenge(props, scheme); + throw new Exception("No domain_hint provided"); } - [HttpGet] - public async Task ExternalCallback() + var userIdentifier = context.Parameters.AllKeys.Contains("user_identifier") ? + context.Parameters["user_identifier"] : null; + + return RedirectToAction(nameof(ExternalChallenge), new { - // Read external identity from the temporary cookie - var result = await HttpContext.AuthenticateAsync( - Core.AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); - if (result?.Succeeded != true) - { - throw new Exception("External authentication error"); - } + domainHint = domainHint, + returnUrl, + userIdentifier, + ssoToken, + }); + } - // Debugging - var externalClaims = result.Principal.Claims.Select(c => $"{c.Type}: {c.Value}"); - _logger.LogDebug("External claims: {@claims}", externalClaims); - - var (user, provider, providerUserId, claims) = await FindUserFromExternalProviderAsync(result); - if (user == null) - { - // Should never happen - throw new Exception("Cannot find user."); - } - - // This allows us to collect any additional claims or properties - // for the specific protocols used and store them in the local auth cookie. - // this is typically used to store data needed for signout from those protocols. - var additionalLocalClaims = new List(); - var localSignInProps = new AuthenticationProperties - { - IsPersistent = true, - ExpiresUtc = DateTimeOffset.UtcNow.AddMinutes(1) - }; - if (result.Properties != null && result.Properties.Items.TryGetValue("organizationId", out var organization)) - { - additionalLocalClaims.Add(new Claim("organizationId", organization)); - } - ProcessLoginCallback(result, additionalLocalClaims, localSignInProps); - - // Issue authentication cookie for user - await HttpContext.SignInAsync(new IdentityServerUser(user.Id.ToString()) - { - DisplayName = user.Email, - IdentityProvider = provider, - AdditionalClaims = additionalLocalClaims.ToArray() - }, localSignInProps); - - // Delete temporary cookie used during external authentication - await HttpContext.SignOutAsync(Core.AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); - - // Retrieve return URL - var returnUrl = result.Properties.Items["return_url"] ?? "~/"; - - var context = await _interaction.GetAuthorizationContextAsync(returnUrl); - if (context != null) - { - if (IsNativeClient(context)) - { - // The client is native, so this change in how to - // return the response is for better UX for the end user. - HttpContext.Response.StatusCode = 200; - HttpContext.Response.Headers["Location"] = string.Empty; - return View("Redirect", new RedirectViewModel { RedirectUrl = returnUrl }); - } - - // We can trust model.ReturnUrl since GetAuthorizationContextAsync returned non-null - return Redirect(returnUrl); - } - - // Request for a local page - if (Url.IsLocalUrl(returnUrl)) - { - return Redirect(returnUrl); - } - else if (string.IsNullOrEmpty(returnUrl)) - { - return Redirect("~/"); - } - else - { - // User might have clicked on a malicious link - should be logged - throw new Exception("invalid return URL"); - } + [HttpGet] + public async Task ExternalChallenge(string domainHint, string returnUrl, + string userIdentifier, string ssoToken) + { + if (string.IsNullOrWhiteSpace(domainHint)) + { + throw new Exception("Invalid organization reference id."); } - private async Task<(User user, string provider, string providerUserId, IEnumerable claims)> - FindUserFromExternalProviderAsync(AuthenticateResult result) + var ssoConfig = await _ssoConfigRepository.GetByIdentifierAsync(domainHint); + if (ssoConfig == null || !ssoConfig.Enabled) { - var externalUser = result.Principal; + throw new Exception("Organization not found or SSO configuration not enabled"); + } + var organizationId = ssoConfig.OrganizationId.ToString(); - // Try to determine the unique id of the external user (issued by the provider) - // the most common claim type for that are the sub claim and the NameIdentifier - // depending on the external provider, some other claim type might be used - var userIdClaim = externalUser.FindFirst(JwtClaimTypes.Subject) ?? - externalUser.FindFirst(ClaimTypes.NameIdentifier) ?? - throw new Exception("Unknown userid"); + var scheme = "sso"; + var props = new AuthenticationProperties + { + RedirectUri = Url.Action(nameof(ExternalCallback)), + Items = + { + { "return_url", returnUrl }, + { "domain_hint", domainHint }, + { "organizationId", organizationId }, + { "scheme", scheme }, + }, + Parameters = + { + { "ssoToken", ssoToken }, + } + }; - // remove the user id claim so we don't include it as an extra claim if/when we provision the user - var claims = externalUser.Claims.ToList(); - claims.Remove(userIdClaim); - - var provider = result.Properties.Items["scheme"]; - var providerUserId = userIdClaim.Value; - var user = await _userRepository.GetByIdAsync(new Guid(providerUserId)); - - return (user, provider, providerUserId, claims); + if (!string.IsNullOrWhiteSpace(userIdentifier)) + { + props.Items.Add("user_identifier", userIdentifier); } - private void ProcessLoginCallback(AuthenticateResult externalResult, List localClaims, - AuthenticationProperties localSignInProps) - { - // If the external system sent a session id claim, copy it over - // so we can use it for single sign-out - var sid = externalResult.Principal.Claims.FirstOrDefault(x => x.Type == JwtClaimTypes.SessionId); - if (sid != null) - { - localClaims.Add(new Claim(JwtClaimTypes.SessionId, sid.Value)); - } + return Challenge(props, scheme); + } - // If the external provider issued an idToken, we'll keep it for signout - var idToken = externalResult.Properties.GetTokenValue("id_token"); - if (idToken != null) - { - localSignInProps.StoreTokens( - new[] { new AuthenticationToken { Name = "id_token", Value = idToken } }); - } + [HttpGet] + public async Task ExternalCallback() + { + // Read external identity from the temporary cookie + var result = await HttpContext.AuthenticateAsync( + Core.AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); + if (result?.Succeeded != true) + { + throw new Exception("External authentication error"); } - private bool IsNativeClient(IdentityServer4.Models.AuthorizationRequest context) + // Debugging + var externalClaims = result.Principal.Claims.Select(c => $"{c.Type}: {c.Value}"); + _logger.LogDebug("External claims: {@claims}", externalClaims); + + var (user, provider, providerUserId, claims) = await FindUserFromExternalProviderAsync(result); + if (user == null) { - return !context.RedirectUri.StartsWith("https", StringComparison.Ordinal) - && !context.RedirectUri.StartsWith("http", StringComparison.Ordinal); + // Should never happen + throw new Exception("Cannot find user."); + } + + // This allows us to collect any additional claims or properties + // for the specific protocols used and store them in the local auth cookie. + // this is typically used to store data needed for signout from those protocols. + var additionalLocalClaims = new List(); + var localSignInProps = new AuthenticationProperties + { + IsPersistent = true, + ExpiresUtc = DateTimeOffset.UtcNow.AddMinutes(1) + }; + if (result.Properties != null && result.Properties.Items.TryGetValue("organizationId", out var organization)) + { + additionalLocalClaims.Add(new Claim("organizationId", organization)); + } + ProcessLoginCallback(result, additionalLocalClaims, localSignInProps); + + // Issue authentication cookie for user + await HttpContext.SignInAsync(new IdentityServerUser(user.Id.ToString()) + { + DisplayName = user.Email, + IdentityProvider = provider, + AdditionalClaims = additionalLocalClaims.ToArray() + }, localSignInProps); + + // Delete temporary cookie used during external authentication + await HttpContext.SignOutAsync(Core.AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); + + // Retrieve return URL + var returnUrl = result.Properties.Items["return_url"] ?? "~/"; + + var context = await _interaction.GetAuthorizationContextAsync(returnUrl); + if (context != null) + { + if (IsNativeClient(context)) + { + // The client is native, so this change in how to + // return the response is for better UX for the end user. + HttpContext.Response.StatusCode = 200; + HttpContext.Response.Headers["Location"] = string.Empty; + return View("Redirect", new RedirectViewModel { RedirectUrl = returnUrl }); + } + + // We can trust model.ReturnUrl since GetAuthorizationContextAsync returned non-null + return Redirect(returnUrl); + } + + // Request for a local page + if (Url.IsLocalUrl(returnUrl)) + { + return Redirect(returnUrl); + } + else if (string.IsNullOrEmpty(returnUrl)) + { + return Redirect("~/"); + } + else + { + // User might have clicked on a malicious link - should be logged + throw new Exception("invalid return URL"); } } + + private async Task<(User user, string provider, string providerUserId, IEnumerable claims)> + FindUserFromExternalProviderAsync(AuthenticateResult result) + { + var externalUser = result.Principal; + + // Try to determine the unique id of the external user (issued by the provider) + // the most common claim type for that are the sub claim and the NameIdentifier + // depending on the external provider, some other claim type might be used + var userIdClaim = externalUser.FindFirst(JwtClaimTypes.Subject) ?? + externalUser.FindFirst(ClaimTypes.NameIdentifier) ?? + throw new Exception("Unknown userid"); + + // remove the user id claim so we don't include it as an extra claim if/when we provision the user + var claims = externalUser.Claims.ToList(); + claims.Remove(userIdClaim); + + var provider = result.Properties.Items["scheme"]; + var providerUserId = userIdClaim.Value; + var user = await _userRepository.GetByIdAsync(new Guid(providerUserId)); + + return (user, provider, providerUserId, claims); + } + + private void ProcessLoginCallback(AuthenticateResult externalResult, List localClaims, + AuthenticationProperties localSignInProps) + { + // If the external system sent a session id claim, copy it over + // so we can use it for single sign-out + var sid = externalResult.Principal.Claims.FirstOrDefault(x => x.Type == JwtClaimTypes.SessionId); + if (sid != null) + { + localClaims.Add(new Claim(JwtClaimTypes.SessionId, sid.Value)); + } + + // If the external provider issued an idToken, we'll keep it for signout + var idToken = externalResult.Properties.GetTokenValue("id_token"); + if (idToken != null) + { + localSignInProps.StoreTokens( + new[] { new AuthenticationToken { Name = "id_token", Value = idToken } }); + } + } + + private bool IsNativeClient(IdentityServer4.Models.AuthorizationRequest context) + { + return !context.RedirectUri.StartsWith("https", StringComparison.Ordinal) + && !context.RedirectUri.StartsWith("http", StringComparison.Ordinal); + } } diff --git a/src/Identity/Models/RedirectViewModel.cs b/src/Identity/Models/RedirectViewModel.cs index 848fdf871..5cf7663b4 100644 --- a/src/Identity/Models/RedirectViewModel.cs +++ b/src/Identity/Models/RedirectViewModel.cs @@ -1,7 +1,6 @@ -namespace Bit.Identity.Models +namespace Bit.Identity.Models; + +public class RedirectViewModel { - public class RedirectViewModel - { - public string RedirectUrl { get; set; } - } + public string RedirectUrl { get; set; } } diff --git a/src/Identity/Program.cs b/src/Identity/Program.cs index 540e3ac75..e87f81aa6 100644 --- a/src/Identity/Program.cs +++ b/src/Identity/Program.cs @@ -2,44 +2,43 @@ using Bit.Core.Utilities; using Serilog.Events; -namespace Bit.Identity +namespace Bit.Identity; + +public class Program { - public class Program + public static void Main(string[] args) { - public static void Main(string[] args) - { - CreateHostBuilder(args) - .Build() - .Run(); - } + CreateHostBuilder(args) + .Build() + .Run(); + } - public static IHostBuilder CreateHostBuilder(string[] args) - { - return Host - .CreateDefaultBuilder(args) - .ConfigureCustomAppConfiguration(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => + public static IHostBuilder CreateHostBuilder(string[] args) + { + return Host + .CreateDefaultBuilder(args) + .ConfigureCustomAppConfiguration(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => + { + var context = e.Properties["SourceContext"].ToString(); + if (context.Contains(typeof(IpRateLimitMiddleware).FullName) && + e.Level == LogEventLevel.Information) { - var context = e.Properties["SourceContext"].ToString(); - if (context.Contains(typeof(IpRateLimitMiddleware).FullName) && - e.Level == LogEventLevel.Information) - { - return true; - } + return true; + } - if (context.Contains("IdentityServer4.Validation.TokenValidator") || - context.Contains("IdentityServer4.Validation.TokenRequestValidator")) - { - return e.Level > LogEventLevel.Error; - } + if (context.Contains("IdentityServer4.Validation.TokenValidator") || + context.Contains("IdentityServer4.Validation.TokenRequestValidator")) + { + return e.Level > LogEventLevel.Error; + } - return e.Level >= LogEventLevel.Error; - })); - }); - } + return e.Level >= LogEventLevel.Error; + })); + }); } } diff --git a/src/Identity/Startup.cs b/src/Identity/Startup.cs index e355f0123..170e2b931 100644 --- a/src/Identity/Startup.cs +++ b/src/Identity/Startup.cs @@ -13,214 +13,213 @@ using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.IdentityModel.Logging; using Microsoft.OpenApi.Models; -namespace Bit.Identity +namespace Bit.Identity; + +public class Startup { - public class Startup + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Configuration = configuration; + Environment = env; + } + + public IConfiguration Configuration { get; private set; } + public IWebHostEnvironment Environment { get; set; } + + public void ConfigureServices(IServiceCollection services) + { + // Options + services.AddOptions(); + + // Settings + var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); + if (!globalSettings.SelfHosted) { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - Configuration = configuration; - Environment = env; + services.Configure(Configuration.GetSection("IpRateLimitOptions")); + services.Configure(Configuration.GetSection("IpRateLimitPolicies")); } - public IConfiguration Configuration { get; private set; } - public IWebHostEnvironment Environment { get; set; } + // Data Protection + services.AddCustomDataProtectionServices(Environment, globalSettings); - public void ConfigureServices(IServiceCollection services) + // Repositories + services.AddSqlServerRepositories(globalSettings); + + // Context + services.AddScoped(); + services.TryAddSingleton(); + + // Caching + services.AddMemoryCache(); + services.AddDistributedCache(globalSettings); + + // Mvc + services.AddMvc(config => { - // Options - services.AddOptions(); + config.Filters.Add(new ModelStateValidationFilterAttribute()); + }); - // Settings - var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); - if (!globalSettings.SelfHosted) + services.AddSwaggerGen(c => + { + c.SwaggerDoc("v1", new OpenApiInfo { Title = "Bitwarden Identity", Version = "v1" }); + }); + + if (!globalSettings.SelfHosted) + { + services.AddIpRateLimiting(globalSettings); + } + + // Cookies + if (Environment.IsDevelopment()) + { + services.Configure(options => { - services.Configure(Configuration.GetSection("IpRateLimitOptions")); - services.Configure(Configuration.GetSection("IpRateLimitPolicies")); - } - - // Data Protection - services.AddCustomDataProtectionServices(Environment, globalSettings); - - // Repositories - services.AddSqlServerRepositories(globalSettings); - - // Context - services.AddScoped(); - services.TryAddSingleton(); - - // Caching - services.AddMemoryCache(); - services.AddDistributedCache(globalSettings); - - // Mvc - services.AddMvc(config => - { - config.Filters.Add(new ModelStateValidationFilterAttribute()); - }); - - services.AddSwaggerGen(c => - { - c.SwaggerDoc("v1", new OpenApiInfo { Title = "Bitwarden Identity", Version = "v1" }); - }); - - if (!globalSettings.SelfHosted) - { - services.AddIpRateLimiting(globalSettings); - } - - // Cookies - if (Environment.IsDevelopment()) - { - services.Configure(options => + options.MinimumSameSitePolicy = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; + options.OnAppendCookie = ctx => { - options.MinimumSameSitePolicy = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; - options.OnAppendCookie = ctx => - { - ctx.CookieOptions.SameSite = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; - }; - }); - } + ctx.CookieOptions.SameSite = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; + }; + }); + } - JwtSecurityTokenHandler.DefaultMapInboundClaims = false; + JwtSecurityTokenHandler.DefaultMapInboundClaims = false; - // Authentication - services - .AddDistributedIdentityServices(globalSettings) - .AddAuthentication() - .AddCookie(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme) - .AddOpenIdConnect("sso", "Single Sign On", options => + // Authentication + services + .AddDistributedIdentityServices(globalSettings) + .AddAuthentication() + .AddCookie(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme) + .AddOpenIdConnect("sso", "Single Sign On", options => + { + options.Authority = globalSettings.BaseServiceUri.InternalSso; + options.RequireHttpsMetadata = !Environment.IsDevelopment() && + globalSettings.BaseServiceUri.InternalIdentity.StartsWith("https"); + options.ClientId = "oidc-identity"; + options.ClientSecret = globalSettings.OidcIdentityClientKey; + options.ResponseMode = "form_post"; + + options.SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme; + options.ResponseType = "code"; + options.SaveTokens = false; + options.GetClaimsFromUserInfoEndpoint = true; + + options.Events = new Microsoft.AspNetCore.Authentication.OpenIdConnect.OpenIdConnectEvents { - options.Authority = globalSettings.BaseServiceUri.InternalSso; - options.RequireHttpsMetadata = !Environment.IsDevelopment() && - globalSettings.BaseServiceUri.InternalIdentity.StartsWith("https"); - options.ClientId = "oidc-identity"; - options.ClientSecret = globalSettings.OidcIdentityClientKey; - options.ResponseMode = "form_post"; - - options.SignInScheme = AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme; - options.ResponseType = "code"; - options.SaveTokens = false; - options.GetClaimsFromUserInfoEndpoint = true; - - options.Events = new Microsoft.AspNetCore.Authentication.OpenIdConnect.OpenIdConnectEvents + OnRedirectToIdentityProvider = context => { - OnRedirectToIdentityProvider = context => + // Pass domain_hint onto the sso idp + context.ProtocolMessage.DomainHint = context.Properties.Items["domain_hint"]; + context.ProtocolMessage.Parameters.Add("organizationId", context.Properties.Items["organizationId"]); + if (context.Properties.Items.ContainsKey("user_identifier")) { - // Pass domain_hint onto the sso idp - context.ProtocolMessage.DomainHint = context.Properties.Items["domain_hint"]; - context.ProtocolMessage.Parameters.Add("organizationId", context.Properties.Items["organizationId"]); - if (context.Properties.Items.ContainsKey("user_identifier")) - { - context.ProtocolMessage.SessionState = context.Properties.Items["user_identifier"]; - } - - if (context.Properties.Parameters.Count > 0 && - context.Properties.Parameters.TryGetValue(SsoTokenable.TokenIdentifier, out var tokenValue)) - { - var token = tokenValue?.ToString() ?? ""; - context.ProtocolMessage.Parameters.Add(SsoTokenable.TokenIdentifier, token); - } - return Task.FromResult(0); + context.ProtocolMessage.SessionState = context.Properties.Items["user_identifier"]; } - }; - }); - // IdentityServer - services.AddCustomIdentityServerServices(Environment, globalSettings); + if (context.Properties.Parameters.Count > 0 && + context.Properties.Parameters.TryGetValue(SsoTokenable.TokenIdentifier, out var tokenValue)) + { + var token = tokenValue?.ToString() ?? ""; + context.ProtocolMessage.Parameters.Add(SsoTokenable.TokenIdentifier, token); + } + return Task.FromResult(0); + } + }; + }); - // Identity - services.AddCustomIdentityServices(globalSettings); + // IdentityServer + services.AddCustomIdentityServerServices(Environment, globalSettings); - // Services - services.AddBaseServices(globalSettings); - services.AddDefaultServices(globalSettings); - services.AddCoreLocalizationServices(); + // Identity + services.AddCustomIdentityServices(globalSettings); - if (CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ConnectionString) && - CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ApplicationCacheTopicName)) + // Services + services.AddBaseServices(globalSettings); + services.AddDefaultServices(globalSettings); + services.AddCoreLocalizationServices(); + + if (CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ConnectionString) && + CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ApplicationCacheTopicName)) + { + services.AddHostedService(); + } + + // HttpClients + services.AddHttpClient("InternalSso", client => + { + client.BaseAddress = new Uri(globalSettings.BaseServiceUri.InternalSso); + }); + } + + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings, + ILogger logger) + { + IdentityModelEventSource.ShowPII = true; + + app.UseSerilog(env, appLifetime, globalSettings); + + // Add general security headers + app.UseMiddleware(); + + if (!env.IsDevelopment()) + { + var uri = new Uri(globalSettings.BaseServiceUri.Identity); + app.Use(async (ctx, next) => { - services.AddHostedService(); - } - - // HttpClients - services.AddHttpClient("InternalSso", client => - { - client.BaseAddress = new Uri(globalSettings.BaseServiceUri.InternalSso); + ctx.SetIdentityServerOrigin($"{uri.Scheme}://{uri.Host}"); + await next(); }); } - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings, - ILogger logger) + if (globalSettings.SelfHosted) { - IdentityModelEventSource.ShowPII = true; - - app.UseSerilog(env, appLifetime, globalSettings); - - // Add general security headers - app.UseMiddleware(); - - if (!env.IsDevelopment()) - { - var uri = new Uri(globalSettings.BaseServiceUri.Identity); - app.Use(async (ctx, next) => - { - ctx.SetIdentityServerOrigin($"{uri.Scheme}://{uri.Host}"); - await next(); - }); - } - - if (globalSettings.SelfHosted) - { - app.UsePathBase("/identity"); - app.UseForwardedHeaders(globalSettings); - } - - // Default Middleware - app.UseDefaultMiddleware(env, globalSettings); - - if (!globalSettings.SelfHosted) - { - // Rate limiting - app.UseMiddleware(); - } - - if (env.IsDevelopment()) - { - app.UseSwagger(); - app.UseDeveloperExceptionPage(); - app.UseCookiePolicy(); - } - - // Add localization - app.UseCoreLocalization(); - - // Add static files to the request pipeline. - app.UseStaticFiles(); - - // Add routing - app.UseRouting(); - - // Add Cors - app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) - .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); - - // Add current context - app.UseMiddleware(); - - // Add IdentityServer to the request pipeline. - app.UseIdentityServer(); - - // Add Mvc stuff - app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); - - // Log startup - logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started."); + app.UsePathBase("/identity"); + app.UseForwardedHeaders(globalSettings); } + + // Default Middleware + app.UseDefaultMiddleware(env, globalSettings); + + if (!globalSettings.SelfHosted) + { + // Rate limiting + app.UseMiddleware(); + } + + if (env.IsDevelopment()) + { + app.UseSwagger(); + app.UseDeveloperExceptionPage(); + app.UseCookiePolicy(); + } + + // Add localization + app.UseCoreLocalization(); + + // Add static files to the request pipeline. + app.UseStaticFiles(); + + // Add routing + app.UseRouting(); + + // Add Cors + app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) + .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); + + // Add current context + app.UseMiddleware(); + + // Add IdentityServer to the request pipeline. + app.UseIdentityServer(); + + // Add Mvc stuff + app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); + + // Log startup + logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started."); } } diff --git a/src/Identity/Utilities/DiscoveryResponseGenerator.cs b/src/Identity/Utilities/DiscoveryResponseGenerator.cs index 32a5e6ddb..da0618098 100644 --- a/src/Identity/Utilities/DiscoveryResponseGenerator.cs +++ b/src/Identity/Utilities/DiscoveryResponseGenerator.cs @@ -5,32 +5,31 @@ using IdentityServer4.Services; using IdentityServer4.Stores; using IdentityServer4.Validation; -namespace Bit.Identity.Utilities +namespace Bit.Identity.Utilities; + +public class DiscoveryResponseGenerator : IdentityServer4.ResponseHandling.DiscoveryResponseGenerator { - public class DiscoveryResponseGenerator : IdentityServer4.ResponseHandling.DiscoveryResponseGenerator + private readonly GlobalSettings _globalSettings; + + public DiscoveryResponseGenerator( + IdentityServerOptions options, + IResourceStore resourceStore, + IKeyMaterialService keys, + ExtensionGrantValidator extensionGrants, + ISecretsListParser secretParsers, + IResourceOwnerPasswordValidator resourceOwnerValidator, + ILogger logger, + GlobalSettings globalSettings) + : base(options, resourceStore, keys, extensionGrants, secretParsers, resourceOwnerValidator, logger) { - private readonly GlobalSettings _globalSettings; + _globalSettings = globalSettings; + } - public DiscoveryResponseGenerator( - IdentityServerOptions options, - IResourceStore resourceStore, - IKeyMaterialService keys, - ExtensionGrantValidator extensionGrants, - ISecretsListParser secretParsers, - IResourceOwnerPasswordValidator resourceOwnerValidator, - ILogger logger, - GlobalSettings globalSettings) - : base(options, resourceStore, keys, extensionGrants, secretParsers, resourceOwnerValidator, logger) - { - _globalSettings = globalSettings; - } - - public override async Task> CreateDiscoveryDocumentAsync( - string baseUrl, string issuerUri) - { - var dict = await base.CreateDiscoveryDocumentAsync(baseUrl, issuerUri); - return CoreHelpers.AdjustIdentityServerConfig(dict, _globalSettings.BaseServiceUri.Identity, - _globalSettings.BaseServiceUri.InternalIdentity); - } + public override async Task> CreateDiscoveryDocumentAsync( + string baseUrl, string issuerUri) + { + var dict = await base.CreateDiscoveryDocumentAsync(baseUrl, issuerUri); + return CoreHelpers.AdjustIdentityServerConfig(dict, _globalSettings.BaseServiceUri.Identity, + _globalSettings.BaseServiceUri.InternalIdentity); } } diff --git a/src/Identity/Utilities/ServiceCollectionExtensions.cs b/src/Identity/Utilities/ServiceCollectionExtensions.cs index 82000ebcf..df3a6dec8 100644 --- a/src/Identity/Utilities/ServiceCollectionExtensions.cs +++ b/src/Identity/Utilities/ServiceCollectionExtensions.cs @@ -5,48 +5,47 @@ using IdentityServer4.ResponseHandling; using IdentityServer4.Services; using IdentityServer4.Stores; -namespace Bit.Identity.Utilities +namespace Bit.Identity.Utilities; + +public static class ServiceCollectionExtensions { - public static class ServiceCollectionExtensions + public static IIdentityServerBuilder AddCustomIdentityServerServices(this IServiceCollection services, + IWebHostEnvironment env, GlobalSettings globalSettings) { - public static IIdentityServerBuilder AddCustomIdentityServerServices(this IServiceCollection services, - IWebHostEnvironment env, GlobalSettings globalSettings) - { - services.AddTransient(); + services.AddTransient(); - services.AddSingleton(); - services.AddTransient(); + services.AddSingleton(); + services.AddTransient(); - var issuerUri = new Uri(globalSettings.BaseServiceUri.InternalIdentity); - var identityServerBuilder = services - .AddIdentityServer(options => + var issuerUri = new Uri(globalSettings.BaseServiceUri.InternalIdentity); + var identityServerBuilder = services + .AddIdentityServer(options => + { + options.Endpoints.EnableIntrospectionEndpoint = false; + options.Endpoints.EnableEndSessionEndpoint = false; + options.Endpoints.EnableUserInfoEndpoint = false; + options.Endpoints.EnableCheckSessionEndpoint = false; + options.Endpoints.EnableTokenRevocationEndpoint = false; + options.IssuerUri = $"{issuerUri.Scheme}://{issuerUri.Host}"; + options.Caching.ClientStoreExpiration = new TimeSpan(0, 5, 0); + if (env.IsDevelopment()) { - options.Endpoints.EnableIntrospectionEndpoint = false; - options.Endpoints.EnableEndSessionEndpoint = false; - options.Endpoints.EnableUserInfoEndpoint = false; - options.Endpoints.EnableCheckSessionEndpoint = false; - options.Endpoints.EnableTokenRevocationEndpoint = false; - options.IssuerUri = $"{issuerUri.Scheme}://{issuerUri.Host}"; - options.Caching.ClientStoreExpiration = new TimeSpan(0, 5, 0); - if (env.IsDevelopment()) - { - options.Authentication.CookieSameSiteMode = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; - } - options.InputLengthRestrictions.UserName = 256; - }) - .AddInMemoryCaching() - .AddInMemoryApiResources(ApiResources.GetApiResources()) - .AddInMemoryApiScopes(ApiScopes.GetApiScopes()) - .AddClientStoreCache() - .AddCustomTokenRequestValidator() - .AddProfileService() - .AddResourceOwnerValidator() - .AddPersistedGrantStore() - .AddClientStore() - .AddIdentityServerCertificate(env, globalSettings); + options.Authentication.CookieSameSiteMode = Microsoft.AspNetCore.Http.SameSiteMode.Unspecified; + } + options.InputLengthRestrictions.UserName = 256; + }) + .AddInMemoryCaching() + .AddInMemoryApiResources(ApiResources.GetApiResources()) + .AddInMemoryApiScopes(ApiScopes.GetApiScopes()) + .AddClientStoreCache() + .AddCustomTokenRequestValidator() + .AddProfileService() + .AddResourceOwnerValidator() + .AddPersistedGrantStore() + .AddClientStore() + .AddIdentityServerCertificate(env, globalSettings); - services.AddTransient(); - return identityServerBuilder; - } + services.AddTransient(); + return identityServerBuilder; } } diff --git a/src/Infrastructure.Dapper/DapperHelpers.cs b/src/Infrastructure.Dapper/DapperHelpers.cs index 720355649..48949df67 100644 --- a/src/Infrastructure.Dapper/DapperHelpers.cs +++ b/src/Infrastructure.Dapper/DapperHelpers.cs @@ -3,133 +3,132 @@ using Bit.Core.Entities; using Bit.Core.Models.Data; using Dapper; -namespace Bit.Infrastructure.Dapper +namespace Bit.Infrastructure.Dapper; + +public static class DapperHelpers { - public static class DapperHelpers + public static DataTable ToGuidIdArrayTVP(this IEnumerable ids) { - public static DataTable ToGuidIdArrayTVP(this IEnumerable ids) - { - return ids.ToArrayTVP("GuidId"); - } + return ids.ToArrayTVP("GuidId"); + } - public static DataTable ToArrayTVP(this IEnumerable values, string columnName) - { - var table = new DataTable(); - table.SetTypeName($"[dbo].[{columnName}Array]"); - table.Columns.Add(columnName, typeof(T)); + public static DataTable ToArrayTVP(this IEnumerable values, string columnName) + { + var table = new DataTable(); + table.SetTypeName($"[dbo].[{columnName}Array]"); + table.Columns.Add(columnName, typeof(T)); - if (values != null) + if (values != null) + { + foreach (var value in values) { - foreach (var value in values) - { - table.Rows.Add(value); - } + table.Rows.Add(value); } - - return table; } - public static DataTable ToArrayTVP(this IEnumerable values) + return table; + } + + public static DataTable ToArrayTVP(this IEnumerable values) + { + var table = new DataTable(); + table.SetTypeName("[dbo].[SelectionReadOnlyArray]"); + + var idColumn = new DataColumn("Id", typeof(Guid)); + table.Columns.Add(idColumn); + var readOnlyColumn = new DataColumn("ReadOnly", typeof(bool)); + table.Columns.Add(readOnlyColumn); + var hidePasswordsColumn = new DataColumn("HidePasswords", typeof(bool)); + table.Columns.Add(hidePasswordsColumn); + + if (values != null) { - var table = new DataTable(); - table.SetTypeName("[dbo].[SelectionReadOnlyArray]"); - - var idColumn = new DataColumn("Id", typeof(Guid)); - table.Columns.Add(idColumn); - var readOnlyColumn = new DataColumn("ReadOnly", typeof(bool)); - table.Columns.Add(readOnlyColumn); - var hidePasswordsColumn = new DataColumn("HidePasswords", typeof(bool)); - table.Columns.Add(hidePasswordsColumn); - - if (values != null) - { - foreach (var value in values) - { - var row = table.NewRow(); - row[idColumn] = value.Id; - row[readOnlyColumn] = value.ReadOnly; - row[hidePasswordsColumn] = value.HidePasswords; - table.Rows.Add(row); - } - } - - return table; - } - - public static DataTable ToTvp(this IEnumerable orgUsers) - { - var table = new DataTable(); - table.SetTypeName("[dbo].[OrganizationUserType]"); - - var columnData = new List<(string name, Type type, Func getter)> - { - (nameof(OrganizationUser.Id), typeof(Guid), ou => ou.Id), - (nameof(OrganizationUser.OrganizationId), typeof(Guid), ou => ou.OrganizationId), - (nameof(OrganizationUser.UserId), typeof(Guid), ou => ou.UserId), - (nameof(OrganizationUser.Email), typeof(string), ou => ou.Email), - (nameof(OrganizationUser.Key), typeof(string), ou => ou.Key), - (nameof(OrganizationUser.Status), typeof(byte), ou => ou.Status), - (nameof(OrganizationUser.Type), typeof(byte), ou => ou.Type), - (nameof(OrganizationUser.AccessAll), typeof(bool), ou => ou.AccessAll), - (nameof(OrganizationUser.ExternalId), typeof(string), ou => ou.ExternalId), - (nameof(OrganizationUser.CreationDate), typeof(DateTime), ou => ou.CreationDate), - (nameof(OrganizationUser.RevisionDate), typeof(DateTime), ou => ou.RevisionDate), - (nameof(OrganizationUser.Permissions), typeof(string), ou => ou.Permissions), - (nameof(OrganizationUser.ResetPasswordKey), typeof(string), ou => ou.ResetPasswordKey), - }; - - return orgUsers.BuildTable(table, columnData); - } - - public static DataTable ToTvp(this IEnumerable organizationSponsorships) - { - var table = new DataTable(); - table.SetTypeName("[dbo].[OrganizationSponsorshipType]"); - - var columnData = new List<(string name, Type type, Func getter)> - { - (nameof(OrganizationSponsorship.Id), typeof(Guid), ou => ou.Id), - (nameof(OrganizationSponsorship.SponsoringOrganizationId), typeof(Guid), ou => ou.SponsoringOrganizationId), - (nameof(OrganizationSponsorship.SponsoringOrganizationUserId), typeof(Guid), ou => ou.SponsoringOrganizationUserId), - (nameof(OrganizationSponsorship.SponsoredOrganizationId), typeof(Guid), ou => ou.SponsoredOrganizationId), - (nameof(OrganizationSponsorship.FriendlyName), typeof(string), ou => ou.FriendlyName), - (nameof(OrganizationSponsorship.OfferedToEmail), typeof(string), ou => ou.OfferedToEmail), - (nameof(OrganizationSponsorship.PlanSponsorshipType), typeof(byte), ou => ou.PlanSponsorshipType), - (nameof(OrganizationSponsorship.LastSyncDate), typeof(DateTime), ou => ou.LastSyncDate), - (nameof(OrganizationSponsorship.ValidUntil), typeof(DateTime), ou => ou.ValidUntil), - (nameof(OrganizationSponsorship.ToDelete), typeof(bool), ou => ou.ToDelete), - }; - - return organizationSponsorships.BuildTable(table, columnData); - } - - private static DataTable BuildTable(this IEnumerable entities, DataTable table, List<(string name, Type type, Func getter)> columnData) - { - foreach (var (name, type, getter) in columnData) - { - var column = new DataColumn(name, type); - table.Columns.Add(column); - } - - foreach (var entity in entities ?? new T[] { }) + foreach (var value in values) { var row = table.NewRow(); - foreach (var (name, type, getter) in columnData) - { - var val = getter(entity); - if (val == null) - { - row[name] = DBNull.Value; - } - else - { - row[name] = val; - } - } + row[idColumn] = value.Id; + row[readOnlyColumn] = value.ReadOnly; + row[hidePasswordsColumn] = value.HidePasswords; table.Rows.Add(row); } - - return table; } + + return table; + } + + public static DataTable ToTvp(this IEnumerable orgUsers) + { + var table = new DataTable(); + table.SetTypeName("[dbo].[OrganizationUserType]"); + + var columnData = new List<(string name, Type type, Func getter)> + { + (nameof(OrganizationUser.Id), typeof(Guid), ou => ou.Id), + (nameof(OrganizationUser.OrganizationId), typeof(Guid), ou => ou.OrganizationId), + (nameof(OrganizationUser.UserId), typeof(Guid), ou => ou.UserId), + (nameof(OrganizationUser.Email), typeof(string), ou => ou.Email), + (nameof(OrganizationUser.Key), typeof(string), ou => ou.Key), + (nameof(OrganizationUser.Status), typeof(byte), ou => ou.Status), + (nameof(OrganizationUser.Type), typeof(byte), ou => ou.Type), + (nameof(OrganizationUser.AccessAll), typeof(bool), ou => ou.AccessAll), + (nameof(OrganizationUser.ExternalId), typeof(string), ou => ou.ExternalId), + (nameof(OrganizationUser.CreationDate), typeof(DateTime), ou => ou.CreationDate), + (nameof(OrganizationUser.RevisionDate), typeof(DateTime), ou => ou.RevisionDate), + (nameof(OrganizationUser.Permissions), typeof(string), ou => ou.Permissions), + (nameof(OrganizationUser.ResetPasswordKey), typeof(string), ou => ou.ResetPasswordKey), + }; + + return orgUsers.BuildTable(table, columnData); + } + + public static DataTable ToTvp(this IEnumerable organizationSponsorships) + { + var table = new DataTable(); + table.SetTypeName("[dbo].[OrganizationSponsorshipType]"); + + var columnData = new List<(string name, Type type, Func getter)> + { + (nameof(OrganizationSponsorship.Id), typeof(Guid), ou => ou.Id), + (nameof(OrganizationSponsorship.SponsoringOrganizationId), typeof(Guid), ou => ou.SponsoringOrganizationId), + (nameof(OrganizationSponsorship.SponsoringOrganizationUserId), typeof(Guid), ou => ou.SponsoringOrganizationUserId), + (nameof(OrganizationSponsorship.SponsoredOrganizationId), typeof(Guid), ou => ou.SponsoredOrganizationId), + (nameof(OrganizationSponsorship.FriendlyName), typeof(string), ou => ou.FriendlyName), + (nameof(OrganizationSponsorship.OfferedToEmail), typeof(string), ou => ou.OfferedToEmail), + (nameof(OrganizationSponsorship.PlanSponsorshipType), typeof(byte), ou => ou.PlanSponsorshipType), + (nameof(OrganizationSponsorship.LastSyncDate), typeof(DateTime), ou => ou.LastSyncDate), + (nameof(OrganizationSponsorship.ValidUntil), typeof(DateTime), ou => ou.ValidUntil), + (nameof(OrganizationSponsorship.ToDelete), typeof(bool), ou => ou.ToDelete), + }; + + return organizationSponsorships.BuildTable(table, columnData); + } + + private static DataTable BuildTable(this IEnumerable entities, DataTable table, List<(string name, Type type, Func getter)> columnData) + { + foreach (var (name, type, getter) in columnData) + { + var column = new DataColumn(name, type); + table.Columns.Add(column); + } + + foreach (var entity in entities ?? new T[] { }) + { + var row = table.NewRow(); + foreach (var (name, type, getter) in columnData) + { + var val = getter(entity); + if (val == null) + { + row[name] = DBNull.Value; + } + else + { + row[name] = val; + } + } + table.Rows.Add(row); + } + + return table; } } diff --git a/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs b/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs index a75a05fcc..9c138f7b0 100644 --- a/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs +++ b/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs @@ -2,43 +2,42 @@ using Bit.Infrastructure.Dapper.Repositories; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.Dapper -{ - public static class DapperServiceCollectionExtensions - { - public static void AddDapperRepositories(this IServiceCollection services, bool selfHosted) - { - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); +namespace Bit.Infrastructure.Dapper; - if (selfHosted) - { - services.AddSingleton(); - } +public static class DapperServiceCollectionExtensions +{ + public static void AddDapperRepositories(this IServiceCollection services, bool selfHosted) + { + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + + if (selfHosted) + { + services.AddSingleton(); } } } diff --git a/src/Infrastructure.Dapper/Repositories/BaseRepository.cs b/src/Infrastructure.Dapper/Repositories/BaseRepository.cs index 135f024a6..4a3694d85 100644 --- a/src/Infrastructure.Dapper/Repositories/BaseRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/BaseRepository.cs @@ -1,30 +1,29 @@ using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public abstract class BaseRepository { - public abstract class BaseRepository + static BaseRepository() { - static BaseRepository() - { - SqlMapper.AddTypeHandler(new DateTimeHandler()); - } - - public BaseRepository(string connectionString, string readOnlyConnectionString) - { - if (string.IsNullOrWhiteSpace(connectionString)) - { - throw new ArgumentNullException(nameof(connectionString)); - } - if (string.IsNullOrWhiteSpace(readOnlyConnectionString)) - { - throw new ArgumentNullException(nameof(readOnlyConnectionString)); - } - - ConnectionString = connectionString; - ReadOnlyConnectionString = readOnlyConnectionString; - } - - protected string ConnectionString { get; private set; } - protected string ReadOnlyConnectionString { get; private set; } + SqlMapper.AddTypeHandler(new DateTimeHandler()); } + + public BaseRepository(string connectionString, string readOnlyConnectionString) + { + if (string.IsNullOrWhiteSpace(connectionString)) + { + throw new ArgumentNullException(nameof(connectionString)); + } + if (string.IsNullOrWhiteSpace(readOnlyConnectionString)) + { + throw new ArgumentNullException(nameof(readOnlyConnectionString)); + } + + ConnectionString = connectionString; + ReadOnlyConnectionString = readOnlyConnectionString; + } + + protected string ConnectionString { get; private set; } + protected string ReadOnlyConnectionString { get; private set; } } diff --git a/src/Infrastructure.Dapper/Repositories/CipherRepository.cs b/src/Infrastructure.Dapper/Repositories/CipherRepository.cs index 33560aad5..a2b757a71 100644 --- a/src/Infrastructure.Dapper/Repositories/CipherRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/CipherRepository.cs @@ -8,325 +8,325 @@ using Bit.Core.Settings; using Core.Models.Data; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class CipherRepository : Repository, ICipherRepository { - public class CipherRepository : Repository, ICipherRepository + public CipherRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public CipherRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetByIdAsync(Guid id, Guid userId) { - public CipherRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public CipherRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetByIdAsync(Guid id, Guid userId) + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[CipherDetails_ReadByIdUserId]", - new { Id = id, UserId = userId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[CipherDetails_ReadByIdUserId]", + new { Id = id, UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.FirstOrDefault(); - } + return results.FirstOrDefault(); + } + } + + public async Task GetOrganizationDetailsByIdAsync(Guid id) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[CipherOrganizationDetails_ReadById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); + + return results.FirstOrDefault(); + } + } + + public async Task> GetManyOrganizationDetailsByOrganizationIdAsync( + Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[CipherOrganizationDetails_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task GetCanEditByIdAsync(Guid userId, Guid cipherId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var result = await connection.QueryFirstOrDefaultAsync( + $"[{Schema}].[Cipher_ReadCanEditByIdUserId]", + new { UserId = userId, Id = cipherId }, + commandType: CommandType.StoredProcedure); + + return result; + } + } + + public async Task> GetManyByUserIdAsync(Guid userId, bool withOrganizations = true) + { + string sprocName = null; + if (withOrganizations) + { + sprocName = $"[{Schema}].[CipherDetails_ReadByUserId]"; + } + else + { + sprocName = $"[{Schema}].[CipherDetails_ReadWithoutOrganizationsByUserId]"; } - public async Task GetOrganizationDetailsByIdAsync(Guid id) + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[CipherOrganizationDetails_ReadById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + sprocName, + new { UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.FirstOrDefault(); - } + return results + .GroupBy(c => c.Id) + .Select(g => g.OrderByDescending(og => og.Edit).First()) + .ToList(); } + } - public async Task> GetManyOrganizationDetailsByOrganizationIdAsync( - Guid organizationId) + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[CipherOrganizationDetails_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[Cipher_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task GetCanEditByIdAsync(Guid userId, Guid cipherId) + public async Task CreateAsync(Cipher cipher, IEnumerable collectionIds) + { + cipher.SetNewId(); + var objWithCollections = JsonSerializer.Deserialize( + JsonSerializer.Serialize(cipher)); + objWithCollections.CollectionIds = collectionIds.ToGuidIdArrayTVP(); + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var result = await connection.QueryFirstOrDefaultAsync( - $"[{Schema}].[Cipher_ReadCanEditByIdUserId]", - new { UserId = userId, Id = cipherId }, - commandType: CommandType.StoredProcedure); - - return result; - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_CreateWithCollections]", + objWithCollections, + commandType: CommandType.StoredProcedure); } + } - public async Task> GetManyByUserIdAsync(Guid userId, bool withOrganizations = true) + public async Task CreateAsync(CipherDetails cipher) + { + cipher.SetNewId(); + using (var connection = new SqlConnection(ConnectionString)) { - string sprocName = null; - if (withOrganizations) - { - sprocName = $"[{Schema}].[CipherDetails_ReadByUserId]"; - } - else - { - sprocName = $"[{Schema}].[CipherDetails_ReadWithoutOrganizationsByUserId]"; - } - - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - sprocName, - new { UserId = userId }, - commandType: CommandType.StoredProcedure); - - return results - .GroupBy(c => c.Id) - .Select(g => g.OrderByDescending(og => og.Edit).First()) - .ToList(); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[CipherDetails_Create]", + cipher, + commandType: CommandType.StoredProcedure); } + } - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + public async Task CreateAsync(CipherDetails cipher, IEnumerable collectionIds) + { + cipher.SetNewId(); + var objWithCollections = JsonSerializer.Deserialize( + JsonSerializer.Serialize(cipher)); + objWithCollections.CollectionIds = collectionIds.ToGuidIdArrayTVP(); + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[Cipher_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[CipherDetails_CreateWithCollections]", + objWithCollections, + commandType: CommandType.StoredProcedure); } + } - public async Task CreateAsync(Cipher cipher, IEnumerable collectionIds) + public async Task ReplaceAsync(CipherDetails obj) + { + using (var connection = new SqlConnection(ConnectionString)) { - cipher.SetNewId(); - var objWithCollections = JsonSerializer.Deserialize( - JsonSerializer.Serialize(cipher)); - objWithCollections.CollectionIds = collectionIds.ToGuidIdArrayTVP(); - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_CreateWithCollections]", - objWithCollections, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[CipherDetails_Update]", + obj, + commandType: CommandType.StoredProcedure); } + } - public async Task CreateAsync(CipherDetails cipher) + public async Task UpsertAsync(CipherDetails cipher) + { + if (cipher.Id.Equals(default)) { - cipher.SetNewId(); - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[CipherDetails_Create]", - cipher, - commandType: CommandType.StoredProcedure); - } + await CreateAsync(cipher); } - - public async Task CreateAsync(CipherDetails cipher, IEnumerable collectionIds) + else { - cipher.SetNewId(); - var objWithCollections = JsonSerializer.Deserialize( - JsonSerializer.Serialize(cipher)); - objWithCollections.CollectionIds = collectionIds.ToGuidIdArrayTVP(); - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[CipherDetails_CreateWithCollections]", - objWithCollections, - commandType: CommandType.StoredProcedure); - } + await ReplaceAsync(cipher); } + } - public async Task ReplaceAsync(CipherDetails obj) + public async Task ReplaceAsync(Cipher obj, IEnumerable collectionIds) + { + var objWithCollections = JsonSerializer.Deserialize( + JsonSerializer.Serialize(obj)); + objWithCollections.CollectionIds = collectionIds.ToGuidIdArrayTVP(); + + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[CipherDetails_Update]", - obj, - commandType: CommandType.StoredProcedure); - } + var result = await connection.ExecuteScalarAsync( + $"[{Schema}].[Cipher_UpdateWithCollections]", + objWithCollections, + commandType: CommandType.StoredProcedure); + return result >= 0; } + } - public async Task UpsertAsync(CipherDetails cipher) + public async Task UpdatePartialAsync(Guid id, Guid userId, Guid? folderId, bool favorite) + { + using (var connection = new SqlConnection(ConnectionString)) { - if (cipher.Id.Equals(default)) - { - await CreateAsync(cipher); - } - else - { - await ReplaceAsync(cipher); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_UpdatePartial]", + new { Id = id, UserId = userId, FolderId = folderId, Favorite = favorite }, + commandType: CommandType.StoredProcedure); } + } - public async Task ReplaceAsync(Cipher obj, IEnumerable collectionIds) + public async Task UpdateAttachmentAsync(CipherAttachment attachment) + { + using (var connection = new SqlConnection(ConnectionString)) { - var objWithCollections = JsonSerializer.Deserialize( - JsonSerializer.Serialize(obj)); - objWithCollections.CollectionIds = collectionIds.ToGuidIdArrayTVP(); - - using (var connection = new SqlConnection(ConnectionString)) - { - var result = await connection.ExecuteScalarAsync( - $"[{Schema}].[Cipher_UpdateWithCollections]", - objWithCollections, - commandType: CommandType.StoredProcedure); - return result >= 0; - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_UpdateAttachment]", + attachment, + commandType: CommandType.StoredProcedure); } + } - public async Task UpdatePartialAsync(Guid id, Guid userId, Guid? folderId, bool favorite) + public async Task DeleteAttachmentAsync(Guid cipherId, string attachmentId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_UpdatePartial]", - new { Id = id, UserId = userId, FolderId = folderId, Favorite = favorite }, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_DeleteAttachment]", + new { Id = cipherId, AttachmentId = attachmentId }, + commandType: CommandType.StoredProcedure); } + } - public async Task UpdateAttachmentAsync(CipherAttachment attachment) + public async Task DeleteAsync(IEnumerable ids, Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_UpdateAttachment]", - attachment, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_Delete]", + new { Ids = ids.ToGuidIdArrayTVP(), UserId = userId }, + commandType: CommandType.StoredProcedure); } + } - public async Task DeleteAttachmentAsync(Guid cipherId, string attachmentId) + public async Task DeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_DeleteAttachment]", - new { Id = cipherId, AttachmentId = attachmentId }, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_DeleteByIdsOrganizationId]", + new { Ids = ids.ToGuidIdArrayTVP(), OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); } + } - public async Task DeleteAsync(IEnumerable ids, Guid userId) + public async Task SoftDeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_Delete]", - new { Ids = ids.ToGuidIdArrayTVP(), UserId = userId }, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_SoftDeleteByIdsOrganizationId]", + new { Ids = ids.ToGuidIdArrayTVP(), OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); } + } - public async Task DeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId) + public async Task MoveAsync(IEnumerable ids, Guid? folderId, Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_DeleteByIdsOrganizationId]", - new { Ids = ids.ToGuidIdArrayTVP(), OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_Move]", + new { Ids = ids.ToGuidIdArrayTVP(), FolderId = folderId, UserId = userId }, + commandType: CommandType.StoredProcedure); } + } - public async Task SoftDeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId) + public async Task DeleteByUserIdAsync(Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_SoftDeleteByIdsOrganizationId]", - new { Ids = ids.ToGuidIdArrayTVP(), OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_DeleteByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); } + } - public async Task MoveAsync(IEnumerable ids, Guid? folderId, Guid userId) + public async Task DeleteByOrganizationIdAsync(Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_Move]", - new { Ids = ids.ToGuidIdArrayTVP(), FolderId = folderId, UserId = userId }, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_DeleteByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); } + } - public async Task DeleteByUserIdAsync(Guid userId) + public Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable ciphers, IEnumerable folders, IEnumerable sends) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_DeleteByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); - } - } + connection.Open(); - public async Task DeleteByOrganizationIdAsync(Guid organizationId) - { - using (var connection = new SqlConnection(ConnectionString)) + using (var transaction = connection.BeginTransaction()) { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_DeleteByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); - } - } - - public Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable ciphers, IEnumerable folders, IEnumerable sends) - { - using (var connection = new SqlConnection(ConnectionString)) - { - connection.Open(); - - using (var transaction = connection.BeginTransaction()) + try { - try + // 1. Update user. + + using (var cmd = new SqlCommand("[dbo].[User_UpdateKeys]", connection, transaction)) { - // 1. Update user. + cmd.CommandType = CommandType.StoredProcedure; + cmd.Parameters.Add("@Id", SqlDbType.UniqueIdentifier).Value = user.Id; + cmd.Parameters.Add("@SecurityStamp", SqlDbType.NVarChar).Value = user.SecurityStamp; + cmd.Parameters.Add("@Key", SqlDbType.VarChar).Value = user.Key; - using (var cmd = new SqlCommand("[dbo].[User_UpdateKeys]", connection, transaction)) + if (string.IsNullOrWhiteSpace(user.PrivateKey)) { - cmd.CommandType = CommandType.StoredProcedure; - cmd.Parameters.Add("@Id", SqlDbType.UniqueIdentifier).Value = user.Id; - cmd.Parameters.Add("@SecurityStamp", SqlDbType.NVarChar).Value = user.SecurityStamp; - cmd.Parameters.Add("@Key", SqlDbType.VarChar).Value = user.Key; - - if (string.IsNullOrWhiteSpace(user.PrivateKey)) - { - cmd.Parameters.Add("@PrivateKey", SqlDbType.VarChar).Value = DBNull.Value; - } - else - { - cmd.Parameters.Add("@PrivateKey", SqlDbType.VarChar).Value = user.PrivateKey; - } - - cmd.Parameters.Add("@RevisionDate", SqlDbType.DateTime2).Value = user.RevisionDate; - cmd.ExecuteNonQuery(); + cmd.Parameters.Add("@PrivateKey", SqlDbType.VarChar).Value = DBNull.Value; + } + else + { + cmd.Parameters.Add("@PrivateKey", SqlDbType.VarChar).Value = user.PrivateKey; } - // 2. Create temp tables to bulk copy into. + cmd.Parameters.Add("@RevisionDate", SqlDbType.DateTime2).Value = user.RevisionDate; + cmd.ExecuteNonQuery(); + } - var sqlCreateTemp = @" + // 2. Create temp tables to bulk copy into. + + var sqlCreateTemp = @" SELECT TOP 0 * INTO #TempCipher FROM [dbo].[Cipher] @@ -339,50 +339,50 @@ namespace Bit.Infrastructure.Dapper.Repositories INTO #TempSend FROM [dbo].[Send]"; - using (var cmd = new SqlCommand(sqlCreateTemp, connection, transaction)) + using (var cmd = new SqlCommand(sqlCreateTemp, connection, transaction)) + { + cmd.ExecuteNonQuery(); + } + + // 3. Bulk copy into temp tables. + + if (ciphers.Any()) + { + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) { - cmd.ExecuteNonQuery(); + bulkCopy.DestinationTableName = "#TempCipher"; + var dataTable = BuildCiphersTable(bulkCopy, ciphers); + bulkCopy.WriteToServer(dataTable); } + } - // 3. Bulk copy into temp tables. - - if (ciphers.Any()) + if (folders.Any()) + { + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) { - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "#TempCipher"; - var dataTable = BuildCiphersTable(bulkCopy, ciphers); - bulkCopy.WriteToServer(dataTable); - } + bulkCopy.DestinationTableName = "#TempFolder"; + var dataTable = BuildFoldersTable(bulkCopy, folders); + bulkCopy.WriteToServer(dataTable); } + } - if (folders.Any()) + if (sends.Any()) + { + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) { - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "#TempFolder"; - var dataTable = BuildFoldersTable(bulkCopy, folders); - bulkCopy.WriteToServer(dataTable); - } + bulkCopy.DestinationTableName = "#TempSend"; + var dataTable = BuildSendsTable(bulkCopy, sends); + bulkCopy.WriteToServer(dataTable); } + } - if (sends.Any()) - { - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "#TempSend"; - var dataTable = BuildSendsTable(bulkCopy, sends); - bulkCopy.WriteToServer(dataTable); - } - } + // 4. Insert into real tables from temp tables and clean up. - // 4. Insert into real tables from temp tables and clean up. + var sql = string.Empty; - var sql = string.Empty; - - if (ciphers.Any()) - { - sql += @" + if (ciphers.Any()) + { + sql += @" UPDATE [dbo].[Cipher] SET @@ -395,11 +395,11 @@ namespace Bit.Infrastructure.Dapper.Repositories #TempCipher TC ON C.Id = TC.Id WHERE C.[UserId] = @UserId"; - } + } - if (folders.Any()) - { - sql += @" + if (folders.Any()) + { + sql += @" UPDATE [dbo].[Folder] SET @@ -411,11 +411,11 @@ namespace Bit.Infrastructure.Dapper.Repositories #TempFolder TF ON F.Id = TF.Id WHERE F.[UserId] = @UserId"; - } + } - if (sends.Any()) - { - sql += @" + if (sends.Any()) + { + sql += @" UPDATE [dbo].[Send] SET @@ -427,72 +427,72 @@ namespace Bit.Infrastructure.Dapper.Repositories #TempSend TS ON S.Id = TS.Id WHERE S.[UserId] = @UserId"; - } + } - sql += @" + sql += @" DROP TABLE #TempCipher DROP TABLE #TempFolder DROP TABLE #TempSend"; - using (var cmd = new SqlCommand(sql, connection, transaction)) - { - cmd.Parameters.Add("@UserId", SqlDbType.UniqueIdentifier).Value = user.Id; - cmd.ExecuteNonQuery(); - } - - transaction.Commit(); - } - catch + using (var cmd = new SqlCommand(sql, connection, transaction)) { - transaction.Rollback(); - throw; + cmd.Parameters.Add("@UserId", SqlDbType.UniqueIdentifier).Value = user.Id; + cmd.ExecuteNonQuery(); } + + transaction.Commit(); + } + catch + { + transaction.Rollback(); + throw; } } - - return Task.FromResult(0); } - public async Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers) + return Task.FromResult(0); + } + + public async Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers) + { + if (!ciphers.Any()) { - if (!ciphers.Any()) - { - return; - } + return; + } - using (var connection = new SqlConnection(ConnectionString)) - { - connection.Open(); + using (var connection = new SqlConnection(ConnectionString)) + { + connection.Open(); - using (var transaction = connection.BeginTransaction()) + using (var transaction = connection.BeginTransaction()) + { + try { - try - { - // 1. Create temp tables to bulk copy into. + // 1. Create temp tables to bulk copy into. - var sqlCreateTemp = @" + var sqlCreateTemp = @" SELECT TOP 0 * INTO #TempCipher FROM [dbo].[Cipher]"; - using (var cmd = new SqlCommand(sqlCreateTemp, connection, transaction)) - { - cmd.ExecuteNonQuery(); - } + using (var cmd = new SqlCommand(sqlCreateTemp, connection, transaction)) + { + cmd.ExecuteNonQuery(); + } - // 2. Bulk copy into temp tables. - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "#TempCipher"; - var dataTable = BuildCiphersTable(bulkCopy, ciphers); - bulkCopy.WriteToServer(dataTable); - } + // 2. Bulk copy into temp tables. + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "#TempCipher"; + var dataTable = BuildCiphersTable(bulkCopy, ciphers); + bulkCopy.WriteToServer(dataTable); + } - // 3. Insert into real tables from temp tables and clean up. + // 3. Insert into real tables from temp tables and clean up. - // Intentionally not including Favorites, Folders, and CreationDate - // since those are not meant to be bulk updated at this time - var sql = @" + // Intentionally not including Favorites, Folders, and CreationDate + // since those are not meant to be bulk updated at this time + var sql = @" UPDATE [dbo].[Cipher] SET @@ -512,452 +512,451 @@ namespace Bit.Infrastructure.Dapper.Repositories DROP TABLE #TempCipher"; - using (var cmd = new SqlCommand(sql, connection, transaction)) - { - cmd.Parameters.Add("@UserId", SqlDbType.UniqueIdentifier).Value = userId; - cmd.ExecuteNonQuery(); - } - - await connection.ExecuteAsync( - $"[{Schema}].[User_BumpAccountRevisionDate]", - new { Id = userId }, - commandType: CommandType.StoredProcedure, transaction: transaction); - - transaction.Commit(); - } - catch + using (var cmd = new SqlCommand(sql, connection, transaction)) { - transaction.Rollback(); - throw; + cmd.Parameters.Add("@UserId", SqlDbType.UniqueIdentifier).Value = userId; + cmd.ExecuteNonQuery(); } + + await connection.ExecuteAsync( + $"[{Schema}].[User_BumpAccountRevisionDate]", + new { Id = userId }, + commandType: CommandType.StoredProcedure, transaction: transaction); + + transaction.Commit(); } - } - } - - public async Task CreateAsync(IEnumerable ciphers, IEnumerable folders) - { - if (!ciphers.Any()) - { - return; - } - - using (var connection = new SqlConnection(ConnectionString)) - { - connection.Open(); - - using (var transaction = connection.BeginTransaction()) + catch { - try - { - if (folders.Any()) - { - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "[dbo].[Folder]"; - var dataTable = BuildFoldersTable(bulkCopy, folders); - bulkCopy.WriteToServer(dataTable); - } - } - - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "[dbo].[Cipher]"; - var dataTable = BuildCiphersTable(bulkCopy, ciphers); - bulkCopy.WriteToServer(dataTable); - } - - await connection.ExecuteAsync( - $"[{Schema}].[User_BumpAccountRevisionDate]", - new { Id = ciphers.First().UserId }, - commandType: CommandType.StoredProcedure, transaction: transaction); - - transaction.Commit(); - } - catch - { - transaction.Rollback(); - throw; - } + transaction.Rollback(); + throw; } } } - - public async Task CreateAsync(IEnumerable ciphers, IEnumerable collections, - IEnumerable collectionCiphers) - { - if (!ciphers.Any()) - { - return; - } - - using (var connection = new SqlConnection(ConnectionString)) - { - connection.Open(); - - using (var transaction = connection.BeginTransaction()) - { - try - { - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "[dbo].[Cipher]"; - var dataTable = BuildCiphersTable(bulkCopy, ciphers); - bulkCopy.WriteToServer(dataTable); - } - - if (collections.Any()) - { - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "[dbo].[Collection]"; - var dataTable = BuildCollectionsTable(bulkCopy, collections); - bulkCopy.WriteToServer(dataTable); - } - - if (collectionCiphers.Any()) - { - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "[dbo].[CollectionCipher]"; - var dataTable = BuildCollectionCiphersTable(bulkCopy, collectionCiphers); - bulkCopy.WriteToServer(dataTable); - } - } - } - - await connection.ExecuteAsync( - $"[{Schema}].[User_BumpAccountRevisionDateByOrganizationId]", - new { OrganizationId = ciphers.First().OrganizationId }, - commandType: CommandType.StoredProcedure, transaction: transaction); - - transaction.Commit(); - } - catch - { - transaction.Rollback(); - throw; - } - } - } - } - - public async Task SoftDeleteAsync(IEnumerable ids, Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Cipher_SoftDelete]", - new { Ids = ids.ToGuidIdArrayTVP(), UserId = userId }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task RestoreAsync(IEnumerable ids, Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteScalarAsync( - $"[{Schema}].[Cipher_Restore]", - new { Ids = ids.ToGuidIdArrayTVP(), UserId = userId }, - commandType: CommandType.StoredProcedure); - - return results; - } - } - - public async Task DeleteDeletedAsync(DateTime deletedDateBefore) - { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync( - $"[{Schema}].[Cipher_DeleteDeleted]", - new { DeletedDateBefore = deletedDateBefore }, - commandType: CommandType.StoredProcedure, - commandTimeout: 43200); - } - } - - private DataTable BuildCiphersTable(SqlBulkCopy bulkCopy, IEnumerable ciphers) - { - var c = ciphers.FirstOrDefault(); - if (c == null) - { - throw new ApplicationException("Must have some ciphers to bulk import."); - } - - var ciphersTable = new DataTable("CipherDataTable"); - - var idColumn = new DataColumn(nameof(c.Id), c.Id.GetType()); - ciphersTable.Columns.Add(idColumn); - var userIdColumn = new DataColumn(nameof(c.UserId), typeof(Guid)); - ciphersTable.Columns.Add(userIdColumn); - var organizationId = new DataColumn(nameof(c.OrganizationId), typeof(Guid)); - ciphersTable.Columns.Add(organizationId); - var typeColumn = new DataColumn(nameof(c.Type), typeof(short)); - ciphersTable.Columns.Add(typeColumn); - var dataColumn = new DataColumn(nameof(c.Data), typeof(string)); - ciphersTable.Columns.Add(dataColumn); - var favoritesColumn = new DataColumn(nameof(c.Favorites), typeof(string)); - ciphersTable.Columns.Add(favoritesColumn); - var foldersColumn = new DataColumn(nameof(c.Folders), typeof(string)); - ciphersTable.Columns.Add(foldersColumn); - var attachmentsColumn = new DataColumn(nameof(c.Attachments), typeof(string)); - ciphersTable.Columns.Add(attachmentsColumn); - var creationDateColumn = new DataColumn(nameof(c.CreationDate), c.CreationDate.GetType()); - ciphersTable.Columns.Add(creationDateColumn); - var revisionDateColumn = new DataColumn(nameof(c.RevisionDate), c.RevisionDate.GetType()); - ciphersTable.Columns.Add(revisionDateColumn); - var deletedDateColumn = new DataColumn(nameof(c.DeletedDate), typeof(DateTime)); - ciphersTable.Columns.Add(deletedDateColumn); - var repromptColumn = new DataColumn(nameof(c.Reprompt), typeof(short)); - ciphersTable.Columns.Add(repromptColumn); - - foreach (DataColumn col in ciphersTable.Columns) - { - bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); - } - - var keys = new DataColumn[1]; - keys[0] = idColumn; - ciphersTable.PrimaryKey = keys; - - foreach (var cipher in ciphers) - { - var row = ciphersTable.NewRow(); - - row[idColumn] = cipher.Id; - row[userIdColumn] = cipher.UserId.HasValue ? (object)cipher.UserId.Value : DBNull.Value; - row[organizationId] = cipher.OrganizationId.HasValue ? (object)cipher.OrganizationId.Value : DBNull.Value; - row[typeColumn] = (short)cipher.Type; - row[dataColumn] = cipher.Data; - row[favoritesColumn] = cipher.Favorites; - row[foldersColumn] = cipher.Folders; - row[attachmentsColumn] = cipher.Attachments; - row[creationDateColumn] = cipher.CreationDate; - row[revisionDateColumn] = cipher.RevisionDate; - row[deletedDateColumn] = cipher.DeletedDate.HasValue ? (object)cipher.DeletedDate : DBNull.Value; - row[repromptColumn] = cipher.Reprompt; - - ciphersTable.Rows.Add(row); - } - - return ciphersTable; - } - - private DataTable BuildFoldersTable(SqlBulkCopy bulkCopy, IEnumerable folders) - { - var f = folders.FirstOrDefault(); - if (f == null) - { - throw new ApplicationException("Must have some folders to bulk import."); - } - - var foldersTable = new DataTable("FolderDataTable"); - - var idColumn = new DataColumn(nameof(f.Id), f.Id.GetType()); - foldersTable.Columns.Add(idColumn); - var userIdColumn = new DataColumn(nameof(f.UserId), f.UserId.GetType()); - foldersTable.Columns.Add(userIdColumn); - var nameColumn = new DataColumn(nameof(f.Name), typeof(string)); - foldersTable.Columns.Add(nameColumn); - var creationDateColumn = new DataColumn(nameof(f.CreationDate), f.CreationDate.GetType()); - foldersTable.Columns.Add(creationDateColumn); - var revisionDateColumn = new DataColumn(nameof(f.RevisionDate), f.RevisionDate.GetType()); - foldersTable.Columns.Add(revisionDateColumn); - - foreach (DataColumn col in foldersTable.Columns) - { - bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); - } - - var keys = new DataColumn[1]; - keys[0] = idColumn; - foldersTable.PrimaryKey = keys; - - foreach (var folder in folders) - { - var row = foldersTable.NewRow(); - - row[idColumn] = folder.Id; - row[userIdColumn] = folder.UserId; - row[nameColumn] = folder.Name; - row[creationDateColumn] = folder.CreationDate; - row[revisionDateColumn] = folder.RevisionDate; - - foldersTable.Rows.Add(row); - } - - return foldersTable; - } - - private DataTable BuildCollectionsTable(SqlBulkCopy bulkCopy, IEnumerable collections) - { - var c = collections.FirstOrDefault(); - if (c == null) - { - throw new ApplicationException("Must have some collections to bulk import."); - } - - var collectionsTable = new DataTable("CollectionDataTable"); - - var idColumn = new DataColumn(nameof(c.Id), c.Id.GetType()); - collectionsTable.Columns.Add(idColumn); - var organizationIdColumn = new DataColumn(nameof(c.OrganizationId), c.OrganizationId.GetType()); - collectionsTable.Columns.Add(organizationIdColumn); - var nameColumn = new DataColumn(nameof(c.Name), typeof(string)); - collectionsTable.Columns.Add(nameColumn); - var creationDateColumn = new DataColumn(nameof(c.CreationDate), c.CreationDate.GetType()); - collectionsTable.Columns.Add(creationDateColumn); - var revisionDateColumn = new DataColumn(nameof(c.RevisionDate), c.RevisionDate.GetType()); - collectionsTable.Columns.Add(revisionDateColumn); - - foreach (DataColumn col in collectionsTable.Columns) - { - bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); - } - - var keys = new DataColumn[1]; - keys[0] = idColumn; - collectionsTable.PrimaryKey = keys; - - foreach (var collection in collections) - { - var row = collectionsTable.NewRow(); - - row[idColumn] = collection.Id; - row[organizationIdColumn] = collection.OrganizationId; - row[nameColumn] = collection.Name; - row[creationDateColumn] = collection.CreationDate; - row[revisionDateColumn] = collection.RevisionDate; - - collectionsTable.Rows.Add(row); - } - - return collectionsTable; - } - - private DataTable BuildCollectionCiphersTable(SqlBulkCopy bulkCopy, IEnumerable collectionCiphers) - { - var cc = collectionCiphers.FirstOrDefault(); - if (cc == null) - { - throw new ApplicationException("Must have some collectionCiphers to bulk import."); - } - - var collectionCiphersTable = new DataTable("CollectionCipherDataTable"); - - var collectionIdColumn = new DataColumn(nameof(cc.CollectionId), cc.CollectionId.GetType()); - collectionCiphersTable.Columns.Add(collectionIdColumn); - var cipherIdColumn = new DataColumn(nameof(cc.CipherId), cc.CipherId.GetType()); - collectionCiphersTable.Columns.Add(cipherIdColumn); - - foreach (DataColumn col in collectionCiphersTable.Columns) - { - bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); - } - - var keys = new DataColumn[2]; - keys[0] = collectionIdColumn; - keys[1] = cipherIdColumn; - collectionCiphersTable.PrimaryKey = keys; - - foreach (var collectionCipher in collectionCiphers) - { - var row = collectionCiphersTable.NewRow(); - - row[collectionIdColumn] = collectionCipher.CollectionId; - row[cipherIdColumn] = collectionCipher.CipherId; - - collectionCiphersTable.Rows.Add(row); - } - - return collectionCiphersTable; - } - - private DataTable BuildSendsTable(SqlBulkCopy bulkCopy, IEnumerable sends) - { - var s = sends.FirstOrDefault(); - if (s == null) - { - throw new ApplicationException("Must have some Sends to bulk import."); - } - - var sendsTable = new DataTable("SendsDataTable"); - - var idColumn = new DataColumn(nameof(s.Id), s.Id.GetType()); - sendsTable.Columns.Add(idColumn); - var userIdColumn = new DataColumn(nameof(s.UserId), typeof(Guid)); - sendsTable.Columns.Add(userIdColumn); - var organizationIdColumn = new DataColumn(nameof(s.OrganizationId), typeof(Guid)); - sendsTable.Columns.Add(organizationIdColumn); - var typeColumn = new DataColumn(nameof(s.Type), s.Type.GetType()); - sendsTable.Columns.Add(typeColumn); - var dataColumn = new DataColumn(nameof(s.Data), s.Data.GetType()); - sendsTable.Columns.Add(dataColumn); - var keyColumn = new DataColumn(nameof(s.Key), s.Key.GetType()); - sendsTable.Columns.Add(keyColumn); - var passwordColumn = new DataColumn(nameof(s.Password), typeof(string)); - sendsTable.Columns.Add(passwordColumn); - var maxAccessCountColumn = new DataColumn(nameof(s.MaxAccessCount), typeof(int)); - sendsTable.Columns.Add(maxAccessCountColumn); - var accessCountColumn = new DataColumn(nameof(s.AccessCount), s.AccessCount.GetType()); - sendsTable.Columns.Add(accessCountColumn); - var creationDateColumn = new DataColumn(nameof(s.CreationDate), s.CreationDate.GetType()); - sendsTable.Columns.Add(creationDateColumn); - var revisionDateColumn = new DataColumn(nameof(s.RevisionDate), s.RevisionDate.GetType()); - sendsTable.Columns.Add(revisionDateColumn); - var expirationDateColumn = new DataColumn(nameof(s.ExpirationDate), typeof(DateTime)); - sendsTable.Columns.Add(expirationDateColumn); - var deletionDateColumn = new DataColumn(nameof(s.DeletionDate), s.DeletionDate.GetType()); - sendsTable.Columns.Add(deletionDateColumn); - var disabledColumn = new DataColumn(nameof(s.Disabled), s.Disabled.GetType()); - sendsTable.Columns.Add(disabledColumn); - var hideEmailColumn = new DataColumn(nameof(s.HideEmail), typeof(bool)); - sendsTable.Columns.Add(hideEmailColumn); - - foreach (DataColumn col in sendsTable.Columns) - { - bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); - } - - var keys = new DataColumn[1]; - keys[0] = idColumn; - sendsTable.PrimaryKey = keys; - - foreach (var send in sends) - { - var row = sendsTable.NewRow(); - - row[idColumn] = send.Id; - row[userIdColumn] = send.UserId.HasValue ? (object)send.UserId.Value : DBNull.Value; - row[organizationIdColumn] = send.OrganizationId.HasValue ? (object)send.OrganizationId.Value : DBNull.Value; - row[typeColumn] = (short)send.Type; - row[dataColumn] = send.Data; - row[keyColumn] = send.Key; - row[passwordColumn] = send.Password; - row[maxAccessCountColumn] = send.MaxAccessCount.HasValue ? (object)send.MaxAccessCount : DBNull.Value; - row[accessCountColumn] = send.AccessCount; - row[creationDateColumn] = send.CreationDate; - row[revisionDateColumn] = send.RevisionDate; - row[expirationDateColumn] = send.ExpirationDate.HasValue ? (object)send.ExpirationDate : DBNull.Value; - row[deletionDateColumn] = send.DeletionDate; - row[disabledColumn] = send.Disabled; - row[hideEmailColumn] = send.HideEmail.HasValue ? (object)send.HideEmail : DBNull.Value; - - sendsTable.Rows.Add(row); - } - - return sendsTable; - } - - public class CipherDetailsWithCollections : CipherDetails - { - public DataTable CollectionIds { get; set; } - } - - public class CipherWithCollections : Cipher - { - public DataTable CollectionIds { get; set; } - } + } + + public async Task CreateAsync(IEnumerable ciphers, IEnumerable folders) + { + if (!ciphers.Any()) + { + return; + } + + using (var connection = new SqlConnection(ConnectionString)) + { + connection.Open(); + + using (var transaction = connection.BeginTransaction()) + { + try + { + if (folders.Any()) + { + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "[dbo].[Folder]"; + var dataTable = BuildFoldersTable(bulkCopy, folders); + bulkCopy.WriteToServer(dataTable); + } + } + + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "[dbo].[Cipher]"; + var dataTable = BuildCiphersTable(bulkCopy, ciphers); + bulkCopy.WriteToServer(dataTable); + } + + await connection.ExecuteAsync( + $"[{Schema}].[User_BumpAccountRevisionDate]", + new { Id = ciphers.First().UserId }, + commandType: CommandType.StoredProcedure, transaction: transaction); + + transaction.Commit(); + } + catch + { + transaction.Rollback(); + throw; + } + } + } + } + + public async Task CreateAsync(IEnumerable ciphers, IEnumerable collections, + IEnumerable collectionCiphers) + { + if (!ciphers.Any()) + { + return; + } + + using (var connection = new SqlConnection(ConnectionString)) + { + connection.Open(); + + using (var transaction = connection.BeginTransaction()) + { + try + { + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "[dbo].[Cipher]"; + var dataTable = BuildCiphersTable(bulkCopy, ciphers); + bulkCopy.WriteToServer(dataTable); + } + + if (collections.Any()) + { + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "[dbo].[Collection]"; + var dataTable = BuildCollectionsTable(bulkCopy, collections); + bulkCopy.WriteToServer(dataTable); + } + + if (collectionCiphers.Any()) + { + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) + { + bulkCopy.DestinationTableName = "[dbo].[CollectionCipher]"; + var dataTable = BuildCollectionCiphersTable(bulkCopy, collectionCiphers); + bulkCopy.WriteToServer(dataTable); + } + } + } + + await connection.ExecuteAsync( + $"[{Schema}].[User_BumpAccountRevisionDateByOrganizationId]", + new { OrganizationId = ciphers.First().OrganizationId }, + commandType: CommandType.StoredProcedure, transaction: transaction); + + transaction.Commit(); + } + catch + { + transaction.Rollback(); + throw; + } + } + } + } + + public async Task SoftDeleteAsync(IEnumerable ids, Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[Cipher_SoftDelete]", + new { Ids = ids.ToGuidIdArrayTVP(), UserId = userId }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task RestoreAsync(IEnumerable ids, Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteScalarAsync( + $"[{Schema}].[Cipher_Restore]", + new { Ids = ids.ToGuidIdArrayTVP(), UserId = userId }, + commandType: CommandType.StoredProcedure); + + return results; + } + } + + public async Task DeleteDeletedAsync(DateTime deletedDateBefore) + { + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + $"[{Schema}].[Cipher_DeleteDeleted]", + new { DeletedDateBefore = deletedDateBefore }, + commandType: CommandType.StoredProcedure, + commandTimeout: 43200); + } + } + + private DataTable BuildCiphersTable(SqlBulkCopy bulkCopy, IEnumerable ciphers) + { + var c = ciphers.FirstOrDefault(); + if (c == null) + { + throw new ApplicationException("Must have some ciphers to bulk import."); + } + + var ciphersTable = new DataTable("CipherDataTable"); + + var idColumn = new DataColumn(nameof(c.Id), c.Id.GetType()); + ciphersTable.Columns.Add(idColumn); + var userIdColumn = new DataColumn(nameof(c.UserId), typeof(Guid)); + ciphersTable.Columns.Add(userIdColumn); + var organizationId = new DataColumn(nameof(c.OrganizationId), typeof(Guid)); + ciphersTable.Columns.Add(organizationId); + var typeColumn = new DataColumn(nameof(c.Type), typeof(short)); + ciphersTable.Columns.Add(typeColumn); + var dataColumn = new DataColumn(nameof(c.Data), typeof(string)); + ciphersTable.Columns.Add(dataColumn); + var favoritesColumn = new DataColumn(nameof(c.Favorites), typeof(string)); + ciphersTable.Columns.Add(favoritesColumn); + var foldersColumn = new DataColumn(nameof(c.Folders), typeof(string)); + ciphersTable.Columns.Add(foldersColumn); + var attachmentsColumn = new DataColumn(nameof(c.Attachments), typeof(string)); + ciphersTable.Columns.Add(attachmentsColumn); + var creationDateColumn = new DataColumn(nameof(c.CreationDate), c.CreationDate.GetType()); + ciphersTable.Columns.Add(creationDateColumn); + var revisionDateColumn = new DataColumn(nameof(c.RevisionDate), c.RevisionDate.GetType()); + ciphersTable.Columns.Add(revisionDateColumn); + var deletedDateColumn = new DataColumn(nameof(c.DeletedDate), typeof(DateTime)); + ciphersTable.Columns.Add(deletedDateColumn); + var repromptColumn = new DataColumn(nameof(c.Reprompt), typeof(short)); + ciphersTable.Columns.Add(repromptColumn); + + foreach (DataColumn col in ciphersTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[1]; + keys[0] = idColumn; + ciphersTable.PrimaryKey = keys; + + foreach (var cipher in ciphers) + { + var row = ciphersTable.NewRow(); + + row[idColumn] = cipher.Id; + row[userIdColumn] = cipher.UserId.HasValue ? (object)cipher.UserId.Value : DBNull.Value; + row[organizationId] = cipher.OrganizationId.HasValue ? (object)cipher.OrganizationId.Value : DBNull.Value; + row[typeColumn] = (short)cipher.Type; + row[dataColumn] = cipher.Data; + row[favoritesColumn] = cipher.Favorites; + row[foldersColumn] = cipher.Folders; + row[attachmentsColumn] = cipher.Attachments; + row[creationDateColumn] = cipher.CreationDate; + row[revisionDateColumn] = cipher.RevisionDate; + row[deletedDateColumn] = cipher.DeletedDate.HasValue ? (object)cipher.DeletedDate : DBNull.Value; + row[repromptColumn] = cipher.Reprompt; + + ciphersTable.Rows.Add(row); + } + + return ciphersTable; + } + + private DataTable BuildFoldersTable(SqlBulkCopy bulkCopy, IEnumerable folders) + { + var f = folders.FirstOrDefault(); + if (f == null) + { + throw new ApplicationException("Must have some folders to bulk import."); + } + + var foldersTable = new DataTable("FolderDataTable"); + + var idColumn = new DataColumn(nameof(f.Id), f.Id.GetType()); + foldersTable.Columns.Add(idColumn); + var userIdColumn = new DataColumn(nameof(f.UserId), f.UserId.GetType()); + foldersTable.Columns.Add(userIdColumn); + var nameColumn = new DataColumn(nameof(f.Name), typeof(string)); + foldersTable.Columns.Add(nameColumn); + var creationDateColumn = new DataColumn(nameof(f.CreationDate), f.CreationDate.GetType()); + foldersTable.Columns.Add(creationDateColumn); + var revisionDateColumn = new DataColumn(nameof(f.RevisionDate), f.RevisionDate.GetType()); + foldersTable.Columns.Add(revisionDateColumn); + + foreach (DataColumn col in foldersTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[1]; + keys[0] = idColumn; + foldersTable.PrimaryKey = keys; + + foreach (var folder in folders) + { + var row = foldersTable.NewRow(); + + row[idColumn] = folder.Id; + row[userIdColumn] = folder.UserId; + row[nameColumn] = folder.Name; + row[creationDateColumn] = folder.CreationDate; + row[revisionDateColumn] = folder.RevisionDate; + + foldersTable.Rows.Add(row); + } + + return foldersTable; + } + + private DataTable BuildCollectionsTable(SqlBulkCopy bulkCopy, IEnumerable collections) + { + var c = collections.FirstOrDefault(); + if (c == null) + { + throw new ApplicationException("Must have some collections to bulk import."); + } + + var collectionsTable = new DataTable("CollectionDataTable"); + + var idColumn = new DataColumn(nameof(c.Id), c.Id.GetType()); + collectionsTable.Columns.Add(idColumn); + var organizationIdColumn = new DataColumn(nameof(c.OrganizationId), c.OrganizationId.GetType()); + collectionsTable.Columns.Add(organizationIdColumn); + var nameColumn = new DataColumn(nameof(c.Name), typeof(string)); + collectionsTable.Columns.Add(nameColumn); + var creationDateColumn = new DataColumn(nameof(c.CreationDate), c.CreationDate.GetType()); + collectionsTable.Columns.Add(creationDateColumn); + var revisionDateColumn = new DataColumn(nameof(c.RevisionDate), c.RevisionDate.GetType()); + collectionsTable.Columns.Add(revisionDateColumn); + + foreach (DataColumn col in collectionsTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[1]; + keys[0] = idColumn; + collectionsTable.PrimaryKey = keys; + + foreach (var collection in collections) + { + var row = collectionsTable.NewRow(); + + row[idColumn] = collection.Id; + row[organizationIdColumn] = collection.OrganizationId; + row[nameColumn] = collection.Name; + row[creationDateColumn] = collection.CreationDate; + row[revisionDateColumn] = collection.RevisionDate; + + collectionsTable.Rows.Add(row); + } + + return collectionsTable; + } + + private DataTable BuildCollectionCiphersTable(SqlBulkCopy bulkCopy, IEnumerable collectionCiphers) + { + var cc = collectionCiphers.FirstOrDefault(); + if (cc == null) + { + throw new ApplicationException("Must have some collectionCiphers to bulk import."); + } + + var collectionCiphersTable = new DataTable("CollectionCipherDataTable"); + + var collectionIdColumn = new DataColumn(nameof(cc.CollectionId), cc.CollectionId.GetType()); + collectionCiphersTable.Columns.Add(collectionIdColumn); + var cipherIdColumn = new DataColumn(nameof(cc.CipherId), cc.CipherId.GetType()); + collectionCiphersTable.Columns.Add(cipherIdColumn); + + foreach (DataColumn col in collectionCiphersTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[2]; + keys[0] = collectionIdColumn; + keys[1] = cipherIdColumn; + collectionCiphersTable.PrimaryKey = keys; + + foreach (var collectionCipher in collectionCiphers) + { + var row = collectionCiphersTable.NewRow(); + + row[collectionIdColumn] = collectionCipher.CollectionId; + row[cipherIdColumn] = collectionCipher.CipherId; + + collectionCiphersTable.Rows.Add(row); + } + + return collectionCiphersTable; + } + + private DataTable BuildSendsTable(SqlBulkCopy bulkCopy, IEnumerable sends) + { + var s = sends.FirstOrDefault(); + if (s == null) + { + throw new ApplicationException("Must have some Sends to bulk import."); + } + + var sendsTable = new DataTable("SendsDataTable"); + + var idColumn = new DataColumn(nameof(s.Id), s.Id.GetType()); + sendsTable.Columns.Add(idColumn); + var userIdColumn = new DataColumn(nameof(s.UserId), typeof(Guid)); + sendsTable.Columns.Add(userIdColumn); + var organizationIdColumn = new DataColumn(nameof(s.OrganizationId), typeof(Guid)); + sendsTable.Columns.Add(organizationIdColumn); + var typeColumn = new DataColumn(nameof(s.Type), s.Type.GetType()); + sendsTable.Columns.Add(typeColumn); + var dataColumn = new DataColumn(nameof(s.Data), s.Data.GetType()); + sendsTable.Columns.Add(dataColumn); + var keyColumn = new DataColumn(nameof(s.Key), s.Key.GetType()); + sendsTable.Columns.Add(keyColumn); + var passwordColumn = new DataColumn(nameof(s.Password), typeof(string)); + sendsTable.Columns.Add(passwordColumn); + var maxAccessCountColumn = new DataColumn(nameof(s.MaxAccessCount), typeof(int)); + sendsTable.Columns.Add(maxAccessCountColumn); + var accessCountColumn = new DataColumn(nameof(s.AccessCount), s.AccessCount.GetType()); + sendsTable.Columns.Add(accessCountColumn); + var creationDateColumn = new DataColumn(nameof(s.CreationDate), s.CreationDate.GetType()); + sendsTable.Columns.Add(creationDateColumn); + var revisionDateColumn = new DataColumn(nameof(s.RevisionDate), s.RevisionDate.GetType()); + sendsTable.Columns.Add(revisionDateColumn); + var expirationDateColumn = new DataColumn(nameof(s.ExpirationDate), typeof(DateTime)); + sendsTable.Columns.Add(expirationDateColumn); + var deletionDateColumn = new DataColumn(nameof(s.DeletionDate), s.DeletionDate.GetType()); + sendsTable.Columns.Add(deletionDateColumn); + var disabledColumn = new DataColumn(nameof(s.Disabled), s.Disabled.GetType()); + sendsTable.Columns.Add(disabledColumn); + var hideEmailColumn = new DataColumn(nameof(s.HideEmail), typeof(bool)); + sendsTable.Columns.Add(hideEmailColumn); + + foreach (DataColumn col in sendsTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[1]; + keys[0] = idColumn; + sendsTable.PrimaryKey = keys; + + foreach (var send in sends) + { + var row = sendsTable.NewRow(); + + row[idColumn] = send.Id; + row[userIdColumn] = send.UserId.HasValue ? (object)send.UserId.Value : DBNull.Value; + row[organizationIdColumn] = send.OrganizationId.HasValue ? (object)send.OrganizationId.Value : DBNull.Value; + row[typeColumn] = (short)send.Type; + row[dataColumn] = send.Data; + row[keyColumn] = send.Key; + row[passwordColumn] = send.Password; + row[maxAccessCountColumn] = send.MaxAccessCount.HasValue ? (object)send.MaxAccessCount : DBNull.Value; + row[accessCountColumn] = send.AccessCount; + row[creationDateColumn] = send.CreationDate; + row[revisionDateColumn] = send.RevisionDate; + row[expirationDateColumn] = send.ExpirationDate.HasValue ? (object)send.ExpirationDate : DBNull.Value; + row[deletionDateColumn] = send.DeletionDate; + row[disabledColumn] = send.Disabled; + row[hideEmailColumn] = send.HideEmail.HasValue ? (object)send.HideEmail : DBNull.Value; + + sendsTable.Rows.Add(row); + } + + return sendsTable; + } + + public class CipherDetailsWithCollections : CipherDetails + { + public DataTable CollectionIds { get; set; } + } + + public class CipherWithCollections : Cipher + { + public DataTable CollectionIds { get; set; } } } diff --git a/src/Infrastructure.Dapper/Repositories/CollectionCipherRepository.cs b/src/Infrastructure.Dapper/Repositories/CollectionCipherRepository.cs index 287697948..1368be21e 100644 --- a/src/Infrastructure.Dapper/Repositories/CollectionCipherRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/CollectionCipherRepository.cs @@ -5,95 +5,94 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class CollectionCipherRepository : BaseRepository, ICollectionCipherRepository { - public class CollectionCipherRepository : BaseRepository, ICollectionCipherRepository + public CollectionCipherRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public CollectionCipherRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task> GetManyByUserIdAsync(Guid userId) { - public CollectionCipherRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public CollectionCipherRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task> GetManyByUserIdAsync(Guid userId) + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[CollectionCipher_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[CollectionCipher_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[CollectionCipher_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[CollectionCipher_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetManyByUserIdCipherIdAsync(Guid userId, Guid cipherId) + public async Task> GetManyByUserIdCipherIdAsync(Guid userId, Guid cipherId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[CollectionCipher_ReadByUserIdCipherId]", - new { UserId = userId, CipherId = cipherId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[CollectionCipher_ReadByUserIdCipherId]", + new { UserId = userId, CipherId = cipherId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task UpdateCollectionsAsync(Guid cipherId, Guid userId, IEnumerable collectionIds) + public async Task UpdateCollectionsAsync(Guid cipherId, Guid userId, IEnumerable collectionIds) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - "[dbo].[CollectionCipher_UpdateCollections]", - new { CipherId = cipherId, UserId = userId, CollectionIds = collectionIds.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + "[dbo].[CollectionCipher_UpdateCollections]", + new { CipherId = cipherId, UserId = userId, CollectionIds = collectionIds.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); } + } - public async Task UpdateCollectionsForAdminAsync(Guid cipherId, Guid organizationId, IEnumerable collectionIds) + public async Task UpdateCollectionsForAdminAsync(Guid cipherId, Guid organizationId, IEnumerable collectionIds) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - "[dbo].[CollectionCipher_UpdateCollectionsAdmin]", - new { CipherId = cipherId, OrganizationId = organizationId, CollectionIds = collectionIds.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + "[dbo].[CollectionCipher_UpdateCollectionsAdmin]", + new { CipherId = cipherId, OrganizationId = organizationId, CollectionIds = collectionIds.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); } + } - public async Task UpdateCollectionsForCiphersAsync(IEnumerable cipherIds, Guid userId, - Guid organizationId, IEnumerable collectionIds) + public async Task UpdateCollectionsForCiphersAsync(IEnumerable cipherIds, Guid userId, + Guid organizationId, IEnumerable collectionIds) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - "[dbo].[CollectionCipher_UpdateCollectionsForCiphers]", - new - { - CipherIds = cipherIds.ToGuidIdArrayTVP(), - UserId = userId, - OrganizationId = organizationId, - CollectionIds = collectionIds.ToGuidIdArrayTVP() - }, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + "[dbo].[CollectionCipher_UpdateCollectionsForCiphers]", + new + { + CipherIds = cipherIds.ToGuidIdArrayTVP(), + UserId = userId, + OrganizationId = organizationId, + CollectionIds = collectionIds.ToGuidIdArrayTVP() + }, + commandType: CommandType.StoredProcedure); } } } diff --git a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs index ee8bc1e2e..3fd0a2430 100644 --- a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs @@ -7,181 +7,180 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class CollectionRepository : Repository, ICollectionRepository { - public class CollectionRepository : Repository, ICollectionRepository + public CollectionRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public CollectionRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetCountByOrganizationIdAsync(Guid organizationId) { - public CollectionRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public CollectionRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetCountByOrganizationIdAsync(Guid organizationId) + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteScalarAsync( - "[dbo].[Collection_ReadCountByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + var results = await connection.ExecuteScalarAsync( + "[dbo].[Collection_ReadCountByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results; - } - } - - public async Task>> GetByIdWithGroupsAsync(Guid id) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryMultipleAsync( - $"[{Schema}].[Collection_ReadWithGroupsById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); - - var collection = await results.ReadFirstOrDefaultAsync(); - var groups = (await results.ReadAsync()).ToList(); - - return new Tuple>(collection, groups); - } - } - - public async Task>> GetByIdWithGroupsAsync( - Guid id, Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryMultipleAsync( - $"[{Schema}].[Collection_ReadWithGroupsByIdUserId]", - new { Id = id, UserId = userId }, - commandType: CommandType.StoredProcedure); - - var collection = await results.ReadFirstOrDefaultAsync(); - var groups = (await results.ReadAsync()).ToList(); - - return new Tuple>(collection, groups); - } - } - - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - - public async Task GetByIdAsync(Guid id, Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[Collection_ReadByIdUserId]", - new { Id = id, UserId = userId }, - commandType: CommandType.StoredProcedure); - - return results.FirstOrDefault(); - } - } - - public async Task> GetManyByUserIdAsync(Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[Collection_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - - public async Task CreateAsync(Collection obj, IEnumerable groups) - { - obj.SetNewId(); - var objWithGroups = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); - objWithGroups.Groups = groups.ToArrayTVP(); - - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Collection_CreateWithGroups]", - objWithGroups, - commandType: CommandType.StoredProcedure); - } - } - - public async Task ReplaceAsync(Collection obj, IEnumerable groups) - { - var objWithGroups = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); - objWithGroups.Groups = groups.ToArrayTVP(); - - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Collection_UpdateWithGroups]", - objWithGroups, - commandType: CommandType.StoredProcedure); - } - } - - public async Task CreateUserAsync(Guid collectionId, Guid organizationUserId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[CollectionUser_Create]", - new { CollectionId = collectionId, OrganizationUserId = organizationUserId }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task DeleteUserAsync(Guid collectionId, Guid organizationUserId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[CollectionUser_Delete]", - new { CollectionId = collectionId, OrganizationUserId = organizationUserId }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task UpdateUsersAsync(Guid id, IEnumerable users) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[CollectionUser_UpdateUsers]", - new { CollectionId = id, Users = users.ToArrayTVP() }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task> GetManyUsersByIdAsync(Guid id) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[CollectionUser_ReadByCollectionId]", - new { CollectionId = id }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - - public class CollectionWithGroups : Collection - { - public DataTable Groups { get; set; } + return results; } } + + public async Task>> GetByIdWithGroupsAsync(Guid id) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryMultipleAsync( + $"[{Schema}].[Collection_ReadWithGroupsById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); + + var collection = await results.ReadFirstOrDefaultAsync(); + var groups = (await results.ReadAsync()).ToList(); + + return new Tuple>(collection, groups); + } + } + + public async Task>> GetByIdWithGroupsAsync( + Guid id, Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryMultipleAsync( + $"[{Schema}].[Collection_ReadWithGroupsByIdUserId]", + new { Id = id, UserId = userId }, + commandType: CommandType.StoredProcedure); + + var collection = await results.ReadFirstOrDefaultAsync(); + var groups = (await results.ReadAsync()).ToList(); + + return new Tuple>(collection, groups); + } + } + + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task GetByIdAsync(Guid id, Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[Collection_ReadByIdUserId]", + new { Id = id, UserId = userId }, + commandType: CommandType.StoredProcedure); + + return results.FirstOrDefault(); + } + } + + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[Collection_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task CreateAsync(Collection obj, IEnumerable groups) + { + obj.SetNewId(); + var objWithGroups = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); + objWithGroups.Groups = groups.ToArrayTVP(); + + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[Collection_CreateWithGroups]", + objWithGroups, + commandType: CommandType.StoredProcedure); + } + } + + public async Task ReplaceAsync(Collection obj, IEnumerable groups) + { + var objWithGroups = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); + objWithGroups.Groups = groups.ToArrayTVP(); + + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[Collection_UpdateWithGroups]", + objWithGroups, + commandType: CommandType.StoredProcedure); + } + } + + public async Task CreateUserAsync(Guid collectionId, Guid organizationUserId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[CollectionUser_Create]", + new { CollectionId = collectionId, OrganizationUserId = organizationUserId }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task DeleteUserAsync(Guid collectionId, Guid organizationUserId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[CollectionUser_Delete]", + new { CollectionId = collectionId, OrganizationUserId = organizationUserId }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task UpdateUsersAsync(Guid id, IEnumerable users) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[CollectionUser_UpdateUsers]", + new { CollectionId = id, Users = users.ToArrayTVP() }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task> GetManyUsersByIdAsync(Guid id) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[CollectionUser_ReadByCollectionId]", + new { CollectionId = id }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public class CollectionWithGroups : Collection + { + public DataTable Groups { get; set; } + } } diff --git a/src/Infrastructure.Dapper/Repositories/DateTimeHandler.cs b/src/Infrastructure.Dapper/Repositories/DateTimeHandler.cs index 8aedf2321..ac48653ec 100644 --- a/src/Infrastructure.Dapper/Repositories/DateTimeHandler.cs +++ b/src/Infrastructure.Dapper/Repositories/DateTimeHandler.cs @@ -1,18 +1,17 @@ using System.Data; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories -{ - public class DateTimeHandler : SqlMapper.TypeHandler - { - public override void SetValue(IDbDataParameter parameter, DateTime value) - { - parameter.Value = value; - } +namespace Bit.Infrastructure.Dapper.Repositories; - public override DateTime Parse(object value) - { - return DateTime.SpecifyKind((DateTime)value, DateTimeKind.Utc); - } +public class DateTimeHandler : SqlMapper.TypeHandler +{ + public override void SetValue(IDbDataParameter parameter, DateTime value) + { + parameter.Value = value; + } + + public override DateTime Parse(object value) + { + return DateTime.SpecifyKind((DateTime)value, DateTimeKind.Utc); } } diff --git a/src/Infrastructure.Dapper/Repositories/DeviceRepository.cs b/src/Infrastructure.Dapper/Repositories/DeviceRepository.cs index 039ff90ae..325cee307 100644 --- a/src/Infrastructure.Dapper/Repositories/DeviceRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/DeviceRepository.cs @@ -5,84 +5,83 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class DeviceRepository : Repository, IDeviceRepository { - public class DeviceRepository : Repository, IDeviceRepository + public DeviceRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public DeviceRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetByIdAsync(Guid id, Guid userId) { - public DeviceRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public DeviceRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetByIdAsync(Guid id, Guid userId) + var device = await GetByIdAsync(id); + if (device == null || device.UserId != userId) { - var device = await GetByIdAsync(id); - if (device == null || device.UserId != userId) - { - return null; - } - - return device; + return null; } - public async Task GetByIdentifierAsync(string identifier) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByIdentifier]", - new - { - Identifier = identifier - }, - commandType: CommandType.StoredProcedure); + return device; + } - return results.FirstOrDefault(); - } + public async Task GetByIdentifierAsync(string identifier) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByIdentifier]", + new + { + Identifier = identifier + }, + commandType: CommandType.StoredProcedure); + + return results.FirstOrDefault(); } + } - public async Task GetByIdentifierAsync(string identifier, Guid userId) + public async Task GetByIdentifierAsync(string identifier, Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByIdentifierUserId]", - new - { - UserId = userId, - Identifier = identifier - }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByIdentifierUserId]", + new + { + UserId = userId, + Identifier = identifier + }, + commandType: CommandType.StoredProcedure); - return results.FirstOrDefault(); - } + return results.FirstOrDefault(); } + } - public async Task> GetManyByUserIdAsync(Guid userId) + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task ClearPushTokenAsync(Guid id) + public async Task ClearPushTokenAsync(Guid id) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync( - $"[{Schema}].[{Table}_ClearPushTokenById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); - } + await connection.ExecuteAsync( + $"[{Schema}].[{Table}_ClearPushTokenById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); } } } diff --git a/src/Infrastructure.Dapper/Repositories/EmergencyAccessRepository.cs b/src/Infrastructure.Dapper/Repositories/EmergencyAccessRepository.cs index c88664c7a..9f1f9a971 100644 --- a/src/Infrastructure.Dapper/Repositories/EmergencyAccessRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/EmergencyAccessRepository.cs @@ -6,92 +6,91 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class EmergencyAccessRepository : Repository, IEmergencyAccessRepository { - public class EmergencyAccessRepository : Repository, IEmergencyAccessRepository + public EmergencyAccessRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public EmergencyAccessRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetCountByGrantorIdEmailAsync(Guid grantorId, string email, bool onlyRegisteredUsers) { - public EmergencyAccessRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public EmergencyAccessRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetCountByGrantorIdEmailAsync(Guid grantorId, string email, bool onlyRegisteredUsers) + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteScalarAsync( - "[dbo].[EmergencyAccess_ReadCountByGrantorIdEmail]", - new { GrantorId = grantorId, Email = email, OnlyUsers = onlyRegisteredUsers }, - commandType: CommandType.StoredProcedure); + var results = await connection.ExecuteScalarAsync( + "[dbo].[EmergencyAccess_ReadCountByGrantorIdEmail]", + new { GrantorId = grantorId, Email = email, OnlyUsers = onlyRegisteredUsers }, + commandType: CommandType.StoredProcedure); - return results; - } + return results; } + } - public async Task> GetManyDetailsByGrantorIdAsync(Guid grantorId) + public async Task> GetManyDetailsByGrantorIdAsync(Guid grantorId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[EmergencyAccessDetails_ReadByGrantorId]", - new { GrantorId = grantorId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[EmergencyAccessDetails_ReadByGrantorId]", + new { GrantorId = grantorId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetManyDetailsByGranteeIdAsync(Guid granteeId) + public async Task> GetManyDetailsByGranteeIdAsync(Guid granteeId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[EmergencyAccessDetails_ReadByGranteeId]", - new { GranteeId = granteeId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[EmergencyAccessDetails_ReadByGranteeId]", + new { GranteeId = granteeId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task GetDetailsByIdGrantorIdAsync(Guid id, Guid grantorId) + public async Task GetDetailsByIdGrantorIdAsync(Guid id, Guid grantorId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[EmergencyAccessDetails_ReadByIdGrantorId]", - new { Id = id, GrantorId = grantorId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[EmergencyAccessDetails_ReadByIdGrantorId]", + new { Id = id, GrantorId = grantorId }, + commandType: CommandType.StoredProcedure); - return results.FirstOrDefault(); - } + return results.FirstOrDefault(); } + } - public async Task> GetManyToNotifyAsync() + public async Task> GetManyToNotifyAsync() + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[EmergencyAccess_ReadToNotify]", - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[EmergencyAccess_ReadToNotify]", + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetExpiredRecoveriesAsync() + public async Task> GetExpiredRecoveriesAsync() + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[EmergencyAccessDetails_ReadExpiredRecoveries]", - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[EmergencyAccessDetails_ReadExpiredRecoveries]", + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } } } diff --git a/src/Infrastructure.Dapper/Repositories/EventRepository.cs b/src/Infrastructure.Dapper/Repositories/EventRepository.cs index 82491cb03..ba4c68b35 100644 --- a/src/Infrastructure.Dapper/Repositories/EventRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/EventRepository.cs @@ -6,221 +6,220 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class EventRepository : Repository, IEventRepository { - public class EventRepository : Repository, IEventRepository + public EventRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public EventRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate, + PageOptions pageOptions) { - public EventRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } + return await GetManyAsync($"[{Schema}].[Event_ReadPageByUserId]", + new Dictionary + { + ["@UserId"] = userId + }, startDate, endDate, pageOptions); + } - public EventRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } + public async Task> GetManyByOrganizationAsync(Guid organizationId, + DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + return await GetManyAsync($"[{Schema}].[Event_ReadPageByOrganizationId]", + new Dictionary + { + ["@OrganizationId"] = organizationId + }, startDate, endDate, pageOptions); + } - public async Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate, - PageOptions pageOptions) + public async Task> GetManyByOrganizationActingUserAsync(Guid organizationId, Guid actingUserId, + DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + return await GetManyAsync($"[{Schema}].[Event_ReadPageByOrganizationIdActingUserId]", + new Dictionary + { + ["@OrganizationId"] = organizationId, + ["@ActingUserId"] = actingUserId + }, startDate, endDate, pageOptions); + } + + public async Task> GetManyByProviderAsync(Guid providerId, + DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + return await GetManyAsync($"[{Schema}].[Event_ReadPageByProviderId]", + new Dictionary + { + ["@ProviderId"] = providerId + }, startDate, endDate, pageOptions); + } + + public async Task> GetManyByProviderActingUserAsync(Guid providerId, Guid actingUserId, + DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + return await GetManyAsync($"[{Schema}].[Event_ReadPageByProviderIdActingUserId]", + new Dictionary + { + ["@ProviderId"] = providerId, + ["@ActingUserId"] = actingUserId + }, startDate, endDate, pageOptions); + } + + public async Task> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate, + PageOptions pageOptions) + { + return await GetManyAsync($"[{Schema}].[Event_ReadPageByCipherId]", + new Dictionary + { + ["@OrganizationId"] = cipher.OrganizationId, + ["@UserId"] = cipher.UserId, + ["@CipherId"] = cipher.Id + }, startDate, endDate, pageOptions); + } + + public async Task CreateAsync(IEvent e) + { + if (!(e is Event ev)) { - return await GetManyAsync($"[{Schema}].[Event_ReadPageByUserId]", - new Dictionary - { - ["@UserId"] = userId - }, startDate, endDate, pageOptions); + ev = new Event(e); } - public async Task> GetManyByOrganizationAsync(Guid organizationId, - DateTime startDate, DateTime endDate, PageOptions pageOptions) + await base.CreateAsync(ev); + } + + public async Task CreateManyAsync(IEnumerable entities) + { + if (!entities?.Any() ?? true) { - return await GetManyAsync($"[{Schema}].[Event_ReadPageByOrganizationId]", - new Dictionary - { - ["@OrganizationId"] = organizationId - }, startDate, endDate, pageOptions); + return; } - public async Task> GetManyByOrganizationActingUserAsync(Guid organizationId, Guid actingUserId, - DateTime startDate, DateTime endDate, PageOptions pageOptions) + if (!entities.Skip(1).Any()) { - return await GetManyAsync($"[{Schema}].[Event_ReadPageByOrganizationIdActingUserId]", - new Dictionary - { - ["@OrganizationId"] = organizationId, - ["@ActingUserId"] = actingUserId - }, startDate, endDate, pageOptions); + await CreateAsync(entities.First()); + return; } - public async Task> GetManyByProviderAsync(Guid providerId, - DateTime startDate, DateTime endDate, PageOptions pageOptions) + using (var connection = new SqlConnection(ConnectionString)) { - return await GetManyAsync($"[{Schema}].[Event_ReadPageByProviderId]", - new Dictionary - { - ["@ProviderId"] = providerId - }, startDate, endDate, pageOptions); - } - - public async Task> GetManyByProviderActingUserAsync(Guid providerId, Guid actingUserId, - DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - return await GetManyAsync($"[{Schema}].[Event_ReadPageByProviderIdActingUserId]", - new Dictionary - { - ["@ProviderId"] = providerId, - ["@ActingUserId"] = actingUserId - }, startDate, endDate, pageOptions); - } - - public async Task> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate, - PageOptions pageOptions) - { - return await GetManyAsync($"[{Schema}].[Event_ReadPageByCipherId]", - new Dictionary - { - ["@OrganizationId"] = cipher.OrganizationId, - ["@UserId"] = cipher.UserId, - ["@CipherId"] = cipher.Id - }, startDate, endDate, pageOptions); - } - - public async Task CreateAsync(IEvent e) - { - if (!(e is Event ev)) + connection.Open(); + using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, null)) { - ev = new Event(e); + bulkCopy.DestinationTableName = "[dbo].[Event]"; + var dataTable = BuildEventsTable(bulkCopy, entities.Select(e => e is Event ? e as Event : new Event(e))); + await bulkCopy.WriteToServerAsync(dataTable); } - - await base.CreateAsync(ev); - } - - public async Task CreateManyAsync(IEnumerable entities) - { - if (!entities?.Any() ?? true) - { - return; - } - - if (!entities.Skip(1).Any()) - { - await CreateAsync(entities.First()); - return; - } - - using (var connection = new SqlConnection(ConnectionString)) - { - connection.Open(); - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, null)) - { - bulkCopy.DestinationTableName = "[dbo].[Event]"; - var dataTable = BuildEventsTable(bulkCopy, entities.Select(e => e is Event ? e as Event : new Event(e))); - await bulkCopy.WriteToServerAsync(dataTable); - } - } - } - - private async Task> GetManyAsync(string sprocName, - IDictionary sprocParams, DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - DateTime? beforeDate = null; - if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && - long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) - { - beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); - } - - var parameters = new DynamicParameters(sprocParams); - parameters.Add("@PageSize", pageOptions.PageSize, DbType.Int32); - // Explicitly use DbType.DateTime2 for proper precision. - // ref: https://github.com/StackExchange/Dapper/issues/229 - parameters.Add("@StartDate", startDate.ToUniversalTime(), DbType.DateTime2, null, 7); - parameters.Add("@EndDate", endDate.ToUniversalTime(), DbType.DateTime2, null, 7); - parameters.Add("@BeforeDate", beforeDate, DbType.DateTime2, null, 7); - - using (var connection = new SqlConnection(ConnectionString)) - { - var events = (await connection.QueryAsync(sprocName, parameters, - commandType: CommandType.StoredProcedure)).ToList(); - - var result = new PagedResult(); - if (events.Any() && events.Count >= pageOptions.PageSize) - { - result.ContinuationToken = events.Last().Date.ToBinary().ToString(); - } - result.Data.AddRange(events); - return result; - } - } - - private DataTable BuildEventsTable(SqlBulkCopy bulkCopy, IEnumerable events) - { - var e = events.FirstOrDefault(); - if (e == null) - { - throw new ApplicationException("Must have some events to bulk import."); - } - - var eventsTable = new DataTable("EventDataTable"); - - var idColumn = new DataColumn(nameof(e.Id), e.Id.GetType()); - eventsTable.Columns.Add(idColumn); - var typeColumn = new DataColumn(nameof(e.Type), typeof(int)); - eventsTable.Columns.Add(typeColumn); - var userIdColumn = new DataColumn(nameof(e.UserId), typeof(Guid)); - eventsTable.Columns.Add(userIdColumn); - var organizationIdColumn = new DataColumn(nameof(e.OrganizationId), typeof(Guid)); - eventsTable.Columns.Add(organizationIdColumn); - var cipherIdColumn = new DataColumn(nameof(e.CipherId), typeof(Guid)); - eventsTable.Columns.Add(cipherIdColumn); - var collectionIdColumn = new DataColumn(nameof(e.CollectionId), typeof(Guid)); - eventsTable.Columns.Add(collectionIdColumn); - var policyIdColumn = new DataColumn(nameof(e.PolicyId), typeof(Guid)); - eventsTable.Columns.Add(policyIdColumn); - var groupIdColumn = new DataColumn(nameof(e.GroupId), typeof(Guid)); - eventsTable.Columns.Add(groupIdColumn); - var organizationUserIdColumn = new DataColumn(nameof(e.OrganizationUserId), typeof(Guid)); - eventsTable.Columns.Add(organizationUserIdColumn); - var actingUserIdColumn = new DataColumn(nameof(e.ActingUserId), typeof(Guid)); - eventsTable.Columns.Add(actingUserIdColumn); - var deviceTypeColumn = new DataColumn(nameof(e.DeviceType), typeof(int)); - eventsTable.Columns.Add(deviceTypeColumn); - var ipAddressColumn = new DataColumn(nameof(e.IpAddress), typeof(string)); - eventsTable.Columns.Add(ipAddressColumn); - var dateColumn = new DataColumn(nameof(e.Date), typeof(DateTime)); - eventsTable.Columns.Add(dateColumn); - - foreach (DataColumn col in eventsTable.Columns) - { - bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); - } - - var keys = new DataColumn[1]; - keys[0] = idColumn; - eventsTable.PrimaryKey = keys; - - foreach (var ev in events) - { - ev.SetNewId(); - - var row = eventsTable.NewRow(); - - row[idColumn] = ev.Id; - row[typeColumn] = (int)ev.Type; - row[userIdColumn] = ev.UserId.HasValue ? (object)ev.UserId.Value : DBNull.Value; - row[organizationIdColumn] = ev.OrganizationId.HasValue ? (object)ev.OrganizationId.Value : DBNull.Value; - row[cipherIdColumn] = ev.CipherId.HasValue ? (object)ev.CipherId.Value : DBNull.Value; - row[collectionIdColumn] = ev.CollectionId.HasValue ? (object)ev.CollectionId.Value : DBNull.Value; - row[policyIdColumn] = ev.PolicyId.HasValue ? (object)ev.PolicyId.Value : DBNull.Value; - row[groupIdColumn] = ev.GroupId.HasValue ? (object)ev.GroupId.Value : DBNull.Value; - row[organizationUserIdColumn] = ev.OrganizationUserId.HasValue ? - (object)ev.OrganizationUserId.Value : DBNull.Value; - row[actingUserIdColumn] = ev.ActingUserId.HasValue ? (object)ev.ActingUserId.Value : DBNull.Value; - row[deviceTypeColumn] = ev.DeviceType.HasValue ? (object)ev.DeviceType.Value : DBNull.Value; - row[ipAddressColumn] = ev.IpAddress != null ? (object)ev.IpAddress : DBNull.Value; - row[dateColumn] = ev.Date; - - eventsTable.Rows.Add(row); - } - - return eventsTable; } } + + private async Task> GetManyAsync(string sprocName, + IDictionary sprocParams, DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + DateTime? beforeDate = null; + if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && + long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) + { + beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); + } + + var parameters = new DynamicParameters(sprocParams); + parameters.Add("@PageSize", pageOptions.PageSize, DbType.Int32); + // Explicitly use DbType.DateTime2 for proper precision. + // ref: https://github.com/StackExchange/Dapper/issues/229 + parameters.Add("@StartDate", startDate.ToUniversalTime(), DbType.DateTime2, null, 7); + parameters.Add("@EndDate", endDate.ToUniversalTime(), DbType.DateTime2, null, 7); + parameters.Add("@BeforeDate", beforeDate, DbType.DateTime2, null, 7); + + using (var connection = new SqlConnection(ConnectionString)) + { + var events = (await connection.QueryAsync(sprocName, parameters, + commandType: CommandType.StoredProcedure)).ToList(); + + var result = new PagedResult(); + if (events.Any() && events.Count >= pageOptions.PageSize) + { + result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + } + result.Data.AddRange(events); + return result; + } + } + + private DataTable BuildEventsTable(SqlBulkCopy bulkCopy, IEnumerable events) + { + var e = events.FirstOrDefault(); + if (e == null) + { + throw new ApplicationException("Must have some events to bulk import."); + } + + var eventsTable = new DataTable("EventDataTable"); + + var idColumn = new DataColumn(nameof(e.Id), e.Id.GetType()); + eventsTable.Columns.Add(idColumn); + var typeColumn = new DataColumn(nameof(e.Type), typeof(int)); + eventsTable.Columns.Add(typeColumn); + var userIdColumn = new DataColumn(nameof(e.UserId), typeof(Guid)); + eventsTable.Columns.Add(userIdColumn); + var organizationIdColumn = new DataColumn(nameof(e.OrganizationId), typeof(Guid)); + eventsTable.Columns.Add(organizationIdColumn); + var cipherIdColumn = new DataColumn(nameof(e.CipherId), typeof(Guid)); + eventsTable.Columns.Add(cipherIdColumn); + var collectionIdColumn = new DataColumn(nameof(e.CollectionId), typeof(Guid)); + eventsTable.Columns.Add(collectionIdColumn); + var policyIdColumn = new DataColumn(nameof(e.PolicyId), typeof(Guid)); + eventsTable.Columns.Add(policyIdColumn); + var groupIdColumn = new DataColumn(nameof(e.GroupId), typeof(Guid)); + eventsTable.Columns.Add(groupIdColumn); + var organizationUserIdColumn = new DataColumn(nameof(e.OrganizationUserId), typeof(Guid)); + eventsTable.Columns.Add(organizationUserIdColumn); + var actingUserIdColumn = new DataColumn(nameof(e.ActingUserId), typeof(Guid)); + eventsTable.Columns.Add(actingUserIdColumn); + var deviceTypeColumn = new DataColumn(nameof(e.DeviceType), typeof(int)); + eventsTable.Columns.Add(deviceTypeColumn); + var ipAddressColumn = new DataColumn(nameof(e.IpAddress), typeof(string)); + eventsTable.Columns.Add(ipAddressColumn); + var dateColumn = new DataColumn(nameof(e.Date), typeof(DateTime)); + eventsTable.Columns.Add(dateColumn); + + foreach (DataColumn col in eventsTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[1]; + keys[0] = idColumn; + eventsTable.PrimaryKey = keys; + + foreach (var ev in events) + { + ev.SetNewId(); + + var row = eventsTable.NewRow(); + + row[idColumn] = ev.Id; + row[typeColumn] = (int)ev.Type; + row[userIdColumn] = ev.UserId.HasValue ? (object)ev.UserId.Value : DBNull.Value; + row[organizationIdColumn] = ev.OrganizationId.HasValue ? (object)ev.OrganizationId.Value : DBNull.Value; + row[cipherIdColumn] = ev.CipherId.HasValue ? (object)ev.CipherId.Value : DBNull.Value; + row[collectionIdColumn] = ev.CollectionId.HasValue ? (object)ev.CollectionId.Value : DBNull.Value; + row[policyIdColumn] = ev.PolicyId.HasValue ? (object)ev.PolicyId.Value : DBNull.Value; + row[groupIdColumn] = ev.GroupId.HasValue ? (object)ev.GroupId.Value : DBNull.Value; + row[organizationUserIdColumn] = ev.OrganizationUserId.HasValue ? + (object)ev.OrganizationUserId.Value : DBNull.Value; + row[actingUserIdColumn] = ev.ActingUserId.HasValue ? (object)ev.ActingUserId.Value : DBNull.Value; + row[deviceTypeColumn] = ev.DeviceType.HasValue ? (object)ev.DeviceType.Value : DBNull.Value; + row[ipAddressColumn] = ev.IpAddress != null ? (object)ev.IpAddress : DBNull.Value; + row[dateColumn] = ev.Date; + + eventsTable.Rows.Add(row); + } + + return eventsTable; + } } diff --git a/src/Infrastructure.Dapper/Repositories/FolderRepository.cs b/src/Infrastructure.Dapper/Repositories/FolderRepository.cs index a0bd11c9d..6500d35dd 100644 --- a/src/Infrastructure.Dapper/Repositories/FolderRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/FolderRepository.cs @@ -5,40 +5,39 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class FolderRepository : Repository, IFolderRepository { - public class FolderRepository : Repository, IFolderRepository + public FolderRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public FolderRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetByIdAsync(Guid id, Guid userId) { - public FolderRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public FolderRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetByIdAsync(Guid id, Guid userId) + var folder = await GetByIdAsync(id); + if (folder == null || folder.UserId != userId) { - var folder = await GetByIdAsync(id); - if (folder == null || folder.UserId != userId) - { - return null; - } - - return folder; + return null; } - public async Task> GetManyByUserIdAsync(Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[Folder_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + return folder; + } - return results.ToList(); - } + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[Folder_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); } } } diff --git a/src/Infrastructure.Dapper/Repositories/GrantRepository.cs b/src/Infrastructure.Dapper/Repositories/GrantRepository.cs index 6596fa510..168576fa9 100644 --- a/src/Infrastructure.Dapper/Repositories/GrantRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/GrantRepository.cs @@ -5,76 +5,75 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class GrantRepository : BaseRepository, IGrantRepository { - public class GrantRepository : BaseRepository, IGrantRepository + public GrantRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public GrantRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetByKeyAsync(string key) { - public GrantRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public GrantRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetByKeyAsync(string key) + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[Grant_ReadByKey]", - new { Key = key }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[Grant_ReadByKey]", + new { Key = key }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); - } + return results.SingleOrDefault(); } + } - public async Task> GetManyAsync(string subjectId, string sessionId, - string clientId, string type) + public async Task> GetManyAsync(string subjectId, string sessionId, + string clientId, string type) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[Grant_Read]", - new { SubjectId = subjectId, SessionId = sessionId, ClientId = clientId, Type = type }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[Grant_Read]", + new { SubjectId = subjectId, SessionId = sessionId, ClientId = clientId, Type = type }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task SaveAsync(Grant obj) + public async Task SaveAsync(Grant obj) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - "[dbo].[Grant_Save]", - obj, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + "[dbo].[Grant_Save]", + obj, + commandType: CommandType.StoredProcedure); } + } - public async Task DeleteByKeyAsync(string key) + public async Task DeleteByKeyAsync(string key) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync( - "[dbo].[Grant_DeleteByKey]", - new { Key = key }, - commandType: CommandType.StoredProcedure); - } + await connection.ExecuteAsync( + "[dbo].[Grant_DeleteByKey]", + new { Key = key }, + commandType: CommandType.StoredProcedure); } + } - public async Task DeleteManyAsync(string subjectId, string sessionId, string clientId, string type) + public async Task DeleteManyAsync(string subjectId, string sessionId, string clientId, string type) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync( - "[dbo].[Grant_Delete]", - new { SubjectId = subjectId, SessionId = sessionId, ClientId = clientId, Type = type }, - commandType: CommandType.StoredProcedure); - } + await connection.ExecuteAsync( + "[dbo].[Grant_Delete]", + new { SubjectId = subjectId, SessionId = sessionId, ClientId = clientId, Type = type }, + commandType: CommandType.StoredProcedure); } } } diff --git a/src/Infrastructure.Dapper/Repositories/GroupRepository.cs b/src/Infrastructure.Dapper/Repositories/GroupRepository.cs index 31f6a29bb..eb0482bf3 100644 --- a/src/Infrastructure.Dapper/Repositories/GroupRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/GroupRepository.cs @@ -7,135 +7,134 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class GroupRepository : Repository, IGroupRepository { - public class GroupRepository : Repository, IGroupRepository + public GroupRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public GroupRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task>> GetByIdWithCollectionsAsync(Guid id) { - public GroupRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public GroupRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task>> GetByIdWithCollectionsAsync(Guid id) + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryMultipleAsync( - $"[{Schema}].[Group_ReadWithCollectionsById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryMultipleAsync( + $"[{Schema}].[Group_ReadWithCollectionsById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); - var group = await results.ReadFirstOrDefaultAsync(); - var colletions = (await results.ReadAsync()).ToList(); + var group = await results.ReadFirstOrDefaultAsync(); + var colletions = (await results.ReadAsync()).ToList(); - return new Tuple>(group, colletions); - } + return new Tuple>(group, colletions); } + } - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[Group_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[Group_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetManyIdsByUserIdAsync(Guid organizationUserId) + public async Task> GetManyIdsByUserIdAsync(Guid organizationUserId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[GroupUser_ReadGroupIdsByOrganizationUserId]", - new { OrganizationUserId = organizationUserId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[GroupUser_ReadGroupIdsByOrganizationUserId]", + new { OrganizationUserId = organizationUserId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetManyUserIdsByIdAsync(Guid id) + public async Task> GetManyUserIdsByIdAsync(Guid id) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[GroupUser_ReadOrganizationUserIdsByGroupId]", - new { GroupId = id }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[GroupUser_ReadOrganizationUserIdsByGroupId]", + new { GroupId = id }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetManyGroupUsersByOrganizationIdAsync(Guid organizationId) + public async Task> GetManyGroupUsersByOrganizationIdAsync(Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[GroupUser_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[GroupUser_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task CreateAsync(Group obj, IEnumerable collections) + public async Task CreateAsync(Group obj, IEnumerable collections) + { + obj.SetNewId(); + var objWithCollections = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); + objWithCollections.Collections = collections.ToArrayTVP(); + + using (var connection = new SqlConnection(ConnectionString)) { - obj.SetNewId(); - var objWithCollections = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); - objWithCollections.Collections = collections.ToArrayTVP(); - - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Group_CreateWithCollections]", - objWithCollections, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[Group_CreateWithCollections]", + objWithCollections, + commandType: CommandType.StoredProcedure); } + } - public async Task ReplaceAsync(Group obj, IEnumerable collections) + public async Task ReplaceAsync(Group obj, IEnumerable collections) + { + var objWithCollections = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); + objWithCollections.Collections = collections.ToArrayTVP(); + + using (var connection = new SqlConnection(ConnectionString)) { - var objWithCollections = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj)); - objWithCollections.Collections = collections.ToArrayTVP(); - - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Group_UpdateWithCollections]", - objWithCollections, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[Group_UpdateWithCollections]", + objWithCollections, + commandType: CommandType.StoredProcedure); } + } - public async Task DeleteUserAsync(Guid groupId, Guid organizationUserId) + public async Task DeleteUserAsync(Guid groupId, Guid organizationUserId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[GroupUser_Delete]", - new { GroupId = groupId, OrganizationUserId = organizationUserId }, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[GroupUser_Delete]", + new { GroupId = groupId, OrganizationUserId = organizationUserId }, + commandType: CommandType.StoredProcedure); } + } - public async Task UpdateUsersAsync(Guid groupId, IEnumerable organizationUserIds) + public async Task UpdateUsersAsync(Guid groupId, IEnumerable organizationUserIds) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - "[dbo].[GroupUser_UpdateUsers]", - new { GroupId = groupId, OrganizationUserIds = organizationUserIds.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + "[dbo].[GroupUser_UpdateUsers]", + new { GroupId = groupId, OrganizationUserIds = organizationUserIds.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); } } } diff --git a/src/Infrastructure.Dapper/Repositories/InstallationRepository.cs b/src/Infrastructure.Dapper/Repositories/InstallationRepository.cs index b82b13c49..0bb38761c 100644 --- a/src/Infrastructure.Dapper/Repositories/InstallationRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/InstallationRepository.cs @@ -2,16 +2,15 @@ using Bit.Core.Repositories; using Bit.Core.Settings; -namespace Bit.Infrastructure.Dapper.Repositories -{ - public class InstallationRepository : Repository, IInstallationRepository - { - public InstallationRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } +namespace Bit.Infrastructure.Dapper.Repositories; - public InstallationRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - } +public class InstallationRepository : Repository, IInstallationRepository +{ + public InstallationRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public InstallationRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } } diff --git a/src/Infrastructure.Dapper/Repositories/MaintenanceRepository.cs b/src/Infrastructure.Dapper/Repositories/MaintenanceRepository.cs index 05c6e9d63..fb5bf3091 100644 --- a/src/Infrastructure.Dapper/Repositories/MaintenanceRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/MaintenanceRepository.cs @@ -4,74 +4,73 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class MaintenanceRepository : BaseRepository, IMaintenanceRepository { - public class MaintenanceRepository : BaseRepository, IMaintenanceRepository + public MaintenanceRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public MaintenanceRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task UpdateStatisticsAsync() { - public MaintenanceRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public MaintenanceRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task UpdateStatisticsAsync() + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync( - "[dbo].[AzureSQLMaintenance]", - new { operation = "statistics", mode = "smart", LogToTable = true }, - commandType: CommandType.StoredProcedure, - commandTimeout: 172800); - } + await connection.ExecuteAsync( + "[dbo].[AzureSQLMaintenance]", + new { operation = "statistics", mode = "smart", LogToTable = true }, + commandType: CommandType.StoredProcedure, + commandTimeout: 172800); } + } - public async Task DisableCipherAutoStatsAsync() + public async Task DisableCipherAutoStatsAsync() + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync( - "sp_autostats", - new { tblname = "[dbo].[Cipher]", flagc = "OFF" }, - commandType: CommandType.StoredProcedure); - } + await connection.ExecuteAsync( + "sp_autostats", + new { tblname = "[dbo].[Cipher]", flagc = "OFF" }, + commandType: CommandType.StoredProcedure); } + } - public async Task RebuildIndexesAsync() + public async Task RebuildIndexesAsync() + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync( - "[dbo].[AzureSQLMaintenance]", - new { operation = "index", mode = "smart", LogToTable = true }, - commandType: CommandType.StoredProcedure, - commandTimeout: 172800); - } + await connection.ExecuteAsync( + "[dbo].[AzureSQLMaintenance]", + new { operation = "index", mode = "smart", LogToTable = true }, + commandType: CommandType.StoredProcedure, + commandTimeout: 172800); } + } - public async Task DeleteExpiredGrantsAsync() + public async Task DeleteExpiredGrantsAsync() + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync( - "[dbo].[Grant_DeleteExpired]", - commandType: CommandType.StoredProcedure, - commandTimeout: 172800); - } + await connection.ExecuteAsync( + "[dbo].[Grant_DeleteExpired]", + commandType: CommandType.StoredProcedure, + commandTimeout: 172800); } + } - public async Task DeleteExpiredSponsorshipsAsync(DateTime validUntilBeforeDate) + public async Task DeleteExpiredSponsorshipsAsync(DateTime validUntilBeforeDate) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync( - "[dbo].[OrganizationSponsorship_DeleteExpired]", - new { ValidUntilBeforeDate = validUntilBeforeDate }, - commandType: CommandType.StoredProcedure, - commandTimeout: 172800); - } + await connection.ExecuteAsync( + "[dbo].[OrganizationSponsorship_DeleteExpired]", + new { ValidUntilBeforeDate = validUntilBeforeDate }, + commandType: CommandType.StoredProcedure, + commandTimeout: 172800); } } } diff --git a/src/Infrastructure.Dapper/Repositories/OrganizationApiKeyRepository.cs b/src/Infrastructure.Dapper/Repositories/OrganizationApiKeyRepository.cs index b0694862f..05eaac68f 100644 --- a/src/Infrastructure.Dapper/Repositories/OrganizationApiKeyRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/OrganizationApiKeyRepository.cs @@ -6,33 +6,32 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class OrganizationApiKeyRepository : Repository, IOrganizationApiKeyRepository { - public class OrganizationApiKeyRepository : Repository, IOrganizationApiKeyRepository + public OrganizationApiKeyRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) { - public OrganizationApiKeyRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + + } + + public OrganizationApiKeyRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task> GetManyByOrganizationIdTypeAsync(Guid organizationId, OrganizationApiKeyType? type = null) + { + using (var connection = new SqlConnection(ConnectionString)) { - - } - - public OrganizationApiKeyRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task> GetManyByOrganizationIdTypeAsync(Guid organizationId, OrganizationApiKeyType? type = null) - { - using (var connection = new SqlConnection(ConnectionString)) - { - return await connection.QueryAsync( - "[dbo].[OrganizationApikey_ReadManyByOrganizationIdType]", - new - { - OrganizationId = organizationId, - Type = type, - }, - commandType: CommandType.StoredProcedure); - } + return await connection.QueryAsync( + "[dbo].[OrganizationApikey_ReadManyByOrganizationIdType]", + new + { + OrganizationId = organizationId, + Type = type, + }, + commandType: CommandType.StoredProcedure); } } } diff --git a/src/Infrastructure.Dapper/Repositories/OrganizationConnectionRepository.cs b/src/Infrastructure.Dapper/Repositories/OrganizationConnectionRepository.cs index 6de4559fa..1cc997588 100644 --- a/src/Infrastructure.Dapper/Repositories/OrganizationConnectionRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/OrganizationConnectionRepository.cs @@ -6,32 +6,31 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class OrganizationConnectionRepository : Repository, IOrganizationConnectionRepository { - public class OrganizationConnectionRepository : Repository, IOrganizationConnectionRepository + public OrganizationConnectionRepository(GlobalSettings globalSettings) + : base(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public async Task> GetByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type) { - public OrganizationConnectionRepository(GlobalSettings globalSettings) - : base(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public async Task> GetByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type) + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[OrganizationConnection_ReadByOrganizationIdType]", - new - { - OrganizationId = organizationId, - Type = type - }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[OrganizationConnection_ReadByOrganizationIdType]", + new + { + OrganizationId = organizationId, + Type = type + }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } - - public async Task> GetEnabledByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type) => - (await GetByOrganizationIdTypeAsync(organizationId, type)).Where(c => c.Enabled).ToList(); } + + public async Task> GetEnabledByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type) => + (await GetByOrganizationIdTypeAsync(organizationId, type)).Where(c => c.Enabled).ToList(); } diff --git a/src/Infrastructure.Dapper/Repositories/OrganizationRepository.cs b/src/Infrastructure.Dapper/Repositories/OrganizationRepository.cs index 05cc3c92a..a4e294b29 100644 --- a/src/Infrastructure.Dapper/Repositories/OrganizationRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/OrganizationRepository.cs @@ -6,93 +6,92 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class OrganizationRepository : Repository, IOrganizationRepository { - public class OrganizationRepository : Repository, IOrganizationRepository + public OrganizationRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public OrganizationRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetByIdentifierAsync(string identifier) { - public OrganizationRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public OrganizationRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetByIdentifierAsync(string identifier) + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[Organization_ReadByIdentifier]", - new { Identifier = identifier }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[Organization_ReadByIdentifier]", + new { Identifier = identifier }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); - } + return results.SingleOrDefault(); } + } - public async Task> GetManyByEnabledAsync() + public async Task> GetManyByEnabledAsync() + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[Organization_ReadByEnabled]", - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[Organization_ReadByEnabled]", + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetManyByUserIdAsync(Guid userId) + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[Organization_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[Organization_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> SearchAsync(string name, string userEmail, bool? paid, - int skip, int take) + public async Task> SearchAsync(string name, string userEmail, bool? paid, + int skip, int take) + { + using (var connection = new SqlConnection(ReadOnlyConnectionString)) { - using (var connection = new SqlConnection(ReadOnlyConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[Organization_Search]", - new { Name = name, UserEmail = userEmail, Paid = paid, Skip = skip, Take = take }, - commandType: CommandType.StoredProcedure, - commandTimeout: 120); + var results = await connection.QueryAsync( + "[dbo].[Organization_Search]", + new { Name = name, UserEmail = userEmail, Paid = paid, Skip = skip, Take = take }, + commandType: CommandType.StoredProcedure, + commandTimeout: 120); - return results.ToList(); - } + return results.ToList(); } + } - public async Task UpdateStorageAsync(Guid id) + public async Task UpdateStorageAsync(Guid id) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync( - "[dbo].[Organization_UpdateStorage]", - new { Id = id }, - commandType: CommandType.StoredProcedure, - commandTimeout: 180); - } + await connection.ExecuteAsync( + "[dbo].[Organization_UpdateStorage]", + new { Id = id }, + commandType: CommandType.StoredProcedure, + commandTimeout: 180); } + } - public async Task> GetManyAbilitiesAsync() + public async Task> GetManyAbilitiesAsync() + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[Organization_ReadAbilities]", - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[Organization_ReadAbilities]", + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } } } diff --git a/src/Infrastructure.Dapper/Repositories/OrganizationSponsorshipRepository.cs b/src/Infrastructure.Dapper/Repositories/OrganizationSponsorshipRepository.cs index 6e4ca9904..11e453cac 100644 --- a/src/Infrastructure.Dapper/Repositories/OrganizationSponsorshipRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/OrganizationSponsorshipRepository.cs @@ -5,143 +5,142 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class OrganizationSponsorshipRepository : Repository, IOrganizationSponsorshipRepository { - public class OrganizationSponsorshipRepository : Repository, IOrganizationSponsorshipRepository + public OrganizationSponsorshipRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public OrganizationSponsorshipRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task> CreateManyAsync(IEnumerable organizationSponsorships) { - public OrganizationSponsorshipRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public OrganizationSponsorshipRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task> CreateManyAsync(IEnumerable organizationSponsorships) + if (!organizationSponsorships.Any()) { - if (!organizationSponsorships.Any()) - { - return default; - } - - foreach (var organizationSponsorship in organizationSponsorships) - { - organizationSponsorship.SetNewId(); - } - - var orgSponsorshipsTVP = organizationSponsorships.ToTvp(); - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[dbo].[OrganizationSponsorship_CreateMany]", - new { OrganizationSponsorshipsInput = orgSponsorshipsTVP }, - commandType: CommandType.StoredProcedure); - } - - return organizationSponsorships.Select(u => u.Id).ToList(); + return default; } - public async Task ReplaceManyAsync(IEnumerable organizationSponsorships) + foreach (var organizationSponsorship in organizationSponsorships) { - if (!organizationSponsorships.Any()) - { - return; - } - - var orgSponsorshipsTVP = organizationSponsorships.ToTvp(); - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[dbo].[OrganizationSponsorship_UpdateMany]", - new { OrganizationSponsorshipsInput = orgSponsorshipsTVP }, - commandType: CommandType.StoredProcedure); - } + organizationSponsorship.SetNewId(); } - public async Task UpsertManyAsync(IEnumerable organizationSponsorships) + var orgSponsorshipsTVP = organizationSponsorships.ToTvp(); + using (var connection = new SqlConnection(ConnectionString)) { - var createSponsorships = new List(); - var replaceSponsorships = new List(); - foreach (var organizationSponsorship in organizationSponsorships) - { - if (organizationSponsorship.Id.Equals(default)) - { - createSponsorships.Add(organizationSponsorship); - } - else - { - replaceSponsorships.Add(organizationSponsorship); - } - } - - await CreateManyAsync(createSponsorships); - await ReplaceManyAsync(replaceSponsorships); - } - - public async Task DeleteManyAsync(IEnumerable organizationSponsorshipIds) - { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync("[dbo].[OrganizationSponsorship_DeleteByIds]", - new { Ids = organizationSponsorshipIds.ToGuidIdArrayTVP() }, commandType: CommandType.StoredProcedure); - } - } - - public async Task GetBySponsoringOrganizationUserIdAsync(Guid sponsoringOrganizationUserId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationSponsorship_ReadBySponsoringOrganizationUserId]", - new - { - SponsoringOrganizationUserId = sponsoringOrganizationUserId - }, - commandType: CommandType.StoredProcedure); - - return results.SingleOrDefault(); - } - } - - public async Task GetBySponsoredOrganizationIdAsync(Guid sponsoredOrganizationId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationSponsorship_ReadBySponsoredOrganizationId]", - new { SponsoredOrganizationId = sponsoredOrganizationId }, - commandType: CommandType.StoredProcedure); - - return results.SingleOrDefault(); - } - } - - public async Task GetLatestSyncDateBySponsoringOrganizationIdAsync(Guid sponsoringOrganizationId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - return await connection.QuerySingleOrDefaultAsync( - "[dbo].[OrganizationSponsorship_ReadLatestBySponsoringOrganizationId]", - new { SponsoringOrganizationId = sponsoringOrganizationId }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task> GetManyBySponsoringOrganizationAsync(Guid sponsoringOrganizationId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationSponsorship_ReadBySponsoringOrganizationId]", - new - { - SponsoringOrganizationId = sponsoringOrganizationId - }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } + var results = await connection.ExecuteAsync( + $"[dbo].[OrganizationSponsorship_CreateMany]", + new { OrganizationSponsorshipsInput = orgSponsorshipsTVP }, + commandType: CommandType.StoredProcedure); } + return organizationSponsorships.Select(u => u.Id).ToList(); } + + public async Task ReplaceManyAsync(IEnumerable organizationSponsorships) + { + if (!organizationSponsorships.Any()) + { + return; + } + + var orgSponsorshipsTVP = organizationSponsorships.ToTvp(); + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[dbo].[OrganizationSponsorship_UpdateMany]", + new { OrganizationSponsorshipsInput = orgSponsorshipsTVP }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task UpsertManyAsync(IEnumerable organizationSponsorships) + { + var createSponsorships = new List(); + var replaceSponsorships = new List(); + foreach (var organizationSponsorship in organizationSponsorships) + { + if (organizationSponsorship.Id.Equals(default)) + { + createSponsorships.Add(organizationSponsorship); + } + else + { + replaceSponsorships.Add(organizationSponsorship); + } + } + + await CreateManyAsync(createSponsorships); + await ReplaceManyAsync(replaceSponsorships); + } + + public async Task DeleteManyAsync(IEnumerable organizationSponsorshipIds) + { + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync("[dbo].[OrganizationSponsorship_DeleteByIds]", + new { Ids = organizationSponsorshipIds.ToGuidIdArrayTVP() }, commandType: CommandType.StoredProcedure); + } + } + + public async Task GetBySponsoringOrganizationUserIdAsync(Guid sponsoringOrganizationUserId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationSponsorship_ReadBySponsoringOrganizationUserId]", + new + { + SponsoringOrganizationUserId = sponsoringOrganizationUserId + }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } + } + + public async Task GetBySponsoredOrganizationIdAsync(Guid sponsoredOrganizationId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationSponsorship_ReadBySponsoredOrganizationId]", + new { SponsoredOrganizationId = sponsoredOrganizationId }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } + } + + public async Task GetLatestSyncDateBySponsoringOrganizationIdAsync(Guid sponsoringOrganizationId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + return await connection.QuerySingleOrDefaultAsync( + "[dbo].[OrganizationSponsorship_ReadLatestBySponsoringOrganizationId]", + new { SponsoringOrganizationId = sponsoringOrganizationId }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task> GetManyBySponsoringOrganizationAsync(Guid sponsoringOrganizationId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationSponsorship_ReadBySponsoringOrganizationId]", + new + { + SponsoringOrganizationId = sponsoringOrganizationId + }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + } diff --git a/src/Infrastructure.Dapper/Repositories/OrganizationUserRepository.cs b/src/Infrastructure.Dapper/Repositories/OrganizationUserRepository.cs index 856fcb7a4..06aede3da 100644 --- a/src/Infrastructure.Dapper/Repositories/OrganizationUserRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/OrganizationUserRepository.cs @@ -9,423 +9,422 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class OrganizationUserRepository : Repository, IOrganizationUserRepository { - public class OrganizationUserRepository : Repository, IOrganizationUserRepository + /// + /// For use with methods with TDS stream issues. + /// This has been observed in Linux-hosted SqlServers with large table-valued-parameters + /// https://github.com/dotnet/SqlClient/issues/54 + /// + private string _marsConnectionString; + + public OrganizationUserRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) { - /// - /// For use with methods with TDS stream issues. - /// This has been observed in Linux-hosted SqlServers with large table-valued-parameters - /// https://github.com/dotnet/SqlClient/issues/54 - /// - private string _marsConnectionString; - - public OrganizationUserRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + var builder = new SqlConnectionStringBuilder(ConnectionString) { - var builder = new SqlConnectionStringBuilder(ConnectionString) - { - MultipleActiveResultSets = true, - }; - _marsConnectionString = builder.ToString(); + MultipleActiveResultSets = true, + }; + _marsConnectionString = builder.ToString(); + } + + public OrganizationUserRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetCountByOrganizationIdAsync(Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteScalarAsync( + "[dbo].[OrganizationUser_ReadCountByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); + + return results; + } + } + + public async Task GetCountByFreeOrganizationAdminUserAsync(Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteScalarAsync( + "[dbo].[OrganizationUser_ReadCountByFreeOrganizationAdminUser]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); + + return results; + } + } + + public async Task GetCountByOnlyOwnerAsync(Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteScalarAsync( + "[dbo].[OrganizationUser_ReadCountByOnlyOwner]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); + + return results; + } + } + + public async Task GetCountByOrganizationAsync(Guid organizationId, string email, bool onlyRegisteredUsers) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var result = await connection.ExecuteScalarAsync( + "[dbo].[OrganizationUser_ReadCountByOrganizationIdEmail]", + new { OrganizationId = organizationId, Email = email, OnlyUsers = onlyRegisteredUsers }, + commandType: CommandType.StoredProcedure); + + return result; + } + } + + public async Task> SelectKnownEmailsAsync(Guid organizationId, IEnumerable emails, + bool onlyRegisteredUsers) + { + var emailsTvp = emails.ToArrayTVP("Email"); + using (var connection = new SqlConnection(_marsConnectionString)) + { + var result = await connection.QueryAsync( + "[dbo].[OrganizationUser_SelectKnownEmails]", + new { OrganizationId = organizationId, Emails = emailsTvp, OnlyUsers = onlyRegisteredUsers }, + commandType: CommandType.StoredProcedure); + + // Return as a list to avoid timing out the sql connection + return result.ToList(); + } + } + + public async Task GetByOrganizationAsync(Guid organizationId, Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUser_ReadByOrganizationIdUserId]", + new { OrganizationId = organizationId, UserId = userId }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } + } + + public async Task> GetManyByUserAsync(Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUser_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task> GetManyByOrganizationAsync(Guid organizationId, + OrganizationUserType? type) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUser_ReadByOrganizationId]", + new { OrganizationId = organizationId, Type = type }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task>> GetByIdWithCollectionsAsync(Guid id) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryMultipleAsync( + "[dbo].[OrganizationUser_ReadWithCollectionsById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); + + var user = (await results.ReadAsync()).SingleOrDefault(); + var collections = (await results.ReadAsync()).ToList(); + return new Tuple>(user, collections); + } + } + + public async Task GetDetailsByIdAsync(Guid id) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUserUserDetails_ReadById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } + } + public async Task>> + GetDetailsByIdWithCollectionsAsync(Guid id) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryMultipleAsync( + "[dbo].[OrganizationUserUserDetails_ReadWithCollectionsById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); + + var user = (await results.ReadAsync()).SingleOrDefault(); + var collections = (await results.ReadAsync()).ToList(); + return new Tuple>(user, collections); + } + } + + public async Task> GetManyDetailsByOrganizationAsync(Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUserUserDetails_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task> GetManyDetailsByUserAsync(Guid userId, + OrganizationUserStatusType? status = null) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUserOrganizationDetails_ReadByUserIdStatus]", + new { UserId = userId, Status = status }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task GetDetailsByUserAsync(Guid userId, + Guid organizationId, OrganizationUserStatusType? status = null) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUserOrganizationDetails_ReadByUserIdStatusOrganizationId]", + new { UserId = userId, Status = status, OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } + } + + public async Task UpdateGroupsAsync(Guid orgUserId, IEnumerable groupIds) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + "[dbo].[GroupUser_UpdateGroups]", + new { OrganizationUserId = orgUserId, GroupIds = groupIds.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + } + } + + public async Task CreateAsync(OrganizationUser obj, IEnumerable collections) + { + obj.SetNewId(); + var objWithCollections = JsonSerializer.Deserialize( + JsonSerializer.Serialize(obj)); + objWithCollections.Collections = collections.ToArrayTVP(); + + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.ExecuteAsync( + $"[{Schema}].[OrganizationUser_CreateWithCollections]", + objWithCollections, + commandType: CommandType.StoredProcedure); } - public OrganizationUserRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } + return obj.Id; + } - public async Task GetCountByOrganizationIdAsync(Guid organizationId) + public async Task ReplaceAsync(OrganizationUser obj, IEnumerable collections) + { + var objWithCollections = JsonSerializer.Deserialize( + JsonSerializer.Serialize(obj)); + objWithCollections.Collections = collections.ToArrayTVP(); + + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteScalarAsync( - "[dbo].[OrganizationUser_ReadCountByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + var results = await connection.ExecuteAsync( + $"[{Schema}].[OrganizationUser_UpdateWithCollections]", + objWithCollections, + commandType: CommandType.StoredProcedure); + } + } - return results; + public async Task> GetManyByManyUsersAsync(IEnumerable userIds) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUser_ReadByUserIds]", + new { UserIds = userIds.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task> GetManyAsync(IEnumerable Ids) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUser_ReadByIds]", + new { Ids = Ids.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task GetByOrganizationEmailAsync(Guid organizationId, string email) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var results = await connection.QueryAsync( + "[dbo].[OrganizationUser_ReadByOrganizationIdEmail]", + new { OrganizationId = organizationId, Email = email }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); + } + } + + public async Task DeleteManyAsync(IEnumerable organizationUserIds) + { + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync("[dbo].[OrganizationUser_DeleteByIds]", + new { Ids = organizationUserIds.ToGuidIdArrayTVP() }, commandType: CommandType.StoredProcedure); + } + } + + public async Task UpsertManyAsync(IEnumerable organizationUsers) + { + var createUsers = new List(); + var replaceUsers = new List(); + foreach (var organizationUser in organizationUsers) + { + if (organizationUser.Id.Equals(default)) + { + createUsers.Add(organizationUser); + } + else + { + replaceUsers.Add(organizationUser); } } - public async Task GetCountByFreeOrganizationAdminUserAsync(Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteScalarAsync( - "[dbo].[OrganizationUser_ReadCountByFreeOrganizationAdminUser]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + await CreateManyAsync(createUsers); + await ReplaceManyAsync(replaceUsers); + } - return results; - } + public async Task> CreateManyAsync(IEnumerable organizationUsers) + { + if (!organizationUsers.Any()) + { + return default; } - public async Task GetCountByOnlyOwnerAsync(Guid userId) + foreach (var organizationUser in organizationUsers) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteScalarAsync( - "[dbo].[OrganizationUser_ReadCountByOnlyOwner]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); - - return results; - } + organizationUser.SetNewId(); } - public async Task GetCountByOrganizationAsync(Guid organizationId, string email, bool onlyRegisteredUsers) + var orgUsersTVP = organizationUsers.ToTvp(); + using (var connection = new SqlConnection(_marsConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var result = await connection.ExecuteScalarAsync( - "[dbo].[OrganizationUser_ReadCountByOrganizationIdEmail]", - new { OrganizationId = organizationId, Email = email, OnlyUsers = onlyRegisteredUsers }, - commandType: CommandType.StoredProcedure); - - return result; - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[{Table}_CreateMany]", + new { OrganizationUsersInput = orgUsersTVP }, + commandType: CommandType.StoredProcedure); } - public async Task> SelectKnownEmailsAsync(Guid organizationId, IEnumerable emails, - bool onlyRegisteredUsers) - { - var emailsTvp = emails.ToArrayTVP("Email"); - using (var connection = new SqlConnection(_marsConnectionString)) - { - var result = await connection.QueryAsync( - "[dbo].[OrganizationUser_SelectKnownEmails]", - new { OrganizationId = organizationId, Emails = emailsTvp, OnlyUsers = onlyRegisteredUsers }, - commandType: CommandType.StoredProcedure); + return organizationUsers.Select(u => u.Id).ToList(); + } - // Return as a list to avoid timing out the sql connection - return result.ToList(); - } + public async Task ReplaceManyAsync(IEnumerable organizationUsers) + { + if (!organizationUsers.Any()) + { + return; } - public async Task GetByOrganizationAsync(Guid organizationId, Guid userId) + var orgUsersTVP = organizationUsers.ToTvp(); + using (var connection = new SqlConnection(_marsConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUser_ReadByOrganizationIdUserId]", - new { OrganizationId = organizationId, UserId = userId }, - commandType: CommandType.StoredProcedure); - - return results.SingleOrDefault(); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[{Table}_UpdateMany]", + new { OrganizationUsersInput = orgUsersTVP }, + commandType: CommandType.StoredProcedure); } + } - public async Task> GetManyByUserAsync(Guid userId) + public async Task> GetManyPublicKeysByOrganizationUserAsync( + Guid organizationId, IEnumerable Ids) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUser_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[User_ReadPublicKeysByOrganizationUserIds]", + new { OrganizationId = organizationId, OrganizationUserIds = Ids.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetManyByOrganizationAsync(Guid organizationId, - OrganizationUserType? type) + public async Task> GetManyByMinimumRoleAsync(Guid organizationId, OrganizationUserType minRole) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUser_ReadByOrganizationId]", - new { OrganizationId = organizationId, Type = type }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[OrganizationUser_ReadByMinimumRole]", + new { OrganizationId = organizationId, MinRole = minRole }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task>> GetByIdWithCollectionsAsync(Guid id) + public async Task RevokeAsync(Guid id) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryMultipleAsync( - "[dbo].[OrganizationUser_ReadWithCollectionsById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); - - var user = (await results.ReadAsync()).SingleOrDefault(); - var collections = (await results.ReadAsync()).ToList(); - return new Tuple>(user, collections); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[{Table}_Deactivate]", + new { Id = id }, + commandType: CommandType.StoredProcedure); } + } - public async Task GetDetailsByIdAsync(Guid id) + public async Task RestoreAsync(Guid id, OrganizationUserStatusType status) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUserUserDetails_ReadById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); - - return results.SingleOrDefault(); - } - } - public async Task>> - GetDetailsByIdWithCollectionsAsync(Guid id) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryMultipleAsync( - "[dbo].[OrganizationUserUserDetails_ReadWithCollectionsById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); - - var user = (await results.ReadAsync()).SingleOrDefault(); - var collections = (await results.ReadAsync()).ToList(); - return new Tuple>(user, collections); - } - } - - public async Task> GetManyDetailsByOrganizationAsync(Guid organizationId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUserUserDetails_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - - public async Task> GetManyDetailsByUserAsync(Guid userId, - OrganizationUserStatusType? status = null) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUserOrganizationDetails_ReadByUserIdStatus]", - new { UserId = userId, Status = status }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - - public async Task GetDetailsByUserAsync(Guid userId, - Guid organizationId, OrganizationUserStatusType? status = null) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUserOrganizationDetails_ReadByUserIdStatusOrganizationId]", - new { UserId = userId, Status = status, OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); - - return results.SingleOrDefault(); - } - } - - public async Task UpdateGroupsAsync(Guid orgUserId, IEnumerable groupIds) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - "[dbo].[GroupUser_UpdateGroups]", - new { OrganizationUserId = orgUserId, GroupIds = groupIds.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task CreateAsync(OrganizationUser obj, IEnumerable collections) - { - obj.SetNewId(); - var objWithCollections = JsonSerializer.Deserialize( - JsonSerializer.Serialize(obj)); - objWithCollections.Collections = collections.ToArrayTVP(); - - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[OrganizationUser_CreateWithCollections]", - objWithCollections, - commandType: CommandType.StoredProcedure); - } - - return obj.Id; - } - - public async Task ReplaceAsync(OrganizationUser obj, IEnumerable collections) - { - var objWithCollections = JsonSerializer.Deserialize( - JsonSerializer.Serialize(obj)); - objWithCollections.Collections = collections.ToArrayTVP(); - - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[OrganizationUser_UpdateWithCollections]", - objWithCollections, - commandType: CommandType.StoredProcedure); - } - } - - public async Task> GetManyByManyUsersAsync(IEnumerable userIds) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUser_ReadByUserIds]", - new { UserIds = userIds.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - - public async Task> GetManyAsync(IEnumerable Ids) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUser_ReadByIds]", - new { Ids = Ids.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - - public async Task GetByOrganizationEmailAsync(Guid organizationId, string email) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUser_ReadByOrganizationIdEmail]", - new { OrganizationId = organizationId, Email = email }, - commandType: CommandType.StoredProcedure); - - return results.SingleOrDefault(); - } - } - - public async Task DeleteManyAsync(IEnumerable organizationUserIds) - { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync("[dbo].[OrganizationUser_DeleteByIds]", - new { Ids = organizationUserIds.ToGuidIdArrayTVP() }, commandType: CommandType.StoredProcedure); - } - } - - public async Task UpsertManyAsync(IEnumerable organizationUsers) - { - var createUsers = new List(); - var replaceUsers = new List(); - foreach (var organizationUser in organizationUsers) - { - if (organizationUser.Id.Equals(default)) - { - createUsers.Add(organizationUser); - } - else - { - replaceUsers.Add(organizationUser); - } - } - - await CreateManyAsync(createUsers); - await ReplaceManyAsync(replaceUsers); - } - - public async Task> CreateManyAsync(IEnumerable organizationUsers) - { - if (!organizationUsers.Any()) - { - return default; - } - - foreach (var organizationUser in organizationUsers) - { - organizationUser.SetNewId(); - } - - var orgUsersTVP = organizationUsers.ToTvp(); - using (var connection = new SqlConnection(_marsConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[{Table}_CreateMany]", - new { OrganizationUsersInput = orgUsersTVP }, - commandType: CommandType.StoredProcedure); - } - - return organizationUsers.Select(u => u.Id).ToList(); - } - - public async Task ReplaceManyAsync(IEnumerable organizationUsers) - { - if (!organizationUsers.Any()) - { - return; - } - - var orgUsersTVP = organizationUsers.ToTvp(); - using (var connection = new SqlConnection(_marsConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[{Table}_UpdateMany]", - new { OrganizationUsersInput = orgUsersTVP }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task> GetManyPublicKeysByOrganizationUserAsync( - Guid organizationId, IEnumerable Ids) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[User_ReadPublicKeysByOrganizationUserIds]", - new { OrganizationId = organizationId, OrganizationUserIds = Ids.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - - public async Task> GetManyByMinimumRoleAsync(Guid organizationId, OrganizationUserType minRole) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[OrganizationUser_ReadByMinimumRole]", - new { OrganizationId = organizationId, MinRole = minRole }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - - public async Task RevokeAsync(Guid id) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[{Table}_Deactivate]", - new { Id = id }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task RestoreAsync(Guid id, OrganizationUserStatusType status) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[{Table}_Activate]", - new { Id = id, Status = status }, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[{Table}_Activate]", + new { Id = id, Status = status }, + commandType: CommandType.StoredProcedure); } } } diff --git a/src/Infrastructure.Dapper/Repositories/PolicyRepository.cs b/src/Infrastructure.Dapper/Repositories/PolicyRepository.cs index 46cd6e29e..59552e51e 100644 --- a/src/Infrastructure.Dapper/Repositories/PolicyRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/PolicyRepository.cs @@ -6,83 +6,82 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class PolicyRepository : Repository, IPolicyRepository { - public class PolicyRepository : Repository, IPolicyRepository + public PolicyRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public PolicyRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type) { - public PolicyRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public PolicyRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type) + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByOrganizationIdType]", - new { OrganizationId = organizationId, Type = (byte)type }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByOrganizationIdType]", + new { OrganizationId = organizationId, Type = (byte)type }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); - } + return results.SingleOrDefault(); } + } - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetManyByUserIdAsync(Guid userId) + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetManyByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, - OrganizationUserStatusType minStatus) + public async Task> GetManyByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, + OrganizationUserStatusType minStatus) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByTypeApplicableToUser]", - new { UserId = userId, PolicyType = policyType, MinimumStatus = minStatus }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByTypeApplicableToUser]", + new { UserId = userId, PolicyType = policyType, MinimumStatus = minStatus }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task GetCountByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, - OrganizationUserStatusType minStatus) + public async Task GetCountByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, + OrganizationUserStatusType minStatus) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var result = await connection.ExecuteScalarAsync( - $"[{Schema}].[{Table}_CountByTypeApplicableToUser]", - new { UserId = userId, PolicyType = policyType, MinimumStatus = minStatus }, - commandType: CommandType.StoredProcedure); + var result = await connection.ExecuteScalarAsync( + $"[{Schema}].[{Table}_CountByTypeApplicableToUser]", + new { UserId = userId, PolicyType = policyType, MinimumStatus = minStatus }, + commandType: CommandType.StoredProcedure); - return result; - } + return result; } } } diff --git a/src/Infrastructure.Dapper/Repositories/ProviderOrganizationRepository.cs b/src/Infrastructure.Dapper/Repositories/ProviderOrganizationRepository.cs index 282f1d273..18ce67866 100644 --- a/src/Infrastructure.Dapper/Repositories/ProviderOrganizationRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/ProviderOrganizationRepository.cs @@ -6,42 +6,41 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class ProviderOrganizationRepository : Repository, IProviderOrganizationRepository { - public class ProviderOrganizationRepository : Repository, IProviderOrganizationRepository + public ProviderOrganizationRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public ProviderOrganizationRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task> GetManyDetailsByProviderAsync(Guid providerId) { - public ProviderOrganizationRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public ProviderOrganizationRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task> GetManyDetailsByProviderAsync(Guid providerId) + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[ProviderOrganizationOrganizationDetails_ReadByProviderId]", - new { ProviderId = providerId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[ProviderOrganizationOrganizationDetails_ReadByProviderId]", + new { ProviderId = providerId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task GetByOrganizationId(Guid organizationId) + public async Task GetByOrganizationId(Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[ProviderOrganization_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[ProviderOrganization_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); - } + return results.SingleOrDefault(); } } } diff --git a/src/Infrastructure.Dapper/Repositories/ProviderRepository.cs b/src/Infrastructure.Dapper/Repositories/ProviderRepository.cs index 4619771a5..3bc38727c 100644 --- a/src/Infrastructure.Dapper/Repositories/ProviderRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/ProviderRepository.cs @@ -6,42 +6,41 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class ProviderRepository : Repository, IProviderRepository { - public class ProviderRepository : Repository, IProviderRepository + public ProviderRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public ProviderRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task> SearchAsync(string name, string userEmail, int skip, int take) { - public ProviderRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public ProviderRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task> SearchAsync(string name, string userEmail, int skip, int take) + using (var connection = new SqlConnection(ReadOnlyConnectionString)) { - using (var connection = new SqlConnection(ReadOnlyConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[Provider_Search]", - new { Name = name, UserEmail = userEmail, Skip = skip, Take = take }, - commandType: CommandType.StoredProcedure, - commandTimeout: 120); + var results = await connection.QueryAsync( + "[dbo].[Provider_Search]", + new { Name = name, UserEmail = userEmail, Skip = skip, Take = take }, + commandType: CommandType.StoredProcedure, + commandTimeout: 120); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetManyAbilitiesAsync() + public async Task> GetManyAbilitiesAsync() + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[Provider_ReadAbilities]", - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[Provider_ReadAbilities]", + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } } } diff --git a/src/Infrastructure.Dapper/Repositories/ProviderUserRepository.cs b/src/Infrastructure.Dapper/Repositories/ProviderUserRepository.cs index 98375ab6a..22a475321 100644 --- a/src/Infrastructure.Dapper/Repositories/ProviderUserRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/ProviderUserRepository.cs @@ -7,158 +7,157 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class ProviderUserRepository : Repository, IProviderUserRepository { - public class ProviderUserRepository : Repository, IProviderUserRepository + public ProviderUserRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public ProviderUserRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetCountByProviderAsync(Guid providerId, string email, bool onlyRegisteredUsers) { - public ProviderUserRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public ProviderUserRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetCountByProviderAsync(Guid providerId, string email, bool onlyRegisteredUsers) + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var result = await connection.ExecuteScalarAsync( - "[dbo].[ProviderUser_ReadCountByProviderIdEmail]", - new { ProviderId = providerId, Email = email, OnlyUsers = onlyRegisteredUsers }, - commandType: CommandType.StoredProcedure); + var result = await connection.ExecuteScalarAsync( + "[dbo].[ProviderUser_ReadCountByProviderIdEmail]", + new { ProviderId = providerId, Email = email, OnlyUsers = onlyRegisteredUsers }, + commandType: CommandType.StoredProcedure); - return result; - } + return result; } + } - public async Task> GetManyAsync(IEnumerable ids) + public async Task> GetManyAsync(IEnumerable ids) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[ProviderUser_ReadByIds]", - new { Ids = ids.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[ProviderUser_ReadByIds]", + new { Ids = ids.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetManyByUserAsync(Guid userId) + public async Task> GetManyByUserAsync(Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[ProviderUser_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[ProviderUser_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task GetByProviderUserAsync(Guid providerId, Guid userId) + public async Task GetByProviderUserAsync(Guid providerId, Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[ProviderUser_ReadByProviderIdUserId]", - new { ProviderId = providerId, UserId = userId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[ProviderUser_ReadByProviderIdUserId]", + new { ProviderId = providerId, UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); - } + return results.SingleOrDefault(); } + } - public async Task> GetManyByProviderAsync(Guid providerId, ProviderUserType? type) + public async Task> GetManyByProviderAsync(Guid providerId, ProviderUserType? type) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[ProviderUser_ReadByProviderId]", - new { ProviderId = providerId, Type = type }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[ProviderUser_ReadByProviderId]", + new { ProviderId = providerId, Type = type }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetManyDetailsByProviderAsync(Guid providerId) + public async Task> GetManyDetailsByProviderAsync(Guid providerId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[ProviderUserUserDetails_ReadByProviderId]", - new { ProviderId = providerId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[ProviderUserUserDetails_ReadByProviderId]", + new { ProviderId = providerId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetManyDetailsByUserAsync(Guid userId, - ProviderUserStatusType? status = null) + public async Task> GetManyDetailsByUserAsync(Guid userId, + ProviderUserStatusType? status = null) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[ProviderUserProviderDetails_ReadByUserIdStatus]", - new { UserId = userId, Status = status }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[ProviderUserProviderDetails_ReadByUserIdStatus]", + new { UserId = userId, Status = status }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetManyOrganizationDetailsByUserAsync(Guid userId, - ProviderUserStatusType? status = null) + public async Task> GetManyOrganizationDetailsByUserAsync(Guid userId, + ProviderUserStatusType? status = null) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[ProviderUserProviderOrganizationDetails_ReadByUserIdStatus]", - new { UserId = userId, Status = status }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[ProviderUserProviderOrganizationDetails_ReadByUserIdStatus]", + new { UserId = userId, Status = status }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task DeleteManyAsync(IEnumerable providerUserIds) + public async Task DeleteManyAsync(IEnumerable providerUserIds) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync("[dbo].[ProviderUser_DeleteByIds]", - new { Ids = providerUserIds.ToGuidIdArrayTVP() }, commandType: CommandType.StoredProcedure); - } + await connection.ExecuteAsync("[dbo].[ProviderUser_DeleteByIds]", + new { Ids = providerUserIds.ToGuidIdArrayTVP() }, commandType: CommandType.StoredProcedure); } + } - public async Task> GetManyPublicKeysByProviderUserAsync( - Guid providerId, IEnumerable Ids) + public async Task> GetManyPublicKeysByProviderUserAsync( + Guid providerId, IEnumerable Ids) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[User_ReadPublicKeysByProviderUserIds]", - new { ProviderId = providerId, ProviderUserIds = Ids.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + "[dbo].[User_ReadPublicKeysByProviderUserIds]", + new { ProviderId = providerId, ProviderUserIds = Ids.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task GetCountByOnlyOwnerAsync(Guid userId) + public async Task GetCountByOnlyOwnerAsync(Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteScalarAsync( - "[dbo].[ProviderUser_ReadCountByOnlyOwner]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + var results = await connection.ExecuteScalarAsync( + "[dbo].[ProviderUser_ReadCountByOnlyOwner]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); - return results; - } + return results; } } } diff --git a/src/Infrastructure.Dapper/Repositories/Repository.cs b/src/Infrastructure.Dapper/Repositories/Repository.cs index 4bc0b91b1..0c46a6d0a 100644 --- a/src/Infrastructure.Dapper/Repositories/Repository.cs +++ b/src/Infrastructure.Dapper/Repositories/Repository.cs @@ -4,92 +4,91 @@ using Bit.Core.Entities; using Bit.Core.Repositories; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public abstract class Repository : BaseRepository, IRepository + where TId : IEquatable + where T : class, ITableObject { - public abstract class Repository : BaseRepository, IRepository - where TId : IEquatable - where T : class, ITableObject + public Repository(string connectionString, string readOnlyConnectionString, + string schema = null, string table = null) + : base(connectionString, readOnlyConnectionString) { - public Repository(string connectionString, string readOnlyConnectionString, - string schema = null, string table = null) - : base(connectionString, readOnlyConnectionString) + if (!string.IsNullOrWhiteSpace(table)) { - if (!string.IsNullOrWhiteSpace(table)) - { - Table = table; - } - - if (!string.IsNullOrWhiteSpace(schema)) - { - Schema = schema; - } + Table = table; } - protected string Schema { get; private set; } = "dbo"; - protected string Table { get; private set; } = typeof(T).Name; - - public virtual async Task GetByIdAsync(TId id) + if (!string.IsNullOrWhiteSpace(schema)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); - - return results.SingleOrDefault(); - } + Schema = schema; } + } - public virtual async Task CreateAsync(T obj) + protected string Schema { get; private set; } = "dbo"; + protected string Table { get; private set; } = typeof(T).Name; + + public virtual async Task GetByIdAsync(TId id) + { + using (var connection = new SqlConnection(ConnectionString)) { - obj.SetNewId(); - using (var connection = new SqlConnection(ConnectionString)) - { - var parameters = new DynamicParameters(); - parameters.AddDynamicParams(obj); - parameters.Add("Id", obj.Id, direction: ParameterDirection.InputOutput); - var results = await connection.ExecuteAsync( - $"[{Schema}].[{Table}_Create]", - parameters, - commandType: CommandType.StoredProcedure); - obj.Id = parameters.Get(nameof(obj.Id)); - } - return obj; + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); } + } - public virtual async Task ReplaceAsync(T obj) + public virtual async Task CreateAsync(T obj) + { + obj.SetNewId(); + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[{Table}_Update]", - obj, - commandType: CommandType.StoredProcedure); - } + var parameters = new DynamicParameters(); + parameters.AddDynamicParams(obj); + parameters.Add("Id", obj.Id, direction: ParameterDirection.InputOutput); + var results = await connection.ExecuteAsync( + $"[{Schema}].[{Table}_Create]", + parameters, + commandType: CommandType.StoredProcedure); + obj.Id = parameters.Get(nameof(obj.Id)); } + return obj; + } - public virtual async Task UpsertAsync(T obj) + public virtual async Task ReplaceAsync(T obj) + { + using (var connection = new SqlConnection(ConnectionString)) { - if (obj.Id.Equals(default(TId))) - { - await CreateAsync(obj); - } - else - { - await ReplaceAsync(obj); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[{Table}_Update]", + obj, + commandType: CommandType.StoredProcedure); } + } - public virtual async Task DeleteAsync(T obj) + public virtual async Task UpsertAsync(T obj) + { + if (obj.Id.Equals(default(TId))) { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync( - $"[{Schema}].[{Table}_DeleteById]", - new { Id = obj.Id }, - commandType: CommandType.StoredProcedure); - } + await CreateAsync(obj); + } + else + { + await ReplaceAsync(obj); + } + } + + public virtual async Task DeleteAsync(T obj) + { + using (var connection = new SqlConnection(ConnectionString)) + { + await connection.ExecuteAsync( + $"[{Schema}].[{Table}_DeleteById]", + new { Id = obj.Id }, + commandType: CommandType.StoredProcedure); } } } diff --git a/src/Infrastructure.Dapper/Repositories/SendRepository.cs b/src/Infrastructure.Dapper/Repositories/SendRepository.cs index 6d8ba6c19..b64af45cd 100644 --- a/src/Infrastructure.Dapper/Repositories/SendRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/SendRepository.cs @@ -5,42 +5,41 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class SendRepository : Repository, ISendRepository { - public class SendRepository : Repository, ISendRepository + public SendRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public SendRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task> GetManyByUserIdAsync(Guid userId) { - public SendRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public SendRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task> GetManyByUserIdAsync(Guid userId) + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[Send_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[Send_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetManyByDeletionDateAsync(DateTime deletionDateBefore) + public async Task> GetManyByDeletionDateAsync(DateTime deletionDateBefore) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[Send_ReadByDeletionDateBefore]", - new { DeletionDate = deletionDateBefore }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[Send_ReadByDeletionDateBefore]", + new { DeletionDate = deletionDateBefore }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } } } diff --git a/src/Infrastructure.Dapper/Repositories/SsoConfigRepository.cs b/src/Infrastructure.Dapper/Repositories/SsoConfigRepository.cs index 70d527c1c..3b8a5a904 100644 --- a/src/Infrastructure.Dapper/Repositories/SsoConfigRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/SsoConfigRepository.cs @@ -5,55 +5,54 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class SsoConfigRepository : Repository, ISsoConfigRepository { - public class SsoConfigRepository : Repository, ISsoConfigRepository + public SsoConfigRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public SsoConfigRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task GetByOrganizationIdAsync(Guid organizationId) { - public SsoConfigRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public SsoConfigRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task GetByOrganizationIdAsync(Guid organizationId) + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); - } + return results.SingleOrDefault(); } + } - public async Task GetByIdentifierAsync(string identifier) + public async Task GetByIdentifierAsync(string identifier) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByIdentifier]", - new { Identifier = identifier }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByIdentifier]", + new { Identifier = identifier }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); - } + return results.SingleOrDefault(); } + } - public async Task> GetManyByRevisionNotBeforeDate(DateTime? notBefore) + public async Task> GetManyByRevisionNotBeforeDate(DateTime? notBefore) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadManyByNotBeforeRevisionDate]", - new { NotBefore = notBefore }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadManyByNotBeforeRevisionDate]", + new { NotBefore = notBefore }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } } } diff --git a/src/Infrastructure.Dapper/Repositories/SsoUserRepository.cs b/src/Infrastructure.Dapper/Repositories/SsoUserRepository.cs index fd32c708d..e393762fa 100644 --- a/src/Infrastructure.Dapper/Repositories/SsoUserRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/SsoUserRepository.cs @@ -5,40 +5,39 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class SsoUserRepository : Repository, ISsoUserRepository { - public class SsoUserRepository : Repository, ISsoUserRepository + public SsoUserRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public SsoUserRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task DeleteAsync(Guid userId, Guid? organizationId) { - public SsoUserRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public SsoUserRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task DeleteAsync(Guid userId, Guid? organizationId) + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[SsoUser_Delete]", - new { UserId = userId, OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[SsoUser_Delete]", + new { UserId = userId, OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); } + } - public async Task GetByUserIdOrganizationIdAsync(Guid organizationId, Guid userId) + public async Task GetByUserIdOrganizationIdAsync(Guid organizationId, Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[SsoUser_ReadByUserIdOrganizationId]", - new { UserId = userId, OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[SsoUser_ReadByUserIdOrganizationId]", + new { UserId = userId, OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); - } + return results.SingleOrDefault(); } } } diff --git a/src/Infrastructure.Dapper/Repositories/TaxRateRepository.cs b/src/Infrastructure.Dapper/Repositories/TaxRateRepository.cs index 1c9982e01..7a9ad7d09 100644 --- a/src/Infrastructure.Dapper/Repositories/TaxRateRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/TaxRateRepository.cs @@ -5,65 +5,64 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class TaxRateRepository : Repository, ITaxRateRepository { - public class TaxRateRepository : Repository, ITaxRateRepository + public TaxRateRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public TaxRateRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task> SearchAsync(int skip, int count) { - public TaxRateRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public TaxRateRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task> SearchAsync(int skip, int count) + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[TaxRate_Search]", - new { Skip = skip, Count = count }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[TaxRate_Search]", + new { Skip = skip, Count = count }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetAllActiveAsync() + public async Task> GetAllActiveAsync() + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[TaxRate_ReadAllActive]", - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[TaxRate_ReadAllActive]", + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task ArchiveAsync(TaxRate model) + public async Task ArchiveAsync(TaxRate model) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.ExecuteAsync( - $"[{Schema}].[TaxRate_Archive]", - new { Id = model.Id }, - commandType: CommandType.StoredProcedure); - } + var results = await connection.ExecuteAsync( + $"[{Schema}].[TaxRate_Archive]", + new { Id = model.Id }, + commandType: CommandType.StoredProcedure); } + } - public async Task> GetByLocationAsync(TaxRate model) + public async Task> GetByLocationAsync(TaxRate model) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[TaxRate_ReadByLocation]", - new { Country = model.Country, PostalCode = model.PostalCode }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[TaxRate_ReadByLocation]", + new { Country = model.Country, PostalCode = model.PostalCode }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } } } diff --git a/src/Infrastructure.Dapper/Repositories/TransactionRepository.cs b/src/Infrastructure.Dapper/Repositories/TransactionRepository.cs index ed8b16f91..ff9c900bf 100644 --- a/src/Infrastructure.Dapper/Repositories/TransactionRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/TransactionRepository.cs @@ -6,55 +6,54 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class TransactionRepository : Repository, ITransactionRepository { - public class TransactionRepository : Repository, ITransactionRepository + public TransactionRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public TransactionRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public async Task> GetManyByUserIdAsync(Guid userId) { - public TransactionRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } - - public TransactionRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public async Task> GetManyByUserIdAsync(Guid userId) + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[Transaction_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[Transaction_ReadByUserId]", + new { UserId = userId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[Transaction_ReadByOrganizationId]", - new { OrganizationId = organizationId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[Transaction_ReadByOrganizationId]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task GetByGatewayIdAsync(GatewayType gatewayType, string gatewayId) + public async Task GetByGatewayIdAsync(GatewayType gatewayType, string gatewayId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[Transaction_ReadByGatewayId]", - new { Gateway = gatewayType, GatewayId = gatewayId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[Transaction_ReadByGatewayId]", + new { Gateway = gatewayType, GatewayId = gatewayId }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); - } + return results.SingleOrDefault(); } } } diff --git a/src/Infrastructure.Dapper/Repositories/UserRepository.cs b/src/Infrastructure.Dapper/Repositories/UserRepository.cs index 077fdd59a..19c7a83be 100644 --- a/src/Infrastructure.Dapper/Repositories/UserRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/UserRepository.cs @@ -6,166 +6,165 @@ using Bit.Core.Repositories; using Bit.Core.Settings; using Dapper; -namespace Bit.Infrastructure.Dapper.Repositories +namespace Bit.Infrastructure.Dapper.Repositories; + +public class UserRepository : Repository, IUserRepository { - public class UserRepository : Repository, IUserRepository + public UserRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { } + + public UserRepository(string connectionString, string readOnlyConnectionString) + : base(connectionString, readOnlyConnectionString) + { } + + public override async Task GetByIdAsync(Guid id) { - public UserRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) - { } + return await base.GetByIdAsync(id); + } - public UserRepository(string connectionString, string readOnlyConnectionString) - : base(connectionString, readOnlyConnectionString) - { } - - public override async Task GetByIdAsync(Guid id) + public async Task GetByEmailAsync(string email) + { + using (var connection = new SqlConnection(ConnectionString)) { - return await base.GetByIdAsync(id); + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByEmail]", + new { Email = email }, + commandType: CommandType.StoredProcedure); + + return results.SingleOrDefault(); } + } - public async Task GetByEmailAsync(string email) + public async Task GetBySsoUserAsync(string externalId, Guid? organizationId) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByEmail]", - new { Email = email }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadBySsoUserOrganizationIdExternalId]", + new { OrganizationId = organizationId, ExternalId = externalId }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); - } + return results.SingleOrDefault(); } + } - public async Task GetBySsoUserAsync(string externalId, Guid? organizationId) + public async Task GetKdfInformationByEmailAsync(string email) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadBySsoUserOrganizationIdExternalId]", - new { OrganizationId = organizationId, ExternalId = externalId }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadKdfByEmail]", + new { Email = email }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); - } + return results.SingleOrDefault(); } + } - public async Task GetKdfInformationByEmailAsync(string email) + public async Task> SearchAsync(string email, int skip, int take) + { + using (var connection = new SqlConnection(ReadOnlyConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadKdfByEmail]", - new { Email = email }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_Search]", + new { Email = email, Skip = skip, Take = take }, + commandType: CommandType.StoredProcedure, + commandTimeout: 120); - return results.SingleOrDefault(); - } + return results.ToList(); } + } - public async Task> SearchAsync(string email, int skip, int take) + public async Task> GetManyByPremiumAsync(bool premium) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ReadOnlyConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_Search]", - new { Email = email, Skip = skip, Take = take }, - commandType: CommandType.StoredProcedure, - commandTimeout: 120); + var results = await connection.QueryAsync( + "[dbo].[User_ReadByPremium]", + new { Premium = premium }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.ToList(); } + } - public async Task> GetManyByPremiumAsync(bool premium) + public async Task GetPublicKeyAsync(Guid id) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - "[dbo].[User_ReadByPremium]", - new { Premium = premium }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadPublicKeyById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); - return results.ToList(); - } + return results.SingleOrDefault(); } + } - public async Task GetPublicKeyAsync(Guid id) + public async Task GetAccountRevisionDateAsync(Guid id) + { + using (var connection = new SqlConnection(ReadOnlyConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadPublicKeyById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadAccountRevisionDateById]", + new { Id = id }, + commandType: CommandType.StoredProcedure); - return results.SingleOrDefault(); - } + return results.SingleOrDefault(); } + } - public async Task GetAccountRevisionDateAsync(Guid id) + public override async Task ReplaceAsync(User user) + { + await base.ReplaceAsync(user); + } + + public override async Task DeleteAsync(User user) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ReadOnlyConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadAccountRevisionDateById]", - new { Id = id }, - commandType: CommandType.StoredProcedure); - - return results.SingleOrDefault(); - } + await connection.ExecuteAsync( + $"[{Schema}].[{Table}_DeleteById]", + new { Id = user.Id }, + commandType: CommandType.StoredProcedure, + commandTimeout: 180); } + } - public override async Task ReplaceAsync(User user) + public async Task UpdateStorageAsync(Guid id) + { + using (var connection = new SqlConnection(ConnectionString)) { - await base.ReplaceAsync(user); + await connection.ExecuteAsync( + $"[{Schema}].[{Table}_UpdateStorage]", + new { Id = id }, + commandType: CommandType.StoredProcedure, + commandTimeout: 180); } + } - public override async Task DeleteAsync(User user) + public async Task UpdateRenewalReminderDateAsync(Guid id, DateTime renewalReminderDate) + { + using (var connection = new SqlConnection(ConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync( - $"[{Schema}].[{Table}_DeleteById]", - new { Id = user.Id }, - commandType: CommandType.StoredProcedure, - commandTimeout: 180); - } + await connection.ExecuteAsync( + $"[{Schema}].[User_UpdateRenewalReminderDate]", + new { Id = id, RenewalReminderDate = renewalReminderDate }, + commandType: CommandType.StoredProcedure); } + } - public async Task UpdateStorageAsync(Guid id) + public async Task> GetManyAsync(IEnumerable ids) + { + using (var connection = new SqlConnection(ReadOnlyConnectionString)) { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync( - $"[{Schema}].[{Table}_UpdateStorage]", - new { Id = id }, - commandType: CommandType.StoredProcedure, - commandTimeout: 180); - } - } + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadByIds]", + new { Ids = ids.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); - public async Task UpdateRenewalReminderDateAsync(Guid id, DateTime renewalReminderDate) - { - using (var connection = new SqlConnection(ConnectionString)) - { - await connection.ExecuteAsync( - $"[{Schema}].[User_UpdateRenewalReminderDate]", - new { Id = id, RenewalReminderDate = renewalReminderDate }, - commandType: CommandType.StoredProcedure); - } - } - - public async Task> GetManyAsync(IEnumerable ids) - { - using (var connection = new SqlConnection(ReadOnlyConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[{Table}_ReadByIds]", - new { Ids = ids.ToGuidIdArrayTVP() }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } + return results.ToList(); } } } diff --git a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs index 259deb2e1..c8a99b274 100644 --- a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs +++ b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs @@ -5,62 +5,61 @@ using LinqToDB.EntityFrameworkCore; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework -{ - public static class EntityFrameworkServiceCollectionExtensions - { - public static void AddEFRepositories(this IServiceCollection services, bool selfHosted, string connectionString, - SupportedDatabaseProviders provider) - { - if (string.IsNullOrWhiteSpace(connectionString)) - { - throw new Exception($"Database provider type {provider} was selected but no connection string was found."); - } - LinqToDBForEFTools.Initialize(); - services.AddAutoMapper(typeof(UserRepository)); - services.AddDbContext(options => - { - if (provider == SupportedDatabaseProviders.Postgres) - { - options.UseNpgsql(connectionString); - // Handle NpgSql Legacy Support for `timestamp without timezone` issue - AppContext.SetSwitch("Npgsql.EnableLegacyTimestampBehavior", true); - } - else if (provider == SupportedDatabaseProviders.MySql) - { - options.UseMySql(connectionString, ServerVersion.AutoDetect(connectionString)); - } - }); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); +namespace Bit.Infrastructure.EntityFramework; - if (selfHosted) +public static class EntityFrameworkServiceCollectionExtensions +{ + public static void AddEFRepositories(this IServiceCollection services, bool selfHosted, string connectionString, + SupportedDatabaseProviders provider) + { + if (string.IsNullOrWhiteSpace(connectionString)) + { + throw new Exception($"Database provider type {provider} was selected but no connection string was found."); + } + LinqToDBForEFTools.Initialize(); + services.AddAutoMapper(typeof(UserRepository)); + services.AddDbContext(options => + { + if (provider == SupportedDatabaseProviders.Postgres) { - services.AddSingleton(); + options.UseNpgsql(connectionString); + // Handle NpgSql Legacy Support for `timestamp without timezone` issue + AppContext.SetSwitch("Npgsql.EnableLegacyTimestampBehavior", true); } + else if (provider == SupportedDatabaseProviders.MySql) + { + options.UseMySql(connectionString, ServerVersion.AutoDetect(connectionString)); + } + }); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + + if (selfHosted) + { + services.AddSingleton(); } } } diff --git a/src/Infrastructure.EntityFramework/Models/Cipher.cs b/src/Infrastructure.EntityFramework/Models/Cipher.cs index 4cf008d52..ec5ddc53d 100644 --- a/src/Infrastructure.EntityFramework/Models/Cipher.cs +++ b/src/Infrastructure.EntityFramework/Models/Cipher.cs @@ -1,19 +1,18 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class Cipher : Core.Entities.Cipher - { - public virtual User User { get; set; } - public virtual Organization Organization { get; set; } - public virtual ICollection CollectionCiphers { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class CipherMapperProfile : Profile +public class Cipher : Core.Entities.Cipher +{ + public virtual User User { get; set; } + public virtual Organization Organization { get; set; } + public virtual ICollection CollectionCiphers { get; set; } +} + +public class CipherMapperProfile : Profile +{ + public CipherMapperProfile() { - public CipherMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/Collection.cs b/src/Infrastructure.EntityFramework/Models/Collection.cs index 2e4337238..29495081d 100644 --- a/src/Infrastructure.EntityFramework/Models/Collection.cs +++ b/src/Infrastructure.EntityFramework/Models/Collection.cs @@ -1,20 +1,19 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class Collection : Core.Entities.Collection - { - public virtual Organization Organization { get; set; } - public virtual ICollection CollectionUsers { get; set; } - public virtual ICollection CollectionCiphers { get; set; } - public virtual ICollection CollectionGroups { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class CollectionMapperProfile : Profile +public class Collection : Core.Entities.Collection +{ + public virtual Organization Organization { get; set; } + public virtual ICollection CollectionUsers { get; set; } + public virtual ICollection CollectionCiphers { get; set; } + public virtual ICollection CollectionGroups { get; set; } +} + +public class CollectionMapperProfile : Profile +{ + public CollectionMapperProfile() { - public CollectionMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/CollectionCipher.cs b/src/Infrastructure.EntityFramework/Models/CollectionCipher.cs index 8a7de5a78..93d1deae1 100644 --- a/src/Infrastructure.EntityFramework/Models/CollectionCipher.cs +++ b/src/Infrastructure.EntityFramework/Models/CollectionCipher.cs @@ -1,18 +1,17 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class CollectionCipher : Core.Entities.CollectionCipher - { - public virtual Cipher Cipher { get; set; } - public virtual Collection Collection { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class CollectionCipherMapperProfile : Profile +public class CollectionCipher : Core.Entities.CollectionCipher +{ + public virtual Cipher Cipher { get; set; } + public virtual Collection Collection { get; set; } +} + +public class CollectionCipherMapperProfile : Profile +{ + public CollectionCipherMapperProfile() { - public CollectionCipherMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/CollectionGroup.cs b/src/Infrastructure.EntityFramework/Models/CollectionGroup.cs index fdded3521..623a5d808 100644 --- a/src/Infrastructure.EntityFramework/Models/CollectionGroup.cs +++ b/src/Infrastructure.EntityFramework/Models/CollectionGroup.cs @@ -1,18 +1,17 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class CollectionGroup : Core.Entities.CollectionGroup - { - public virtual Collection Collection { get; set; } - public virtual Group Group { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class CollectionGroupMapperProfile : Profile +public class CollectionGroup : Core.Entities.CollectionGroup +{ + public virtual Collection Collection { get; set; } + public virtual Group Group { get; set; } +} + +public class CollectionGroupMapperProfile : Profile +{ + public CollectionGroupMapperProfile() { - public CollectionGroupMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/CollectionUser.cs b/src/Infrastructure.EntityFramework/Models/CollectionUser.cs index 24d10c2a7..308673492 100644 --- a/src/Infrastructure.EntityFramework/Models/CollectionUser.cs +++ b/src/Infrastructure.EntityFramework/Models/CollectionUser.cs @@ -1,18 +1,17 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class CollectionUser : Core.Entities.CollectionUser - { - public virtual Collection Collection { get; set; } - public virtual OrganizationUser OrganizationUser { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class CollectionUserMapperProfile : Profile +public class CollectionUser : Core.Entities.CollectionUser +{ + public virtual Collection Collection { get; set; } + public virtual OrganizationUser OrganizationUser { get; set; } +} + +public class CollectionUserMapperProfile : Profile +{ + public CollectionUserMapperProfile() { - public CollectionUserMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/Device.cs b/src/Infrastructure.EntityFramework/Models/Device.cs index 675ed917a..1eace238d 100644 --- a/src/Infrastructure.EntityFramework/Models/Device.cs +++ b/src/Infrastructure.EntityFramework/Models/Device.cs @@ -1,17 +1,16 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class Device : Core.Entities.Device - { - public virtual User User { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class DeviceMapperProfile : Profile +public class Device : Core.Entities.Device +{ + public virtual User User { get; set; } +} + +public class DeviceMapperProfile : Profile +{ + public DeviceMapperProfile() { - public DeviceMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/EmergencyAccess.cs b/src/Infrastructure.EntityFramework/Models/EmergencyAccess.cs index e92eba8ee..867912c5e 100644 --- a/src/Infrastructure.EntityFramework/Models/EmergencyAccess.cs +++ b/src/Infrastructure.EntityFramework/Models/EmergencyAccess.cs @@ -1,18 +1,17 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class EmergencyAccess : Core.Entities.EmergencyAccess - { - public virtual User Grantee { get; set; } - public virtual User Grantor { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class EmergencyAccessMapperProfile : Profile +public class EmergencyAccess : Core.Entities.EmergencyAccess +{ + public virtual User Grantee { get; set; } + public virtual User Grantor { get; set; } +} + +public class EmergencyAccessMapperProfile : Profile +{ + public EmergencyAccessMapperProfile() { - public EmergencyAccessMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/Event.cs b/src/Infrastructure.EntityFramework/Models/Event.cs index 558f2a285..b7bad9c78 100644 --- a/src/Infrastructure.EntityFramework/Models/Event.cs +++ b/src/Infrastructure.EntityFramework/Models/Event.cs @@ -1,16 +1,15 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class Event : Core.Entities.Event - { - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class EventMapperProfile : Profile +public class Event : Core.Entities.Event +{ +} + +public class EventMapperProfile : Profile +{ + public EventMapperProfile() { - public EventMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/Folder.cs b/src/Infrastructure.EntityFramework/Models/Folder.cs index 1918dfe73..466833785 100644 --- a/src/Infrastructure.EntityFramework/Models/Folder.cs +++ b/src/Infrastructure.EntityFramework/Models/Folder.cs @@ -1,17 +1,16 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class Folder : Core.Entities.Folder - { - public virtual User User { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class FolderMapperProfile : Profile +public class Folder : Core.Entities.Folder +{ + public virtual User User { get; set; } +} + +public class FolderMapperProfile : Profile +{ + public FolderMapperProfile() { - public FolderMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/Grant.cs b/src/Infrastructure.EntityFramework/Models/Grant.cs index 251d16437..78b4b4582 100644 --- a/src/Infrastructure.EntityFramework/Models/Grant.cs +++ b/src/Infrastructure.EntityFramework/Models/Grant.cs @@ -1,16 +1,15 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class Grant : Core.Entities.Grant - { - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class GrantMapperProfile : Profile +public class Grant : Core.Entities.Grant +{ +} + +public class GrantMapperProfile : Profile +{ + public GrantMapperProfile() { - public GrantMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/Group.cs b/src/Infrastructure.EntityFramework/Models/Group.cs index 98f482012..eaa41bed8 100644 --- a/src/Infrastructure.EntityFramework/Models/Group.cs +++ b/src/Infrastructure.EntityFramework/Models/Group.cs @@ -1,18 +1,17 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class Group : Core.Entities.Group - { - public virtual Organization Organization { get; set; } - public virtual ICollection GroupUsers { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class GroupMapperProfile : Profile +public class Group : Core.Entities.Group +{ + public virtual Organization Organization { get; set; } + public virtual ICollection GroupUsers { get; set; } +} + +public class GroupMapperProfile : Profile +{ + public GroupMapperProfile() { - public GroupMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/GroupUser.cs b/src/Infrastructure.EntityFramework/Models/GroupUser.cs index 5a81ed884..3f25e7d87 100644 --- a/src/Infrastructure.EntityFramework/Models/GroupUser.cs +++ b/src/Infrastructure.EntityFramework/Models/GroupUser.cs @@ -1,19 +1,18 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class GroupUser : Core.Entities.GroupUser - { - public virtual Group Group { get; set; } - public virtual OrganizationUser OrganizationUser { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class GroupUserMapperProfile : Profile +public class GroupUser : Core.Entities.GroupUser +{ + public virtual Group Group { get; set; } + public virtual OrganizationUser OrganizationUser { get; set; } +} + +public class GroupUserMapperProfile : Profile +{ + public GroupUserMapperProfile() { - public GroupUserMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/Installation.cs b/src/Infrastructure.EntityFramework/Models/Installation.cs index 92bbd2abb..35223a33d 100644 --- a/src/Infrastructure.EntityFramework/Models/Installation.cs +++ b/src/Infrastructure.EntityFramework/Models/Installation.cs @@ -1,16 +1,15 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class Installation : Core.Entities.Installation - { - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class InstallationMapperProfile : Profile +public class Installation : Core.Entities.Installation +{ +} + +public class InstallationMapperProfile : Profile +{ + public InstallationMapperProfile() { - public InstallationMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/Organization.cs b/src/Infrastructure.EntityFramework/Models/Organization.cs index 3d46027ef..c1969cab0 100644 --- a/src/Infrastructure.EntityFramework/Models/Organization.cs +++ b/src/Infrastructure.EntityFramework/Models/Organization.cs @@ -1,25 +1,24 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class Organization : Core.Entities.Organization - { - public virtual ICollection Ciphers { get; set; } - public virtual ICollection OrganizationUsers { get; set; } - public virtual ICollection Groups { get; set; } - public virtual ICollection Policies { get; set; } - public virtual ICollection SsoConfigs { get; set; } - public virtual ICollection SsoUsers { get; set; } - public virtual ICollection Transactions { get; set; } - public virtual ICollection ApiKeys { get; set; } - public virtual ICollection Connections { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class OrganizationMapperProfile : Profile +public class Organization : Core.Entities.Organization +{ + public virtual ICollection Ciphers { get; set; } + public virtual ICollection OrganizationUsers { get; set; } + public virtual ICollection Groups { get; set; } + public virtual ICollection Policies { get; set; } + public virtual ICollection SsoConfigs { get; set; } + public virtual ICollection SsoUsers { get; set; } + public virtual ICollection Transactions { get; set; } + public virtual ICollection ApiKeys { get; set; } + public virtual ICollection Connections { get; set; } +} + +public class OrganizationMapperProfile : Profile +{ + public OrganizationMapperProfile() { - public OrganizationMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/OrganizationApiKey.cs b/src/Infrastructure.EntityFramework/Models/OrganizationApiKey.cs index c0e6c33e0..b8a4f4e74 100644 --- a/src/Infrastructure.EntityFramework/Models/OrganizationApiKey.cs +++ b/src/Infrastructure.EntityFramework/Models/OrganizationApiKey.cs @@ -1,17 +1,16 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class OrganizationApiKey : Core.Entities.OrganizationApiKey - { - public virtual Organization Organization { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class OrganizationApiKeyMapperProfile : Profile +public class OrganizationApiKey : Core.Entities.OrganizationApiKey +{ + public virtual Organization Organization { get; set; } +} + +public class OrganizationApiKeyMapperProfile : Profile +{ + public OrganizationApiKeyMapperProfile() { - public OrganizationApiKeyMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/OrganizationConnection.cs b/src/Infrastructure.EntityFramework/Models/OrganizationConnection.cs index f53ee711c..5c41d5f6c 100644 --- a/src/Infrastructure.EntityFramework/Models/OrganizationConnection.cs +++ b/src/Infrastructure.EntityFramework/Models/OrganizationConnection.cs @@ -1,17 +1,16 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class OrganizationConnection : Core.Entities.OrganizationConnection - { - public virtual Organization Organization { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class OrganizationConnectionMapperProfile : Profile +public class OrganizationConnection : Core.Entities.OrganizationConnection +{ + public virtual Organization Organization { get; set; } +} + +public class OrganizationConnectionMapperProfile : Profile +{ + public OrganizationConnectionMapperProfile() { - public OrganizationConnectionMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/OrganizationSponsorship.cs b/src/Infrastructure.EntityFramework/Models/OrganizationSponsorship.cs index c9eee03e5..3d8b8acf7 100644 --- a/src/Infrastructure.EntityFramework/Models/OrganizationSponsorship.cs +++ b/src/Infrastructure.EntityFramework/Models/OrganizationSponsorship.cs @@ -1,18 +1,17 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class OrganizationSponsorship : Core.Entities.OrganizationSponsorship - { - public virtual Organization SponsoringOrganization { get; set; } - public virtual Organization SponsoredOrganization { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class OrganizationSponsorshipMapperProfile : Profile +public class OrganizationSponsorship : Core.Entities.OrganizationSponsorship +{ + public virtual Organization SponsoringOrganization { get; set; } + public virtual Organization SponsoredOrganization { get; set; } +} + +public class OrganizationSponsorshipMapperProfile : Profile +{ + public OrganizationSponsorshipMapperProfile() { - public OrganizationSponsorshipMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/OrganizationUser.cs b/src/Infrastructure.EntityFramework/Models/OrganizationUser.cs index f1489bd46..abab1a4d5 100644 --- a/src/Infrastructure.EntityFramework/Models/OrganizationUser.cs +++ b/src/Infrastructure.EntityFramework/Models/OrganizationUser.cs @@ -1,19 +1,18 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class OrganizationUser : Core.Entities.OrganizationUser - { - public virtual Organization Organization { get; set; } - public virtual User User { get; set; } - public virtual ICollection CollectionUsers { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class OrganizationUserMapperProfile : Profile +public class OrganizationUser : Core.Entities.OrganizationUser +{ + public virtual Organization Organization { get; set; } + public virtual User User { get; set; } + public virtual ICollection CollectionUsers { get; set; } +} + +public class OrganizationUserMapperProfile : Profile +{ + public OrganizationUserMapperProfile() { - public OrganizationUserMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/Policy.cs b/src/Infrastructure.EntityFramework/Models/Policy.cs index 953556cdd..22b17c6f6 100644 --- a/src/Infrastructure.EntityFramework/Models/Policy.cs +++ b/src/Infrastructure.EntityFramework/Models/Policy.cs @@ -1,17 +1,16 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class Policy : Core.Entities.Policy - { - public virtual Organization Organization { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class PolicyMapperProfile : Profile +public class Policy : Core.Entities.Policy +{ + public virtual Organization Organization { get; set; } +} + +public class PolicyMapperProfile : Profile +{ + public PolicyMapperProfile() { - public PolicyMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/Provider/Provider.cs b/src/Infrastructure.EntityFramework/Models/Provider/Provider.cs index 8efa1558d..d639d6d01 100644 --- a/src/Infrastructure.EntityFramework/Models/Provider/Provider.cs +++ b/src/Infrastructure.EntityFramework/Models/Provider/Provider.cs @@ -1,16 +1,15 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class Provider : Core.Entities.Provider.Provider - { - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class ProviderMapperProfile : Profile +public class Provider : Core.Entities.Provider.Provider +{ +} + +public class ProviderMapperProfile : Profile +{ + public ProviderMapperProfile() { - public ProviderMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/Provider/ProviderOrganization.cs b/src/Infrastructure.EntityFramework/Models/Provider/ProviderOrganization.cs index 13aa52110..af23ba978 100644 --- a/src/Infrastructure.EntityFramework/Models/Provider/ProviderOrganization.cs +++ b/src/Infrastructure.EntityFramework/Models/Provider/ProviderOrganization.cs @@ -1,18 +1,17 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class ProviderOrganization : Core.Entities.Provider.ProviderOrganization - { - public virtual Provider Provider { get; set; } - public virtual Organization Organization { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class ProviderOrganizationMapperProfile : Profile +public class ProviderOrganization : Core.Entities.Provider.ProviderOrganization +{ + public virtual Provider Provider { get; set; } + public virtual Organization Organization { get; set; } +} + +public class ProviderOrganizationMapperProfile : Profile +{ + public ProviderOrganizationMapperProfile() { - public ProviderOrganizationMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/Provider/ProviderUser.cs b/src/Infrastructure.EntityFramework/Models/Provider/ProviderUser.cs index 9aac138be..5c53c4d97 100644 --- a/src/Infrastructure.EntityFramework/Models/Provider/ProviderUser.cs +++ b/src/Infrastructure.EntityFramework/Models/Provider/ProviderUser.cs @@ -1,18 +1,17 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class ProviderUser : Core.Entities.Provider.ProviderUser - { - public virtual User User { get; set; } - public virtual Provider Provider { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class ProviderUserMapperProfile : Profile +public class ProviderUser : Core.Entities.Provider.ProviderUser +{ + public virtual User User { get; set; } + public virtual Provider Provider { get; set; } +} + +public class ProviderUserMapperProfile : Profile +{ + public ProviderUserMapperProfile() { - public ProviderUserMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/Role.cs b/src/Infrastructure.EntityFramework/Models/Role.cs index a92682e2e..4cc2e099c 100644 --- a/src/Infrastructure.EntityFramework/Models/Role.cs +++ b/src/Infrastructure.EntityFramework/Models/Role.cs @@ -1,16 +1,15 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class Role : Core.Entities.Role - { - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class RoleMapperProfile : Profile +public class Role : Core.Entities.Role +{ +} + +public class RoleMapperProfile : Profile +{ + public RoleMapperProfile() { - public RoleMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/Send.cs b/src/Infrastructure.EntityFramework/Models/Send.cs index 5732ac2a1..13bfbb61b 100644 --- a/src/Infrastructure.EntityFramework/Models/Send.cs +++ b/src/Infrastructure.EntityFramework/Models/Send.cs @@ -1,18 +1,17 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class Send : Core.Entities.Send - { - public virtual Organization Organization { get; set; } - public virtual User User { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class SendMapperProfile : Profile +public class Send : Core.Entities.Send +{ + public virtual Organization Organization { get; set; } + public virtual User User { get; set; } +} + +public class SendMapperProfile : Profile +{ + public SendMapperProfile() { - public SendMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/SsoConfig.cs b/src/Infrastructure.EntityFramework/Models/SsoConfig.cs index d748934f2..70e007b99 100644 --- a/src/Infrastructure.EntityFramework/Models/SsoConfig.cs +++ b/src/Infrastructure.EntityFramework/Models/SsoConfig.cs @@ -1,17 +1,16 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class SsoConfig : Core.Entities.SsoConfig - { - public virtual Organization Organization { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class SsoConfigMapperProfile : Profile +public class SsoConfig : Core.Entities.SsoConfig +{ + public virtual Organization Organization { get; set; } +} + +public class SsoConfigMapperProfile : Profile +{ + public SsoConfigMapperProfile() { - public SsoConfigMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/SsoUser.cs b/src/Infrastructure.EntityFramework/Models/SsoUser.cs index eb0298442..01333dbca 100644 --- a/src/Infrastructure.EntityFramework/Models/SsoUser.cs +++ b/src/Infrastructure.EntityFramework/Models/SsoUser.cs @@ -1,18 +1,17 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class SsoUser : Core.Entities.SsoUser - { - public virtual Organization Organization { get; set; } - public virtual User User { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class SsoUserMapperProfile : Profile +public class SsoUser : Core.Entities.SsoUser +{ + public virtual Organization Organization { get; set; } + public virtual User User { get; set; } +} + +public class SsoUserMapperProfile : Profile +{ + public SsoUserMapperProfile() { - public SsoUserMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/TaxRate.cs b/src/Infrastructure.EntityFramework/Models/TaxRate.cs index f464724ae..d47a92237 100644 --- a/src/Infrastructure.EntityFramework/Models/TaxRate.cs +++ b/src/Infrastructure.EntityFramework/Models/TaxRate.cs @@ -1,16 +1,15 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class TaxRate : Core.Entities.TaxRate - { - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class TaxRateMapperProfile : Profile +public class TaxRate : Core.Entities.TaxRate +{ +} + +public class TaxRateMapperProfile : Profile +{ + public TaxRateMapperProfile() { - public TaxRateMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/Transaction.cs b/src/Infrastructure.EntityFramework/Models/Transaction.cs index b9d4bc954..4eb63646c 100644 --- a/src/Infrastructure.EntityFramework/Models/Transaction.cs +++ b/src/Infrastructure.EntityFramework/Models/Transaction.cs @@ -1,18 +1,17 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class Transaction : Core.Entities.Transaction - { - public virtual Organization Organization { get; set; } - public virtual User User { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class TransactionMapperProfile : Profile +public class Transaction : Core.Entities.Transaction +{ + public virtual Organization Organization { get; set; } + public virtual User User { get; set; } +} + +public class TransactionMapperProfile : Profile +{ + public TransactionMapperProfile() { - public TransactionMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Models/User.cs b/src/Infrastructure.EntityFramework/Models/User.cs index 9ff81e90b..1316acccf 100644 --- a/src/Infrastructure.EntityFramework/Models/User.cs +++ b/src/Infrastructure.EntityFramework/Models/User.cs @@ -1,23 +1,22 @@ using AutoMapper; -namespace Bit.Infrastructure.EntityFramework.Models -{ - public class User : Core.Entities.User - { - public virtual ICollection Ciphers { get; set; } - public virtual ICollection Folders { get; set; } - public virtual ICollection CollectionUsers { get; set; } - public virtual ICollection GroupUsers { get; set; } - public virtual ICollection OrganizationUsers { get; set; } - public virtual ICollection SsoUsers { get; set; } - public virtual ICollection Transactions { get; set; } - } +namespace Bit.Infrastructure.EntityFramework.Models; - public class UserMapperProfile : Profile +public class User : Core.Entities.User +{ + public virtual ICollection Ciphers { get; set; } + public virtual ICollection Folders { get; set; } + public virtual ICollection CollectionUsers { get; set; } + public virtual ICollection GroupUsers { get; set; } + public virtual ICollection OrganizationUsers { get; set; } + public virtual ICollection SsoUsers { get; set; } + public virtual ICollection Transactions { get; set; } +} + +public class UserMapperProfile : Profile +{ + public UserMapperProfile() { - public UserMapperProfile() - { - CreateMap().ReverseMap(); - } + CreateMap().ReverseMap(); } } diff --git a/src/Infrastructure.EntityFramework/Repositories/BaseEntityFrameworkRepository.cs b/src/Infrastructure.EntityFramework/Repositories/BaseEntityFrameworkRepository.cs index 833994fc5..9dc7818d7 100644 --- a/src/Infrastructure.EntityFramework/Repositories/BaseEntityFrameworkRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/BaseEntityFrameworkRepository.cs @@ -10,80 +10,63 @@ using Microsoft.Extensions.DependencyInjection; using Cipher = Bit.Core.Entities.Cipher; using User = Bit.Core.Entities.User; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public abstract class BaseEntityFrameworkRepository { - public abstract class BaseEntityFrameworkRepository + protected BulkCopyOptions DefaultBulkCopyOptions { get; set; } = new BulkCopyOptions { - protected BulkCopyOptions DefaultBulkCopyOptions { get; set; } = new BulkCopyOptions - { - KeepIdentity = true, - BulkCopyType = BulkCopyType.MultipleRows, - }; + KeepIdentity = true, + BulkCopyType = BulkCopyType.MultipleRows, + }; - public BaseEntityFrameworkRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + public BaseEntityFrameworkRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + { + ServiceScopeFactory = serviceScopeFactory; + Mapper = mapper; + } + + protected IServiceScopeFactory ServiceScopeFactory { get; private set; } + protected IMapper Mapper { get; private set; } + + public DatabaseContext GetDatabaseContext(IServiceScope serviceScope) + { + return serviceScope.ServiceProvider.GetRequiredService(); + } + + public void ClearChangeTracking() + { + using (var scope = ServiceScopeFactory.CreateScope()) { - ServiceScopeFactory = serviceScopeFactory; - Mapper = mapper; + var dbContext = GetDatabaseContext(scope); + dbContext.ChangeTracker.Clear(); } + } - protected IServiceScopeFactory ServiceScopeFactory { get; private set; } - protected IMapper Mapper { get; private set; } - - public DatabaseContext GetDatabaseContext(IServiceScope serviceScope) + public async Task GetCountFromQuery(IQuery query) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - return serviceScope.ServiceProvider.GetRequiredService(); + return await query.Run(GetDatabaseContext(scope)).CountAsync(); } + } - public void ClearChangeTracking() + protected async Task UserBumpAccountRevisionDateByCipherId(Cipher cipher) + { + var list = new List { cipher }; + await UserBumpAccountRevisionDateByCipherId(list); + } + + protected async Task UserBumpAccountRevisionDateByCipherId(IEnumerable ciphers) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + foreach (var cipher in ciphers) { var dbContext = GetDatabaseContext(scope); - dbContext.ChangeTracker.Clear(); - } - } - - public async Task GetCountFromQuery(IQuery query) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - return await query.Run(GetDatabaseContext(scope)).CountAsync(); - } - } - - protected async Task UserBumpAccountRevisionDateByCipherId(Cipher cipher) - { - var list = new List { cipher }; - await UserBumpAccountRevisionDateByCipherId(list); - } - - protected async Task UserBumpAccountRevisionDateByCipherId(IEnumerable ciphers) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - foreach (var cipher in ciphers) - { - var dbContext = GetDatabaseContext(scope); - var query = new UserBumpAccountRevisionDateByCipherIdQuery(cipher); - var users = query.Run(dbContext); - - await users.ForEachAsync(e => - { - dbContext.Attach(e); - e.RevisionDate = DateTime.UtcNow; - }); - await dbContext.SaveChangesAsync(); - } - } - } - - protected async Task UserBumpAccountRevisionDateByOrganizationId(Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new UserBumpAccountRevisionDateByOrganizationIdQuery(organizationId); + var query = new UserBumpAccountRevisionDateByCipherIdQuery(cipher); var users = query.Run(dbContext); + await users.ForEachAsync(e => { dbContext.Attach(e); @@ -92,175 +75,191 @@ namespace Bit.Infrastructure.EntityFramework.Repositories await dbContext.SaveChangesAsync(); } } + } - protected async Task UserBumpAccountRevisionDate(Guid userId) + protected async Task UserBumpAccountRevisionDateByOrganizationId(Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - await UserBumpManyAccountRevisionDates(new[] { userId }); - } - - protected async Task UserBumpManyAccountRevisionDates(ICollection userIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var query = new UserBumpAccountRevisionDateByOrganizationIdQuery(organizationId); + var users = query.Run(dbContext); + await users.ForEachAsync(e => { - var dbContext = GetDatabaseContext(scope); - var users = dbContext.Users.Where(u => userIds.Contains(u.Id)); - await users.ForEachAsync(u => - { - dbContext.Attach(u); - u.RevisionDate = DateTime.UtcNow; - }); - await dbContext.SaveChangesAsync(); - } + dbContext.Attach(e); + e.RevisionDate = DateTime.UtcNow; + }); + await dbContext.SaveChangesAsync(); } + } - protected async Task OrganizationUpdateStorage(Guid organizationId) + protected async Task UserBumpAccountRevisionDate(Guid userId) + { + await UserBumpManyAccountRevisionDates(new[] { userId }); + } + + protected async Task UserBumpManyAccountRevisionDates(ICollection userIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var users = dbContext.Users.Where(u => userIds.Contains(u.Id)); + await users.ForEachAsync(u => { - var dbContext = GetDatabaseContext(scope); - var attachments = await dbContext.Ciphers - .Where(e => e.UserId == null && - e.OrganizationId == organizationId && - !string.IsNullOrWhiteSpace(e.Attachments)) - .Select(e => e.Attachments) - .ToListAsync(); - var storage = attachments.Sum(e => JsonDocument.Parse(e)?.RootElement.EnumerateArray() - .Sum(p => p.GetProperty("Size").GetInt64()) ?? 0); - var organization = new Organization - { - Id = organizationId, - RevisionDate = DateTime.UtcNow, - Storage = storage, - }; - dbContext.Organizations.Attach(organization); - var entry = dbContext.Entry(organization); - entry.Property(e => e.RevisionDate).IsModified = true; - entry.Property(e => e.Storage).IsModified = true; - await dbContext.SaveChangesAsync(); - } + dbContext.Attach(u); + u.RevisionDate = DateTime.UtcNow; + }); + await dbContext.SaveChangesAsync(); } + } - protected async Task UserUpdateStorage(Guid userId) + protected async Task OrganizationUpdateStorage(Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var attachments = await dbContext.Ciphers + .Where(e => e.UserId == null && + e.OrganizationId == organizationId && + !string.IsNullOrWhiteSpace(e.Attachments)) + .Select(e => e.Attachments) + .ToListAsync(); + var storage = attachments.Sum(e => JsonDocument.Parse(e)?.RootElement.EnumerateArray() + .Sum(p => p.GetProperty("Size").GetInt64()) ?? 0); + var organization = new Organization { - var dbContext = GetDatabaseContext(scope); - var attachments = await dbContext.Ciphers - .Where(e => e.UserId.HasValue && - e.UserId.Value == userId && - e.OrganizationId == null && - !string.IsNullOrWhiteSpace(e.Attachments)) - .Select(e => e.Attachments) - .ToListAsync(); - var storage = attachments.Sum(e => JsonDocument.Parse(e)?.RootElement.EnumerateArray() - .Sum(p => p.GetProperty("Size").GetInt64()) ?? 0); - var user = new Models.User - { - Id = userId, - RevisionDate = DateTime.UtcNow, - Storage = storage, - }; - dbContext.Users.Attach(user); - var entry = dbContext.Entry(user); - entry.Property(e => e.RevisionDate).IsModified = true; - entry.Property(e => e.Storage).IsModified = true; - await dbContext.SaveChangesAsync(); - } + Id = organizationId, + RevisionDate = DateTime.UtcNow, + Storage = storage, + }; + dbContext.Organizations.Attach(organization); + var entry = dbContext.Entry(organization); + entry.Property(e => e.RevisionDate).IsModified = true; + entry.Property(e => e.Storage).IsModified = true; + await dbContext.SaveChangesAsync(); } + } - protected async Task UserUpdateKeys(User user) + protected async Task UserUpdateStorage(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var attachments = await dbContext.Ciphers + .Where(e => e.UserId.HasValue && + e.UserId.Value == userId && + e.OrganizationId == null && + !string.IsNullOrWhiteSpace(e.Attachments)) + .Select(e => e.Attachments) + .ToListAsync(); + var storage = attachments.Sum(e => JsonDocument.Parse(e)?.RootElement.EnumerateArray() + .Sum(p => p.GetProperty("Size").GetInt64()) ?? 0); + var user = new Models.User { - var dbContext = GetDatabaseContext(scope); - var entity = await dbContext.Users.FindAsync(user.Id); - if (entity == null) - { - return; - } - entity.SecurityStamp = user.SecurityStamp; - entity.Key = user.Key; - entity.PrivateKey = user.PrivateKey; - entity.RevisionDate = DateTime.UtcNow; - await dbContext.SaveChangesAsync(); - } + Id = userId, + RevisionDate = DateTime.UtcNow, + Storage = storage, + }; + dbContext.Users.Attach(user); + var entry = dbContext.Entry(user); + entry.Property(e => e.RevisionDate).IsModified = true; + entry.Property(e => e.Storage).IsModified = true; + await dbContext.SaveChangesAsync(); } + } - protected async Task UserBumpAccountRevisionDateByCollectionId(Guid collectionId, Guid organizationId) + protected async Task UserUpdateKeys(User user) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var entity = await dbContext.Users.FindAsync(user.Id); + if (entity == null) { - var dbContext = GetDatabaseContext(scope); - var query = from u in dbContext.Users - join ou in dbContext.OrganizationUsers - on u.Id equals ou.UserId - join cu in dbContext.CollectionUsers - on ou.Id equals cu.OrganizationUserId into cu_g - from cu in cu_g.DefaultIfEmpty() - where !ou.AccessAll && cu.CollectionId.Equals(collectionId) - join gu in dbContext.GroupUsers - on ou.Id equals gu.OrganizationUserId into gu_g - from gu in gu_g.DefaultIfEmpty() - where cu.CollectionId == default(Guid) && !ou.AccessAll - join g in dbContext.Groups - on gu.GroupId equals g.Id into g_g - from g in g_g.DefaultIfEmpty() - join cg in dbContext.CollectionGroups - on gu.GroupId equals cg.GroupId into cg_g - from cg in cg_g.DefaultIfEmpty() - where !g.AccessAll && cg.CollectionId == collectionId && - (ou.OrganizationId == organizationId && ou.Status == OrganizationUserStatusType.Confirmed && - (cu.CollectionId != default(Guid) || cg.CollectionId != default(Guid) || ou.AccessAll || g.AccessAll)) - select new { u, ou, cu, gu, g, cg }; - var users = query.Select(x => x.u); - await users.ForEachAsync(u => - { - dbContext.Attach(u); - u.RevisionDate = DateTime.UtcNow; - }); - await dbContext.SaveChangesAsync(); + return; } + entity.SecurityStamp = user.SecurityStamp; + entity.Key = user.Key; + entity.PrivateKey = user.PrivateKey; + entity.RevisionDate = DateTime.UtcNow; + await dbContext.SaveChangesAsync(); } + } - protected async Task UserBumpAccountRevisionDateByOrganizationUserId(Guid organizationUserId) + protected async Task UserBumpAccountRevisionDateByCollectionId(Guid collectionId, Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var query = from u in dbContext.Users + join ou in dbContext.OrganizationUsers + on u.Id equals ou.UserId + join cu in dbContext.CollectionUsers + on ou.Id equals cu.OrganizationUserId into cu_g + from cu in cu_g.DefaultIfEmpty() + where !ou.AccessAll && cu.CollectionId.Equals(collectionId) + join gu in dbContext.GroupUsers + on ou.Id equals gu.OrganizationUserId into gu_g + from gu in gu_g.DefaultIfEmpty() + where cu.CollectionId == default(Guid) && !ou.AccessAll + join g in dbContext.Groups + on gu.GroupId equals g.Id into g_g + from g in g_g.DefaultIfEmpty() + join cg in dbContext.CollectionGroups + on gu.GroupId equals cg.GroupId into cg_g + from cg in cg_g.DefaultIfEmpty() + where !g.AccessAll && cg.CollectionId == collectionId && + (ou.OrganizationId == organizationId && ou.Status == OrganizationUserStatusType.Confirmed && + (cu.CollectionId != default(Guid) || cg.CollectionId != default(Guid) || ou.AccessAll || g.AccessAll)) + select new { u, ou, cu, gu, g, cg }; + var users = query.Select(x => x.u); + await users.ForEachAsync(u => { - var dbContext = GetDatabaseContext(scope); - var query = from u in dbContext.Users - join ou in dbContext.OrganizationUsers - on u.Id equals ou.UserId - where ou.Id.Equals(organizationUserId) && ou.Status.Equals(OrganizationUserStatusType.Confirmed) - select new { u, ou }; - var users = query.Select(x => x.u); - await users.ForEachAsync(u => - { - dbContext.Attach(u); - u.AccountRevisionDate = DateTime.UtcNow; - }); - await dbContext.SaveChangesAsync(); - } + dbContext.Attach(u); + u.RevisionDate = DateTime.UtcNow; + }); + await dbContext.SaveChangesAsync(); } + } - protected async Task UserBumpAccountRevisionDateByProviderUserIds(ICollection providerUserIds) + protected async Task UserBumpAccountRevisionDateByOrganizationUserId(Guid organizationUserId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var query = from u in dbContext.Users + join ou in dbContext.OrganizationUsers + on u.Id equals ou.UserId + where ou.Id.Equals(organizationUserId) && ou.Status.Equals(OrganizationUserStatusType.Confirmed) + select new { u, ou }; + var users = query.Select(x => x.u); + await users.ForEachAsync(u => { - var dbContext = GetDatabaseContext(scope); - var query = from pu in dbContext.ProviderUsers - join u in dbContext.Users - on pu.UserId equals u.Id - where pu.Status.Equals(ProviderUserStatusType.Confirmed) && - providerUserIds.Contains(pu.Id) - select new { pu, u }; - var users = query.Select(x => x.u); - await users.ForEachAsync(u => - { - dbContext.Attach(u); - u.AccountRevisionDate = DateTime.UtcNow; - }); - await dbContext.SaveChangesAsync(); - } + dbContext.Attach(u); + u.AccountRevisionDate = DateTime.UtcNow; + }); + await dbContext.SaveChangesAsync(); + } + } + + protected async Task UserBumpAccountRevisionDateByProviderUserIds(ICollection providerUserIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from pu in dbContext.ProviderUsers + join u in dbContext.Users + on pu.UserId equals u.Id + where pu.Status.Equals(ProviderUserStatusType.Confirmed) && + providerUserIds.Contains(pu.Id) + select new { pu, u }; + var users = query.Select(x => x.u); + await users.ForEachAsync(u => + { + dbContext.Attach(u); + u.AccountRevisionDate = DateTime.UtcNow; + }); + await dbContext.SaveChangesAsync(); } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/CipherRepository.cs b/src/Infrastructure.EntityFramework/Repositories/CipherRepository.cs index fdf528393..17aaedfac 100644 --- a/src/Infrastructure.EntityFramework/Repositories/CipherRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/CipherRepository.cs @@ -13,638 +13,637 @@ using Newtonsoft.Json; using Newtonsoft.Json.Linq; using User = Bit.Core.Entities.User; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class CipherRepository : Repository, ICipherRepository { - public class CipherRepository : Repository, ICipherRepository + public CipherRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Ciphers) + { } + + public override async Task CreateAsync(Core.Entities.Cipher cipher) { - public CipherRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Ciphers) - { } - - public override async Task CreateAsync(Core.Entities.Cipher cipher) + cipher = await base.CreateAsync(cipher); + using (var scope = ServiceScopeFactory.CreateScope()) { - cipher = await base.CreateAsync(cipher); - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + if (cipher.OrganizationId.HasValue) { - var dbContext = GetDatabaseContext(scope); - if (cipher.OrganizationId.HasValue) + await UserBumpAccountRevisionDateByCipherId(cipher); + } + else if (cipher.UserId.HasValue) + { + await UserBumpAccountRevisionDate(cipher.UserId.Value); + } + } + return cipher; + } + + public IQueryable GetBumpedAccountsByCipherId(Core.Entities.Cipher cipher) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new UserBumpAccountRevisionDateByCipherIdQuery(cipher); + return query.Run(dbContext); + } + } + + public async Task CreateAsync(Core.Entities.Cipher cipher, IEnumerable collectionIds) + { + cipher = await base.CreateAsync(cipher); + await UpdateCollections(cipher, collectionIds); + } + + private async Task UpdateCollections(Core.Entities.Cipher cipher, IEnumerable collectionIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var cipherEntity = await dbContext.Ciphers.FindAsync(cipher.Id); + var query = new CipherUpdateCollectionsQuery(cipherEntity, collectionIds).Run(dbContext); + await dbContext.AddRangeAsync(query); + await dbContext.SaveChangesAsync(); + } + } + + public async Task CreateAsync(CipherDetails cipher) + { + await CreateAsyncReturnCipher(cipher); + } + + private async Task CreateAsyncReturnCipher(CipherDetails cipher) + { + cipher.SetNewId(); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var userIdKey = $"\"{cipher.UserId}\""; + cipher.UserId = cipher.OrganizationId.HasValue ? null : cipher.UserId; + cipher.Favorites = cipher.Favorite ? + $"{{{userIdKey}:true}}" : + null; + cipher.Folders = cipher.FolderId.HasValue ? + $"{{{userIdKey}:\"{cipher.FolderId}\"}}" : + null; + var entity = Mapper.Map((Core.Entities.Cipher)cipher); + await dbContext.AddAsync(entity); + await dbContext.SaveChangesAsync(); + } + await UserBumpAccountRevisionDateByCipherId(cipher); + return cipher; + } + + public async Task CreateAsync(CipherDetails cipher, IEnumerable collectionIds) + { + cipher = await CreateAsyncReturnCipher(cipher); + await UpdateCollections(cipher, collectionIds); + } + + public async Task CreateAsync(IEnumerable ciphers, IEnumerable folders) + { + if (!ciphers.Any()) + { + return; + } + + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var folderEntities = Mapper.Map>(folders); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, folderEntities); + var cipherEntities = Mapper.Map>(ciphers); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, cipherEntities); + await UserBumpAccountRevisionDateByCipherId(ciphers); + } + } + + public async Task CreateAsync(IEnumerable ciphers, IEnumerable collections, IEnumerable collectionCiphers) + { + if (!ciphers.Any()) + { + return; + } + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var cipherEntities = Mapper.Map>(ciphers); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, cipherEntities); + if (collections.Any()) + { + var collectionEntities = Mapper.Map>(collections); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, collectionEntities); + + if (collectionCiphers.Any()) { - await UserBumpAccountRevisionDateByCipherId(cipher); - } - else if (cipher.UserId.HasValue) - { - await UserBumpAccountRevisionDate(cipher.UserId.Value); + var collectionCipherEntities = Mapper.Map>(collectionCiphers); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, collectionCipherEntities); } } - return cipher; + await UserBumpAccountRevisionDateByOrganizationId(ciphers.First().OrganizationId.Value); } + } - public IQueryable GetBumpedAccountsByCipherId(Core.Entities.Cipher cipher) + public async Task DeleteAsync(IEnumerable ids, Guid userId) + { + await ToggleCipherStates(ids, userId, CipherStateAction.HardDelete); + } + + public async Task DeleteAttachmentAsync(Guid cipherId, string attachmentId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var cipher = await dbContext.Ciphers.FindAsync(cipherId); + var attachmentsJson = JObject.Parse(cipher.Attachments); + attachmentsJson.Remove(attachmentId); + cipher.Attachments = JsonConvert.SerializeObject(attachmentsJson); + await dbContext.SaveChangesAsync(); + + if (cipher.OrganizationId.HasValue) { - var dbContext = GetDatabaseContext(scope); - var query = new UserBumpAccountRevisionDateByCipherIdQuery(cipher); - return query.Run(dbContext); + await OrganizationUpdateStorage(cipher.OrganizationId.Value); + await UserBumpAccountRevisionDateByCipherId(cipher); + } + else if (cipher.UserId.HasValue) + { + await UserUpdateStorage(cipher.UserId.Value); + await UserBumpAccountRevisionDate(cipher.UserId.Value); } } + } - public async Task CreateAsync(Core.Entities.Cipher cipher, IEnumerable collectionIds) + public async Task DeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - cipher = await base.CreateAsync(cipher); - await UpdateCollections(cipher, collectionIds); + var dbContext = GetDatabaseContext(scope); + var ciphers = from c in dbContext.Ciphers + where c.OrganizationId == organizationId && + ids.Contains(c.Id) + select c; + dbContext.RemoveRange(ciphers); + await dbContext.SaveChangesAsync(); + } + await OrganizationUpdateStorage(organizationId); + await UserBumpAccountRevisionDateByOrganizationId(organizationId); + } + + public async Task DeleteByOrganizationIdAsync(Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + + var collectionCiphers = from cc in dbContext.CollectionCiphers + join c in dbContext.Collections + on cc.CollectionId equals c.Id + where c.OrganizationId == organizationId + select cc; + dbContext.RemoveRange(collectionCiphers); + + var ciphers = from c in dbContext.Ciphers + where c.OrganizationId == organizationId + select c; + dbContext.RemoveRange(ciphers); + + await dbContext.SaveChangesAsync(); + } + await OrganizationUpdateStorage(organizationId); + await UserBumpAccountRevisionDateByOrganizationId(organizationId); + } + + public async Task DeleteByUserIdAsync(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var ciphers = from c in dbContext.Ciphers + where c.UserId == userId + select c; + dbContext.RemoveRange(ciphers); + var folders = from f in dbContext.Folders + where f.UserId == userId + select f; + dbContext.RemoveRange(folders); + await dbContext.SaveChangesAsync(); + await UserUpdateStorage(userId); + await UserBumpAccountRevisionDate(userId); } - private async Task UpdateCollections(Core.Entities.Cipher cipher, IEnumerable collectionIds) + } + + public async Task DeleteDeletedAsync(DateTime deletedDateBefore) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var query = dbContext.Ciphers.Where(c => c.DeletedDate < deletedDateBefore); + dbContext.RemoveRange(query); + await dbContext.SaveChangesAsync(); + } + } + + public async Task GetByIdAsync(Guid id, Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var userCipherDetails = new UserCipherDetailsQuery(userId); + var data = await userCipherDetails.Run(dbContext).FirstOrDefaultAsync(c => c.Id == id); + return data; + } + } + + public async Task> GetManyOrganizationDetailsByOrganizationIdAsync( + Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new CipherOrganizationDetailsReadByIdQuery(organizationId); + var data = await query.Run(dbContext).ToListAsync(); + return data; + } + } + + public async Task GetCanEditByIdAsync(Guid userId, Guid cipherId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new CipherReadCanEditByIdUserIdQuery(userId, cipherId); + var canEdit = await query.Run(dbContext).AnyAsync(); + return canEdit; + } + } + + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = dbContext.Ciphers.Where(x => !x.UserId.HasValue && x.OrganizationId == organizationId); + var data = await query.ToListAsync(); + return Mapper.Map>(data); + } + } + + public async Task> GetManyByUserIdAsync(Guid userId, bool withOrganizations = true) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + IQueryable cipherDetailsView = withOrganizations ? + new UserCipherDetailsQuery(userId).Run(dbContext) : + new CipherDetailsQuery(userId).Run(dbContext); + if (!withOrganizations) { - var dbContext = GetDatabaseContext(scope); - var cipherEntity = await dbContext.Ciphers.FindAsync(cipher.Id); - var query = new CipherUpdateCollectionsQuery(cipherEntity, collectionIds).Run(dbContext); - await dbContext.AddRangeAsync(query); - await dbContext.SaveChangesAsync(); + cipherDetailsView = from c in cipherDetailsView + where c.UserId == userId + select new CipherDetails + { + Id = c.Id, + UserId = c.UserId, + OrganizationId = c.OrganizationId, + Type = c.Type, + Data = c.Data, + Attachments = c.Attachments, + CreationDate = c.CreationDate, + RevisionDate = c.RevisionDate, + DeletedDate = c.DeletedDate, + Favorite = c.Favorite, + FolderId = c.FolderId, + Edit = true, + ViewPassword = true, + OrganizationUseTotp = false, + }; } + var ciphers = await cipherDetailsView.ToListAsync(); + return ciphers; } + } - public async Task CreateAsync(CipherDetails cipher) + public async Task GetOrganizationDetailsByIdAsync(Guid id) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - await CreateAsyncReturnCipher(cipher); + var dbContext = GetDatabaseContext(scope); + var query = new CipherOrganizationDetailsReadByIdQuery(id); + var data = await query.Run(dbContext).FirstOrDefaultAsync(); + return data; } + } - private async Task CreateAsyncReturnCipher(CipherDetails cipher) + public async Task MoveAsync(IEnumerable ids, Guid? folderId, Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - cipher.SetNewId(); - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var cipherEntities = dbContext.Ciphers.Where(c => ids.Contains(c.Id)); + var userCipherDetails = new UserCipherDetailsQuery(userId).Run(dbContext); + var idsToMove = from ucd in userCipherDetails + join c in cipherEntities + on ucd.Id equals c.Id + where ucd.Edit + select c; + await idsToMove.ForEachAsync(cipher => + { + var foldersJson = string.IsNullOrWhiteSpace(cipher.Folders) ? + new JObject() : + JObject.Parse(cipher.Folders); + + if (folderId.HasValue) + { + foldersJson.Remove(userId.ToString()); + foldersJson.Add(userId.ToString(), folderId.Value.ToString()); + } + else if (!string.IsNullOrWhiteSpace(cipher.Folders)) + { + foldersJson.Remove(userId.ToString()); + } + dbContext.Attach(cipher); + cipher.Folders = JsonConvert.SerializeObject(foldersJson); + }); + await dbContext.SaveChangesAsync(); + await UserBumpAccountRevisionDate(userId); + } + } + + public async Task ReplaceAsync(CipherDetails cipher) + { + cipher.UserId = cipher.OrganizationId.HasValue ? + null : + cipher.UserId; + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entity = await dbContext.Ciphers.FindAsync(cipher.Id); + if (entity != null) { - var dbContext = GetDatabaseContext(scope); var userIdKey = $"\"{cipher.UserId}\""; - cipher.UserId = cipher.OrganizationId.HasValue ? null : cipher.UserId; - cipher.Favorites = cipher.Favorite ? - $"{{{userIdKey}:true}}" : - null; - cipher.Folders = cipher.FolderId.HasValue ? - $"{{{userIdKey}:\"{cipher.FolderId}\"}}" : - null; - var entity = Mapper.Map((Core.Entities.Cipher)cipher); - await dbContext.AddAsync(entity); - await dbContext.SaveChangesAsync(); - } - await UserBumpAccountRevisionDateByCipherId(cipher); - return cipher; - } - - public async Task CreateAsync(CipherDetails cipher, IEnumerable collectionIds) - { - cipher = await CreateAsyncReturnCipher(cipher); - await UpdateCollections(cipher, collectionIds); - } - - public async Task CreateAsync(IEnumerable ciphers, IEnumerable folders) - { - if (!ciphers.Any()) - { - return; - } - - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var folderEntities = Mapper.Map>(folders); - await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, folderEntities); - var cipherEntities = Mapper.Map>(ciphers); - await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, cipherEntities); - await UserBumpAccountRevisionDateByCipherId(ciphers); - } - } - - public async Task CreateAsync(IEnumerable ciphers, IEnumerable collections, IEnumerable collectionCiphers) - { - if (!ciphers.Any()) - { - return; - } - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var cipherEntities = Mapper.Map>(ciphers); - await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, cipherEntities); - if (collections.Any()) + if (cipher.Favorite) { - var collectionEntities = Mapper.Map>(collections); - await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, collectionEntities); - - if (collectionCiphers.Any()) + if (cipher.Favorites == null) { - var collectionCipherEntities = Mapper.Map>(collectionCiphers); - await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, collectionCipherEntities); + cipher.Favorites = $"{{{userIdKey}:true}}"; + } + else + { + var favorites = CoreHelpers.LoadClassFromJsonData>(cipher.Favorites); + favorites.Add(cipher.UserId.Value, true); + cipher.Favorites = JsonConvert.SerializeObject(favorites); } } - await UserBumpAccountRevisionDateByOrganizationId(ciphers.First().OrganizationId.Value); + else + { + if (cipher.Favorites != null && cipher.Favorites.Contains(cipher.UserId.Value.ToString())) + { + var favorites = CoreHelpers.LoadClassFromJsonData>(cipher.Favorites); + favorites.Remove(cipher.UserId.Value); + cipher.Favorites = JsonConvert.SerializeObject(favorites); + } + } + if (cipher.FolderId.HasValue) + { + if (cipher.Folders == null) + { + cipher.Folders = $"{{{userIdKey}:\"{cipher.FolderId}\"}}"; + } + else + { + var folders = CoreHelpers.LoadClassFromJsonData>(cipher.Folders); + folders.Add(cipher.UserId.Value, cipher.FolderId.Value); + cipher.Folders = JsonConvert.SerializeObject(folders); + } + } + else + { + if (cipher.Folders != null && cipher.Folders.Contains(cipher.UserId.Value.ToString())) + { + var folders = CoreHelpers.LoadClassFromJsonData>(cipher.Favorites); + folders.Remove(cipher.UserId.Value); + cipher.Favorites = JsonConvert.SerializeObject(folders); + } + } + var mappedEntity = Mapper.Map((Core.Entities.Cipher)cipher); + dbContext.Entry(entity).CurrentValues.SetValues(mappedEntity); + await UserBumpAccountRevisionDateByCipherId(cipher); + await dbContext.SaveChangesAsync(); } } + } - public async Task DeleteAsync(IEnumerable ids, Guid userId) + public async Task ReplaceAsync(Core.Entities.Cipher obj, IEnumerable collectionIds) + { + await UpdateCollections(obj, collectionIds); + using (var scope = ServiceScopeFactory.CreateScope()) { - await ToggleCipherStates(ids, userId, CipherStateAction.HardDelete); - } + var dbContext = GetDatabaseContext(scope); + var cipher = await dbContext.Ciphers.FindAsync(obj.Id); + cipher.UserId = null; + cipher.OrganizationId = obj.OrganizationId; + cipher.Data = obj.Data; + cipher.Attachments = obj.Attachments; + cipher.RevisionDate = obj.RevisionDate; + cipher.DeletedDate = obj.DeletedDate; + await dbContext.SaveChangesAsync(); - public async Task DeleteAttachmentAsync(Guid cipherId, string attachmentId) - { - using (var scope = ServiceScopeFactory.CreateScope()) + if (!string.IsNullOrWhiteSpace(cipher.Attachments)) { - var dbContext = GetDatabaseContext(scope); - var cipher = await dbContext.Ciphers.FindAsync(cipherId); - var attachmentsJson = JObject.Parse(cipher.Attachments); - attachmentsJson.Remove(attachmentId); - cipher.Attachments = JsonConvert.SerializeObject(attachmentsJson); - await dbContext.SaveChangesAsync(); - if (cipher.OrganizationId.HasValue) { await OrganizationUpdateStorage(cipher.OrganizationId.Value); - await UserBumpAccountRevisionDateByCipherId(cipher); } else if (cipher.UserId.HasValue) { await UserUpdateStorage(cipher.UserId.Value); - await UserBumpAccountRevisionDate(cipher.UserId.Value); } } - } - public async Task DeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId) + await UserBumpAccountRevisionDateByCipherId(cipher); + return true; + } + } + + public async Task RestoreAsync(IEnumerable ids, Guid userId) + { + return await ToggleCipherStates(ids, userId, CipherStateAction.Restore); + } + + public async Task SoftDeleteAsync(IEnumerable ids, Guid userId) + { + await ToggleCipherStates(ids, userId, CipherStateAction.SoftDelete); + } + + private async Task ToggleCipherStates(IEnumerable ids, Guid userId, CipherStateAction action) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var userCipherDetailsQuery = new UserCipherDetailsQuery(userId); + var cipherEntitiesToCheck = await (dbContext.Ciphers.Where(c => ids.Contains(c.Id))).ToListAsync(); + var query = from ucd in await (userCipherDetailsQuery.Run(dbContext)).ToListAsync() + join c in cipherEntitiesToCheck + on ucd.Id equals c.Id + where ucd.Edit && ucd.DeletedDate == null + select c; + + var utcNow = DateTime.UtcNow; + var cipherIdsToModify = query.Select(c => c.Id); + var cipherEntitiesToModify = dbContext.Ciphers.Where(x => cipherIdsToModify.Contains(x.Id)); + if (action == CipherStateAction.HardDelete) { - var dbContext = GetDatabaseContext(scope); - var ciphers = from c in dbContext.Ciphers - where c.OrganizationId == organizationId && - ids.Contains(c.Id) - select c; - dbContext.RemoveRange(ciphers); - await dbContext.SaveChangesAsync(); - } - await OrganizationUpdateStorage(organizationId); - await UserBumpAccountRevisionDateByOrganizationId(organizationId); - } - - public async Task DeleteByOrganizationIdAsync(Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - - var collectionCiphers = from cc in dbContext.CollectionCiphers - join c in dbContext.Collections - on cc.CollectionId equals c.Id - where c.OrganizationId == organizationId - select cc; - dbContext.RemoveRange(collectionCiphers); - - var ciphers = from c in dbContext.Ciphers - where c.OrganizationId == organizationId - select c; - dbContext.RemoveRange(ciphers); - - await dbContext.SaveChangesAsync(); - } - await OrganizationUpdateStorage(organizationId); - await UserBumpAccountRevisionDateByOrganizationId(organizationId); - } - - public async Task DeleteByUserIdAsync(Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var ciphers = from c in dbContext.Ciphers - where c.UserId == userId - select c; - dbContext.RemoveRange(ciphers); - var folders = from f in dbContext.Folders - where f.UserId == userId - select f; - dbContext.RemoveRange(folders); - await dbContext.SaveChangesAsync(); - await UserUpdateStorage(userId); - await UserBumpAccountRevisionDate(userId); - } - - } - - public async Task DeleteDeletedAsync(DateTime deletedDateBefore) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = dbContext.Ciphers.Where(c => c.DeletedDate < deletedDateBefore); - dbContext.RemoveRange(query); - await dbContext.SaveChangesAsync(); - } - } - - public async Task GetByIdAsync(Guid id, Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var userCipherDetails = new UserCipherDetailsQuery(userId); - var data = await userCipherDetails.Run(dbContext).FirstOrDefaultAsync(c => c.Id == id); - return data; - } - } - - public async Task> GetManyOrganizationDetailsByOrganizationIdAsync( - Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new CipherOrganizationDetailsReadByIdQuery(organizationId); - var data = await query.Run(dbContext).ToListAsync(); - return data; - } - } - - public async Task GetCanEditByIdAsync(Guid userId, Guid cipherId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new CipherReadCanEditByIdUserIdQuery(userId, cipherId); - var canEdit = await query.Run(dbContext).AnyAsync(); - return canEdit; - } - } - - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = dbContext.Ciphers.Where(x => !x.UserId.HasValue && x.OrganizationId == organizationId); - var data = await query.ToListAsync(); - return Mapper.Map>(data); - } - } - - public async Task> GetManyByUserIdAsync(Guid userId, bool withOrganizations = true) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - IQueryable cipherDetailsView = withOrganizations ? - new UserCipherDetailsQuery(userId).Run(dbContext) : - new CipherDetailsQuery(userId).Run(dbContext); - if (!withOrganizations) - { - cipherDetailsView = from c in cipherDetailsView - where c.UserId == userId - select new CipherDetails - { - Id = c.Id, - UserId = c.UserId, - OrganizationId = c.OrganizationId, - Type = c.Type, - Data = c.Data, - Attachments = c.Attachments, - CreationDate = c.CreationDate, - RevisionDate = c.RevisionDate, - DeletedDate = c.DeletedDate, - Favorite = c.Favorite, - FolderId = c.FolderId, - Edit = true, - ViewPassword = true, - OrganizationUseTotp = false, - }; - } - var ciphers = await cipherDetailsView.ToListAsync(); - return ciphers; - } - } - - public async Task GetOrganizationDetailsByIdAsync(Guid id) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new CipherOrganizationDetailsReadByIdQuery(id); - var data = await query.Run(dbContext).FirstOrDefaultAsync(); - return data; - } - } - - public async Task MoveAsync(IEnumerable ids, Guid? folderId, Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var cipherEntities = dbContext.Ciphers.Where(c => ids.Contains(c.Id)); - var userCipherDetails = new UserCipherDetailsQuery(userId).Run(dbContext); - var idsToMove = from ucd in userCipherDetails - join c in cipherEntities - on ucd.Id equals c.Id - where ucd.Edit - select c; - await idsToMove.ForEachAsync(cipher => - { - var foldersJson = string.IsNullOrWhiteSpace(cipher.Folders) ? - new JObject() : - JObject.Parse(cipher.Folders); - - if (folderId.HasValue) - { - foldersJson.Remove(userId.ToString()); - foldersJson.Add(userId.ToString(), folderId.Value.ToString()); - } - else if (!string.IsNullOrWhiteSpace(cipher.Folders)) - { - foldersJson.Remove(userId.ToString()); - } - dbContext.Attach(cipher); - cipher.Folders = JsonConvert.SerializeObject(foldersJson); - }); - await dbContext.SaveChangesAsync(); - await UserBumpAccountRevisionDate(userId); - } - } - - public async Task ReplaceAsync(CipherDetails cipher) - { - cipher.UserId = cipher.OrganizationId.HasValue ? - null : - cipher.UserId; - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entity = await dbContext.Ciphers.FindAsync(cipher.Id); - if (entity != null) - { - var userIdKey = $"\"{cipher.UserId}\""; - if (cipher.Favorite) - { - if (cipher.Favorites == null) - { - cipher.Favorites = $"{{{userIdKey}:true}}"; - } - else - { - var favorites = CoreHelpers.LoadClassFromJsonData>(cipher.Favorites); - favorites.Add(cipher.UserId.Value, true); - cipher.Favorites = JsonConvert.SerializeObject(favorites); - } - } - else - { - if (cipher.Favorites != null && cipher.Favorites.Contains(cipher.UserId.Value.ToString())) - { - var favorites = CoreHelpers.LoadClassFromJsonData>(cipher.Favorites); - favorites.Remove(cipher.UserId.Value); - cipher.Favorites = JsonConvert.SerializeObject(favorites); - } - } - if (cipher.FolderId.HasValue) - { - if (cipher.Folders == null) - { - cipher.Folders = $"{{{userIdKey}:\"{cipher.FolderId}\"}}"; - } - else - { - var folders = CoreHelpers.LoadClassFromJsonData>(cipher.Folders); - folders.Add(cipher.UserId.Value, cipher.FolderId.Value); - cipher.Folders = JsonConvert.SerializeObject(folders); - } - } - else - { - if (cipher.Folders != null && cipher.Folders.Contains(cipher.UserId.Value.ToString())) - { - var folders = CoreHelpers.LoadClassFromJsonData>(cipher.Favorites); - folders.Remove(cipher.UserId.Value); - cipher.Favorites = JsonConvert.SerializeObject(folders); - } - } - var mappedEntity = Mapper.Map((Core.Entities.Cipher)cipher); - dbContext.Entry(entity).CurrentValues.SetValues(mappedEntity); - await UserBumpAccountRevisionDateByCipherId(cipher); - await dbContext.SaveChangesAsync(); - } - } - } - - public async Task ReplaceAsync(Core.Entities.Cipher obj, IEnumerable collectionIds) - { - await UpdateCollections(obj, collectionIds); - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var cipher = await dbContext.Ciphers.FindAsync(obj.Id); - cipher.UserId = null; - cipher.OrganizationId = obj.OrganizationId; - cipher.Data = obj.Data; - cipher.Attachments = obj.Attachments; - cipher.RevisionDate = obj.RevisionDate; - cipher.DeletedDate = obj.DeletedDate; - await dbContext.SaveChangesAsync(); - - if (!string.IsNullOrWhiteSpace(cipher.Attachments)) - { - if (cipher.OrganizationId.HasValue) - { - await OrganizationUpdateStorage(cipher.OrganizationId.Value); - } - else if (cipher.UserId.HasValue) - { - await UserUpdateStorage(cipher.UserId.Value); - } - } - - await UserBumpAccountRevisionDateByCipherId(cipher); - return true; - } - } - - public async Task RestoreAsync(IEnumerable ids, Guid userId) - { - return await ToggleCipherStates(ids, userId, CipherStateAction.Restore); - } - - public async Task SoftDeleteAsync(IEnumerable ids, Guid userId) - { - await ToggleCipherStates(ids, userId, CipherStateAction.SoftDelete); - } - - private async Task ToggleCipherStates(IEnumerable ids, Guid userId, CipherStateAction action) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var userCipherDetailsQuery = new UserCipherDetailsQuery(userId); - var cipherEntitiesToCheck = await (dbContext.Ciphers.Where(c => ids.Contains(c.Id))).ToListAsync(); - var query = from ucd in await (userCipherDetailsQuery.Run(dbContext)).ToListAsync() - join c in cipherEntitiesToCheck - on ucd.Id equals c.Id - where ucd.Edit && ucd.DeletedDate == null - select c; - - var utcNow = DateTime.UtcNow; - var cipherIdsToModify = query.Select(c => c.Id); - var cipherEntitiesToModify = dbContext.Ciphers.Where(x => cipherIdsToModify.Contains(x.Id)); - if (action == CipherStateAction.HardDelete) - { - dbContext.RemoveRange(cipherEntitiesToModify); - } - else - { - await cipherEntitiesToModify.ForEachAsync(cipher => - { - dbContext.Attach(cipher); - cipher.DeletedDate = action == CipherStateAction.Restore ? null : utcNow; - cipher.RevisionDate = utcNow; - }); - } - - var orgIds = query - .Where(c => c.OrganizationId.HasValue) - .GroupBy(c => c.OrganizationId).Select(x => x.Key); - - foreach (var orgId in orgIds) - { - await OrganizationUpdateStorage(orgId.Value); - await UserBumpAccountRevisionDateByOrganizationId(orgId.Value); - } - if (query.Any(c => c.UserId.HasValue && !string.IsNullOrWhiteSpace(c.Attachments))) - { - await UserUpdateStorage(userId); - } - await UserBumpAccountRevisionDate(userId); - await dbContext.SaveChangesAsync(); - return utcNow; - } - } - - public async Task SoftDeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var utcNow = DateTime.UtcNow; - var ciphers = dbContext.Ciphers.Where(c => ids.Contains(c.Id) && c.OrganizationId == organizationId); - await ciphers.ForEachAsync(cipher => - { - dbContext.Attach(cipher); - cipher.DeletedDate = utcNow; - cipher.RevisionDate = utcNow; - }); - await dbContext.SaveChangesAsync(); - await OrganizationUpdateStorage(organizationId); - await UserBumpAccountRevisionDateByOrganizationId(organizationId); - } - } - - public async Task UpdateAttachmentAsync(CipherAttachment attachment) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var cipher = await dbContext.Ciphers.FindAsync(attachment.Id); - var attachmentsJson = string.IsNullOrWhiteSpace(cipher.Attachments) ? new JObject() : JObject.Parse(cipher.Attachments); - attachmentsJson.Add(attachment.AttachmentId, attachment.AttachmentData); - cipher.Attachments = JsonConvert.SerializeObject(attachmentsJson); - await dbContext.SaveChangesAsync(); - - if (attachment.OrganizationId.HasValue) - { - await OrganizationUpdateStorage(cipher.OrganizationId.Value); - await UserBumpAccountRevisionDateByCipherId(new List { cipher }); - } - else if (attachment.UserId.HasValue) - { - await UserUpdateStorage(attachment.UserId.Value); - await UserBumpAccountRevisionDate(attachment.UserId.Value); - } - } - } - - public async Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers) - { - if (!ciphers.Any()) - { - return; - } - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entities = Mapper.Map>(ciphers); - await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, entities); - await UserBumpAccountRevisionDate(userId); - } - } - - public async Task UpdatePartialAsync(Guid id, Guid userId, Guid? folderId, bool favorite) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var cipher = await dbContext.Ciphers.FindAsync(id); - - var foldersJson = JObject.Parse(cipher.Folders); - if (foldersJson == null && folderId.HasValue) - { - foldersJson.Add(userId.ToString(), folderId.Value); - } - else if (foldersJson != null && folderId.HasValue) - { - foldersJson[userId] = folderId.Value; - } - else - { - foldersJson.Remove(userId.ToString()); - } - - var favoritesJson = JObject.Parse(cipher.Favorites); - if (favorite) - { - favoritesJson.Add(userId.ToString(), favorite); - } - else - { - favoritesJson.Remove(userId.ToString()); - } - - await dbContext.SaveChangesAsync(); - await UserBumpAccountRevisionDate(userId); - } - } - - public async Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable ciphers, IEnumerable folders, IEnumerable sends) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - await UserUpdateKeys(user); - var cipherEntities = Mapper.Map>(ciphers); - await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, cipherEntities); - var folderEntities = Mapper.Map>(folders); - await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, folderEntities); - var sendEntities = Mapper.Map>(sends); - await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, sendEntities); - await dbContext.SaveChangesAsync(); - } - } - - public async Task UpsertAsync(CipherDetails cipher) - { - if (cipher.Id.Equals(default)) - { - await CreateAsync(cipher); + dbContext.RemoveRange(cipherEntitiesToModify); } else { - await ReplaceAsync(cipher); + await cipherEntitiesToModify.ForEachAsync(cipher => + { + dbContext.Attach(cipher); + cipher.DeletedDate = action == CipherStateAction.Restore ? null : utcNow; + cipher.RevisionDate = utcNow; + }); + } + + var orgIds = query + .Where(c => c.OrganizationId.HasValue) + .GroupBy(c => c.OrganizationId).Select(x => x.Key); + + foreach (var orgId in orgIds) + { + await OrganizationUpdateStorage(orgId.Value); + await UserBumpAccountRevisionDateByOrganizationId(orgId.Value); + } + if (query.Any(c => c.UserId.HasValue && !string.IsNullOrWhiteSpace(c.Attachments))) + { + await UserUpdateStorage(userId); + } + await UserBumpAccountRevisionDate(userId); + await dbContext.SaveChangesAsync(); + return utcNow; + } + } + + public async Task SoftDeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var utcNow = DateTime.UtcNow; + var ciphers = dbContext.Ciphers.Where(c => ids.Contains(c.Id) && c.OrganizationId == organizationId); + await ciphers.ForEachAsync(cipher => + { + dbContext.Attach(cipher); + cipher.DeletedDate = utcNow; + cipher.RevisionDate = utcNow; + }); + await dbContext.SaveChangesAsync(); + await OrganizationUpdateStorage(organizationId); + await UserBumpAccountRevisionDateByOrganizationId(organizationId); + } + } + + public async Task UpdateAttachmentAsync(CipherAttachment attachment) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var cipher = await dbContext.Ciphers.FindAsync(attachment.Id); + var attachmentsJson = string.IsNullOrWhiteSpace(cipher.Attachments) ? new JObject() : JObject.Parse(cipher.Attachments); + attachmentsJson.Add(attachment.AttachmentId, attachment.AttachmentData); + cipher.Attachments = JsonConvert.SerializeObject(attachmentsJson); + await dbContext.SaveChangesAsync(); + + if (attachment.OrganizationId.HasValue) + { + await OrganizationUpdateStorage(cipher.OrganizationId.Value); + await UserBumpAccountRevisionDateByCipherId(new List { cipher }); + } + else if (attachment.UserId.HasValue) + { + await UserUpdateStorage(attachment.UserId.Value); + await UserBumpAccountRevisionDate(attachment.UserId.Value); } } } + + public async Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers) + { + if (!ciphers.Any()) + { + return; + } + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entities = Mapper.Map>(ciphers); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, entities); + await UserBumpAccountRevisionDate(userId); + } + } + + public async Task UpdatePartialAsync(Guid id, Guid userId, Guid? folderId, bool favorite) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var cipher = await dbContext.Ciphers.FindAsync(id); + + var foldersJson = JObject.Parse(cipher.Folders); + if (foldersJson == null && folderId.HasValue) + { + foldersJson.Add(userId.ToString(), folderId.Value); + } + else if (foldersJson != null && folderId.HasValue) + { + foldersJson[userId] = folderId.Value; + } + else + { + foldersJson.Remove(userId.ToString()); + } + + var favoritesJson = JObject.Parse(cipher.Favorites); + if (favorite) + { + favoritesJson.Add(userId.ToString(), favorite); + } + else + { + favoritesJson.Remove(userId.ToString()); + } + + await dbContext.SaveChangesAsync(); + await UserBumpAccountRevisionDate(userId); + } + } + + public async Task UpdateUserKeysAndCiphersAsync(User user, IEnumerable ciphers, IEnumerable folders, IEnumerable sends) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + await UserUpdateKeys(user); + var cipherEntities = Mapper.Map>(ciphers); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, cipherEntities); + var folderEntities = Mapper.Map>(folders); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, folderEntities); + var sendEntities = Mapper.Map>(sends); + await dbContext.BulkCopyAsync(base.DefaultBulkCopyOptions, sendEntities); + await dbContext.SaveChangesAsync(); + } + } + + public async Task UpsertAsync(CipherDetails cipher) + { + if (cipher.Id.Equals(default)) + { + await CreateAsync(cipher); + } + else + { + await ReplaceAsync(cipher); + } + } } diff --git a/src/Infrastructure.EntityFramework/Repositories/CollectionCipherRepository.cs b/src/Infrastructure.EntityFramework/Repositories/CollectionCipherRepository.cs index 1d717ce2e..fd23237a9 100644 --- a/src/Infrastructure.EntityFramework/Repositories/CollectionCipherRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/CollectionCipherRepository.cs @@ -6,233 +6,232 @@ using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; using CollectionCipher = Bit.Core.Entities.CollectionCipher; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class CollectionCipherRepository : BaseEntityFrameworkRepository, ICollectionCipherRepository { - public class CollectionCipherRepository : BaseEntityFrameworkRepository, ICollectionCipherRepository + public CollectionCipherRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper) + { } + + public async Task CreateAsync(CollectionCipher obj) { - public CollectionCipherRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper) - { } - - public async Task CreateAsync(CollectionCipher obj) + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var entity = Mapper.Map(obj); + dbContext.Add(entity); + await dbContext.SaveChangesAsync(); + var organizationId = (await dbContext.Ciphers.FirstOrDefaultAsync(c => c.Id.Equals(obj.CipherId))).OrganizationId; + if (organizationId.HasValue) { - var dbContext = GetDatabaseContext(scope); - var entity = Mapper.Map(obj); - dbContext.Add(entity); - await dbContext.SaveChangesAsync(); - var organizationId = (await dbContext.Ciphers.FirstOrDefaultAsync(c => c.Id.Equals(obj.CipherId))).OrganizationId; - if (organizationId.HasValue) + await UserBumpAccountRevisionDateByCollectionId(obj.CollectionId, organizationId.Value); + } + return obj; + } + } + + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var data = await (from cc in dbContext.CollectionCiphers + join c in dbContext.Collections + on cc.CollectionId equals c.Id + where c.OrganizationId == organizationId + select cc).ToArrayAsync(); + return data; + } + } + + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var data = await new CollectionCipherReadByUserIdQuery(userId) + .Run(dbContext) + .ToArrayAsync(); + return data; + } + } + + public async Task> GetManyByUserIdCipherIdAsync(Guid userId, Guid cipherId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var data = await new CollectionCipherReadByUserIdCipherIdQuery(userId, cipherId) + .Run(dbContext) + .ToArrayAsync(); + return data; + } + } + + public async Task UpdateCollectionsAsync(Guid cipherId, Guid userId, IEnumerable collectionIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var organizationId = (await dbContext.Ciphers.FindAsync(cipherId)).OrganizationId; + var availableCollectionsCte = from c in dbContext.Collections + join o in dbContext.Organizations + on c.OrganizationId equals o.Id + join ou in dbContext.OrganizationUsers + on o.Id equals ou.OrganizationId + where ou.UserId == userId + join cu in dbContext.CollectionUsers + on ou.Id equals cu.OrganizationUserId into cu_g + from cu in cu_g.DefaultIfEmpty() + where !ou.AccessAll && cu.CollectionId == c.Id + join gu in dbContext.GroupUsers + on ou.Id equals gu.OrganizationUserId into gu_g + from gu in gu_g.DefaultIfEmpty() + where cu.CollectionId == null && !ou.AccessAll + join g in dbContext.Groups + on gu.GroupId equals g.Id into g_g + from g in g_g.DefaultIfEmpty() + join cg in dbContext.CollectionGroups + on gu.GroupId equals cg.GroupId into cg_g + from cg in cg_g.DefaultIfEmpty() + where !g.AccessAll && cg.CollectionId == c.Id && + (o.Id == organizationId && o.Enabled && ou.Status == OrganizationUserStatusType.Confirmed && ( + ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly)) + select new { c, o, cu, gu, g, cg }; + var target = from cc in dbContext.CollectionCiphers + where cc.CipherId == cipherId + select new { cc.CollectionId, cc.CipherId }; + var source = collectionIds.Select(x => new { CollectionId = x, CipherId = cipherId }); + var merge1 = from t in target + join s in source + on t.CollectionId equals s.CollectionId into s_g + from s in s_g.DefaultIfEmpty() + where t.CipherId == s.CipherId + select new { t, s }; + var merge2 = from s in source + join t in target + on s.CollectionId equals t.CollectionId into t_g + from t in t_g.DefaultIfEmpty() + where t.CipherId == s.CipherId + select new { t, s }; + var union = merge1.Union(merge2).Distinct(); + var insert = union + .Where(x => x.t == null && collectionIds.Contains(x.s.CollectionId)) + .Select(x => new Models.CollectionCipher { - await UserBumpAccountRevisionDateByCollectionId(obj.CollectionId, organizationId.Value); - } - return obj; - } - } - - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var data = await (from cc in dbContext.CollectionCiphers - join c in dbContext.Collections - on cc.CollectionId equals c.Id - where c.OrganizationId == organizationId - select cc).ToArrayAsync(); - return data; - } - } - - public async Task> GetManyByUserIdAsync(Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var data = await new CollectionCipherReadByUserIdQuery(userId) - .Run(dbContext) - .ToArrayAsync(); - return data; - } - } - - public async Task> GetManyByUserIdCipherIdAsync(Guid userId, Guid cipherId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var data = await new CollectionCipherReadByUserIdCipherIdQuery(userId, cipherId) - .Run(dbContext) - .ToArrayAsync(); - return data; - } - } - - public async Task UpdateCollectionsAsync(Guid cipherId, Guid userId, IEnumerable collectionIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var organizationId = (await dbContext.Ciphers.FindAsync(cipherId)).OrganizationId; - var availableCollectionsCte = from c in dbContext.Collections - join o in dbContext.Organizations - on c.OrganizationId equals o.Id - join ou in dbContext.OrganizationUsers - on o.Id equals ou.OrganizationId - where ou.UserId == userId - join cu in dbContext.CollectionUsers - on ou.Id equals cu.OrganizationUserId into cu_g - from cu in cu_g.DefaultIfEmpty() - where !ou.AccessAll && cu.CollectionId == c.Id - join gu in dbContext.GroupUsers - on ou.Id equals gu.OrganizationUserId into gu_g - from gu in gu_g.DefaultIfEmpty() - where cu.CollectionId == null && !ou.AccessAll - join g in dbContext.Groups - on gu.GroupId equals g.Id into g_g - from g in g_g.DefaultIfEmpty() - join cg in dbContext.CollectionGroups - on gu.GroupId equals cg.GroupId into cg_g - from cg in cg_g.DefaultIfEmpty() - where !g.AccessAll && cg.CollectionId == c.Id && - (o.Id == organizationId && o.Enabled && ou.Status == OrganizationUserStatusType.Confirmed && ( - ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly)) - select new { c, o, cu, gu, g, cg }; - var target = from cc in dbContext.CollectionCiphers - where cc.CipherId == cipherId - select new { cc.CollectionId, cc.CipherId }; - var source = collectionIds.Select(x => new { CollectionId = x, CipherId = cipherId }); - var merge1 = from t in target - join s in source - on t.CollectionId equals s.CollectionId into s_g - from s in s_g.DefaultIfEmpty() - where t.CipherId == s.CipherId - select new { t, s }; - var merge2 = from s in source - join t in target - on s.CollectionId equals t.CollectionId into t_g - from t in t_g.DefaultIfEmpty() - where t.CipherId == s.CipherId - select new { t, s }; - var union = merge1.Union(merge2).Distinct(); - var insert = union - .Where(x => x.t == null && collectionIds.Contains(x.s.CollectionId)) - .Select(x => new Models.CollectionCipher - { - CollectionId = x.s.CollectionId, - CipherId = x.s.CipherId, - }); - var delete = union - .Where(x => x.s == null && x.t.CipherId == cipherId && collectionIds.Contains(x.t.CollectionId)) - .Select(x => new Models.CollectionCipher - { - CollectionId = x.t.CollectionId, - CipherId = x.t.CipherId, - }); - await dbContext.AddRangeAsync(insert); - dbContext.RemoveRange(delete); - await dbContext.SaveChangesAsync(); - - if (organizationId.HasValue) + CollectionId = x.s.CollectionId, + CipherId = x.s.CipherId, + }); + var delete = union + .Where(x => x.s == null && x.t.CipherId == cipherId && collectionIds.Contains(x.t.CollectionId)) + .Select(x => new Models.CollectionCipher { - await UserBumpAccountRevisionDateByOrganizationId(organizationId.Value); - } - } - } + CollectionId = x.t.CollectionId, + CipherId = x.t.CipherId, + }); + await dbContext.AddRangeAsync(insert); + dbContext.RemoveRange(delete); + await dbContext.SaveChangesAsync(); - public async Task UpdateCollectionsForAdminAsync(Guid cipherId, Guid organizationId, IEnumerable collectionIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) + if (organizationId.HasValue) { - var dbContext = GetDatabaseContext(scope); - var availableCollectionsCte = from c in dbContext.Collections - where c.OrganizationId == organizationId - select c; - var target = from cc in dbContext.CollectionCiphers - where cc.CipherId == cipherId - select new { cc.CollectionId, cc.CipherId }; - var source = collectionIds.Select(x => new { CollectionId = x, CipherId = cipherId }); - var merge1 = from t in target - join s in source - on t.CollectionId equals s.CollectionId into s_g - from s in s_g.DefaultIfEmpty() - where t.CipherId == s.CipherId - select new { t, s }; - var merge2 = from s in source - join t in target - on s.CollectionId equals t.CollectionId into t_g - from t in t_g.DefaultIfEmpty() - where t.CipherId == s.CipherId - select new { t, s }; - var union = merge1.Union(merge2).Distinct(); - var insert = union - .Where(x => x.t == null && collectionIds.Contains(x.s.CollectionId)) - .Select(x => new Models.CollectionCipher - { - CollectionId = x.s.CollectionId, - CipherId = x.s.CipherId, - }); - var delete = union - .Where(x => x.s == null && x.t.CipherId == cipherId) - .Select(x => new Models.CollectionCipher - { - CollectionId = x.t.CollectionId, - CipherId = x.t.CipherId, - }); - await dbContext.AddRangeAsync(insert); - dbContext.RemoveRange(delete); - await dbContext.SaveChangesAsync(); - await UserBumpAccountRevisionDateByOrganizationId(organizationId); - } - } - - public async Task UpdateCollectionsForCiphersAsync(IEnumerable cipherIds, Guid userId, Guid organizationId, IEnumerable collectionIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var availibleCollections = from c in dbContext.Collections - join o in dbContext.Organizations - on c.OrganizationId equals o.Id - join ou in dbContext.OrganizationUsers - on o.Id equals ou.OrganizationId - where ou.UserId == userId - join cu in dbContext.CollectionUsers - on ou.Id equals cu.OrganizationUserId into cu_g - from cu in cu_g.DefaultIfEmpty() - where !ou.AccessAll && cu.CollectionId == c.Id - join gu in dbContext.GroupUsers - on ou.Id equals gu.OrganizationUserId into gu_g - from gu in gu_g.DefaultIfEmpty() - where cu.CollectionId == null && !ou.AccessAll - join g in dbContext.Groups - on gu.GroupId equals g.Id into g_g - from g in g_g.DefaultIfEmpty() - join cg in dbContext.CollectionGroups - on gu.GroupId equals cg.GroupId into cg_g - from cg in cg_g.DefaultIfEmpty() - where !g.AccessAll && cg.CollectionId == c.Id && - (o.Id == organizationId && o.Enabled && ou.Status == OrganizationUserStatusType.Confirmed && - (ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly)) - select new { c, o, ou, cu, gu, g, cg }; - var count = await availibleCollections.CountAsync(); - if (await availibleCollections.CountAsync() < 1) - { - return; - } - - var insertData = from collectionId in collectionIds - from cipherId in cipherIds - where availibleCollections.Select(x => x.c.Id).Contains(collectionId) - select new Models.CollectionCipher - { - CollectionId = collectionId, - CipherId = cipherId, - }; - await dbContext.AddRangeAsync(insertData); - await UserBumpAccountRevisionDateByOrganizationId(organizationId); + await UserBumpAccountRevisionDateByOrganizationId(organizationId.Value); } } } + + public async Task UpdateCollectionsForAdminAsync(Guid cipherId, Guid organizationId, IEnumerable collectionIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var availableCollectionsCte = from c in dbContext.Collections + where c.OrganizationId == organizationId + select c; + var target = from cc in dbContext.CollectionCiphers + where cc.CipherId == cipherId + select new { cc.CollectionId, cc.CipherId }; + var source = collectionIds.Select(x => new { CollectionId = x, CipherId = cipherId }); + var merge1 = from t in target + join s in source + on t.CollectionId equals s.CollectionId into s_g + from s in s_g.DefaultIfEmpty() + where t.CipherId == s.CipherId + select new { t, s }; + var merge2 = from s in source + join t in target + on s.CollectionId equals t.CollectionId into t_g + from t in t_g.DefaultIfEmpty() + where t.CipherId == s.CipherId + select new { t, s }; + var union = merge1.Union(merge2).Distinct(); + var insert = union + .Where(x => x.t == null && collectionIds.Contains(x.s.CollectionId)) + .Select(x => new Models.CollectionCipher + { + CollectionId = x.s.CollectionId, + CipherId = x.s.CipherId, + }); + var delete = union + .Where(x => x.s == null && x.t.CipherId == cipherId) + .Select(x => new Models.CollectionCipher + { + CollectionId = x.t.CollectionId, + CipherId = x.t.CipherId, + }); + await dbContext.AddRangeAsync(insert); + dbContext.RemoveRange(delete); + await dbContext.SaveChangesAsync(); + await UserBumpAccountRevisionDateByOrganizationId(organizationId); + } + } + + public async Task UpdateCollectionsForCiphersAsync(IEnumerable cipherIds, Guid userId, Guid organizationId, IEnumerable collectionIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var availibleCollections = from c in dbContext.Collections + join o in dbContext.Organizations + on c.OrganizationId equals o.Id + join ou in dbContext.OrganizationUsers + on o.Id equals ou.OrganizationId + where ou.UserId == userId + join cu in dbContext.CollectionUsers + on ou.Id equals cu.OrganizationUserId into cu_g + from cu in cu_g.DefaultIfEmpty() + where !ou.AccessAll && cu.CollectionId == c.Id + join gu in dbContext.GroupUsers + on ou.Id equals gu.OrganizationUserId into gu_g + from gu in gu_g.DefaultIfEmpty() + where cu.CollectionId == null && !ou.AccessAll + join g in dbContext.Groups + on gu.GroupId equals g.Id into g_g + from g in g_g.DefaultIfEmpty() + join cg in dbContext.CollectionGroups + on gu.GroupId equals cg.GroupId into cg_g + from cg in cg_g.DefaultIfEmpty() + where !g.AccessAll && cg.CollectionId == c.Id && + (o.Id == organizationId && o.Enabled && ou.Status == OrganizationUserStatusType.Confirmed && + (ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly)) + select new { c, o, ou, cu, gu, g, cg }; + var count = await availibleCollections.CountAsync(); + if (await availibleCollections.CountAsync() < 1) + { + return; + } + + var insertData = from collectionId in collectionIds + from cipherId in cipherIds + where availibleCollections.Select(x => x.c.Id).Contains(collectionId) + select new Models.CollectionCipher + { + CollectionId = collectionId, + CipherId = cipherId, + }; + await dbContext.AddRangeAsync(insertData); + await UserBumpAccountRevisionDateByOrganizationId(organizationId); + } + } } diff --git a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs index 74d714bb1..d8338b470 100644 --- a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs @@ -6,245 +6,244 @@ using Bit.Infrastructure.EntityFramework.Repositories.Queries; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class CollectionRepository : Repository, ICollectionRepository { - public class CollectionRepository : Repository, ICollectionRepository + public CollectionRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Collections) + { } + + public override async Task CreateAsync(Core.Entities.Collection obj) { - public CollectionRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Collections) - { } + await base.CreateAsync(obj); + await UserBumpAccountRevisionDateByCollectionId(obj.Id, obj.OrganizationId); + return obj; + } - public override async Task CreateAsync(Core.Entities.Collection obj) + public async Task CreateAsync(Core.Entities.Collection obj, IEnumerable groups) + { + await base.CreateAsync(obj); + using (var scope = ServiceScopeFactory.CreateScope()) { - await base.CreateAsync(obj); - await UserBumpAccountRevisionDateByCollectionId(obj.Id, obj.OrganizationId); - return obj; - } - - public async Task CreateAsync(Core.Entities.Collection obj, IEnumerable groups) - { - await base.CreateAsync(obj); - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var availibleGroups = await (from g in dbContext.Groups - where g.OrganizationId == obj.OrganizationId - select g.Id).ToListAsync(); - var collectionGroups = groups - .Where(g => availibleGroups.Contains(g.Id)) - .Select(g => new CollectionGroup - { - CollectionId = obj.Id, - GroupId = g.Id, - ReadOnly = g.ReadOnly, - HidePasswords = g.HidePasswords, - }); - await dbContext.AddRangeAsync(collectionGroups); - await dbContext.SaveChangesAsync(); - await UserBumpAccountRevisionDateByOrganizationId(obj.OrganizationId); - } - } - - public async Task DeleteUserAsync(Guid collectionId, Guid organizationUserId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from cu in dbContext.CollectionUsers - where cu.CollectionId == collectionId && - cu.OrganizationUserId == organizationUserId - select cu; - dbContext.RemoveRange(await query.ToListAsync()); - await dbContext.SaveChangesAsync(); - await UserBumpAccountRevisionDateByOrganizationUserId(organizationUserId); - } - } - - public async Task GetByIdAsync(Guid id, Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - return (await GetManyByUserIdAsync(userId)).FirstOrDefault(c => c.Id == id); - } - } - - public async Task>> GetByIdWithGroupsAsync(Guid id) - { - var collection = await base.GetByIdAsync(id); - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var collectionGroups = await (from cg in dbContext.CollectionGroups - where cg.CollectionId == id - select cg).ToListAsync(); - var selectionReadOnlys = collectionGroups.Select(cg => new SelectionReadOnly + var dbContext = GetDatabaseContext(scope); + var availibleGroups = await (from g in dbContext.Groups + where g.OrganizationId == obj.OrganizationId + select g.Id).ToListAsync(); + var collectionGroups = groups + .Where(g => availibleGroups.Contains(g.Id)) + .Select(g => new CollectionGroup { - Id = cg.GroupId, - ReadOnly = cg.ReadOnly, - HidePasswords = cg.HidePasswords, + CollectionId = obj.Id, + GroupId = g.Id, + ReadOnly = g.ReadOnly, + HidePasswords = g.HidePasswords, + }); + await dbContext.AddRangeAsync(collectionGroups); + await dbContext.SaveChangesAsync(); + await UserBumpAccountRevisionDateByOrganizationId(obj.OrganizationId); + } + } + + public async Task DeleteUserAsync(Guid collectionId, Guid organizationUserId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from cu in dbContext.CollectionUsers + where cu.CollectionId == collectionId && + cu.OrganizationUserId == organizationUserId + select cu; + dbContext.RemoveRange(await query.ToListAsync()); + await dbContext.SaveChangesAsync(); + await UserBumpAccountRevisionDateByOrganizationUserId(organizationUserId); + } + } + + public async Task GetByIdAsync(Guid id, Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + return (await GetManyByUserIdAsync(userId)).FirstOrDefault(c => c.Id == id); + } + } + + public async Task>> GetByIdWithGroupsAsync(Guid id) + { + var collection = await base.GetByIdAsync(id); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var collectionGroups = await (from cg in dbContext.CollectionGroups + where cg.CollectionId == id + select cg).ToListAsync(); + var selectionReadOnlys = collectionGroups.Select(cg => new SelectionReadOnly + { + Id = cg.GroupId, + ReadOnly = cg.ReadOnly, + HidePasswords = cg.HidePasswords, + }).ToList(); + return new Tuple>(collection, selectionReadOnlys); + } + } + + public async Task>> GetByIdWithGroupsAsync(Guid id, Guid userId) + { + var collection = await GetByIdAsync(id, userId); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from cg in dbContext.CollectionGroups + where cg.CollectionId.Equals(id) + select new SelectionReadOnly + { + Id = cg.GroupId, + ReadOnly = cg.ReadOnly, + HidePasswords = cg.HidePasswords, + }; + var configurations = await query.ToArrayAsync(); + return new Tuple>(collection, configurations); + } + } + + public async Task GetCountByOrganizationIdAsync(Guid organizationId) + { + var query = new CollectionReadCountByOrganizationIdQuery(organizationId); + return await GetCountFromQuery(query); + } + + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from c in dbContext.Collections + where c.OrganizationId == organizationId + select c; + var collections = await query.ToArrayAsync(); + return collections; + } + } + + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + return (await new UserCollectionDetailsQuery(userId).Run(dbContext).ToListAsync()) + .GroupBy(c => c.Id) + .Select(g => new CollectionDetails + { + Id = g.Key, + OrganizationId = g.FirstOrDefault().OrganizationId, + Name = g.FirstOrDefault().Name, + ExternalId = g.FirstOrDefault().ExternalId, + CreationDate = g.FirstOrDefault().CreationDate, + RevisionDate = g.FirstOrDefault().RevisionDate, + ReadOnly = g.Min(c => c.ReadOnly), + HidePasswords = g.Min(c => c.HidePasswords) }).ToList(); - return new Tuple>(collection, selectionReadOnlys); - } } + } - public async Task>> GetByIdWithGroupsAsync(Guid id, Guid userId) + public async Task> GetManyUsersByIdAsync(Guid id) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - var collection = await GetByIdAsync(id, userId); - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var query = from cu in dbContext.CollectionUsers + where cu.CollectionId == id + select cu; + var collectionUsers = await query.ToListAsync(); + return collectionUsers.Select(cu => new SelectionReadOnly { - var dbContext = GetDatabaseContext(scope); - var query = from cg in dbContext.CollectionGroups - where cg.CollectionId.Equals(id) - select new SelectionReadOnly - { - Id = cg.GroupId, - ReadOnly = cg.ReadOnly, - HidePasswords = cg.HidePasswords, - }; - var configurations = await query.ToArrayAsync(); - return new Tuple>(collection, configurations); - } + Id = cu.OrganizationUserId, + ReadOnly = cu.ReadOnly, + HidePasswords = cu.HidePasswords, + }).ToArray(); } + } - public async Task GetCountByOrganizationIdAsync(Guid organizationId) + public async Task ReplaceAsync(Core.Entities.Collection collection, IEnumerable groups) + { + await base.ReplaceAsync(collection); + using (var scope = ServiceScopeFactory.CreateScope()) { - var query = new CollectionReadCountByOrganizationIdQuery(organizationId); - return await GetCountFromQuery(query); - } - - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from c in dbContext.Collections - where c.OrganizationId == organizationId - select c; - var collections = await query.ToArrayAsync(); - return collections; - } - } - - public async Task> GetManyByUserIdAsync(Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - return (await new UserCollectionDetailsQuery(userId).Run(dbContext).ToListAsync()) - .GroupBy(c => c.Id) - .Select(g => new CollectionDetails - { - Id = g.Key, - OrganizationId = g.FirstOrDefault().OrganizationId, - Name = g.FirstOrDefault().Name, - ExternalId = g.FirstOrDefault().ExternalId, - CreationDate = g.FirstOrDefault().CreationDate, - RevisionDate = g.FirstOrDefault().RevisionDate, - ReadOnly = g.Min(c => c.ReadOnly), - HidePasswords = g.Min(c => c.HidePasswords) - }).ToList(); - } - } - - public async Task> GetManyUsersByIdAsync(Guid id) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from cu in dbContext.CollectionUsers - where cu.CollectionId == id - select cu; - var collectionUsers = await query.ToListAsync(); - return collectionUsers.Select(cu => new SelectionReadOnly + var dbContext = GetDatabaseContext(scope); + var groupsInOrg = dbContext.Groups.Where(g => g.OrganizationId == collection.OrganizationId); + var modifiedGroupEntities = dbContext.Groups.Where(x => groups.Select(x => x.Id).Contains(x.Id)); + var target = (from cg in dbContext.CollectionGroups + join g in modifiedGroupEntities + on cg.CollectionId equals collection.Id into s_g + from g in s_g.DefaultIfEmpty() + where g == null || cg.GroupId == g.Id + select new { cg, g }).AsNoTracking(); + var source = (from g in modifiedGroupEntities + from cg in dbContext.CollectionGroups + .Where(cg => cg.CollectionId == collection.Id && cg.GroupId == g.Id).DefaultIfEmpty() + select new { cg, g }).AsNoTracking(); + var union = await target + .Union(source) + .Where(x => + x.cg == null || + ((x.g == null || x.g.Id == x.cg.GroupId) && + (x.cg.CollectionId == collection.Id))) + .AsNoTracking() + .ToListAsync(); + var insert = union.Where(x => x.cg == null && groupsInOrg.Any(c => x.g.Id == c.Id)) + .Select(x => new CollectionGroup { - Id = cu.OrganizationUserId, - ReadOnly = cu.ReadOnly, - HidePasswords = cu.HidePasswords, - }).ToArray(); - } + CollectionId = collection.Id, + GroupId = x.g.Id, + ReadOnly = groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly, + HidePasswords = groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords, + }).ToList(); + var update = union + .Where( + x => x.g != null && + x.cg != null && + (x.cg.ReadOnly != groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly || + x.cg.HidePasswords != groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords) + ) + .Select(x => new CollectionGroup + { + CollectionId = collection.Id, + GroupId = x.g.Id, + ReadOnly = groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly, + HidePasswords = groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords, + }); + var delete = union + .Where( + x => x.g == null && + x.cg.CollectionId == collection.Id + ) + .Select(x => new CollectionGroup + { + CollectionId = collection.Id, + GroupId = x.cg.GroupId, + }) + .ToList(); + + await dbContext.AddRangeAsync(insert); + dbContext.UpdateRange(update); + dbContext.RemoveRange(delete); + await dbContext.SaveChangesAsync(); + await UserBumpAccountRevisionDateByCollectionId(collection.Id, collection.OrganizationId); } + } - public async Task ReplaceAsync(Core.Entities.Collection collection, IEnumerable groups) + public async Task UpdateUsersAsync(Guid id, IEnumerable users) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - await base.ReplaceAsync(collection); - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var groupsInOrg = dbContext.Groups.Where(g => g.OrganizationId == collection.OrganizationId); - var modifiedGroupEntities = dbContext.Groups.Where(x => groups.Select(x => x.Id).Contains(x.Id)); - var target = (from cg in dbContext.CollectionGroups - join g in modifiedGroupEntities - on cg.CollectionId equals collection.Id into s_g - from g in s_g.DefaultIfEmpty() - where g == null || cg.GroupId == g.Id - select new { cg, g }).AsNoTracking(); - var source = (from g in modifiedGroupEntities - from cg in dbContext.CollectionGroups - .Where(cg => cg.CollectionId == collection.Id && cg.GroupId == g.Id).DefaultIfEmpty() - select new { cg, g }).AsNoTracking(); - var union = await target - .Union(source) - .Where(x => - x.cg == null || - ((x.g == null || x.g.Id == x.cg.GroupId) && - (x.cg.CollectionId == collection.Id))) - .AsNoTracking() - .ToListAsync(); - var insert = union.Where(x => x.cg == null && groupsInOrg.Any(c => x.g.Id == c.Id)) - .Select(x => new CollectionGroup - { - CollectionId = collection.Id, - GroupId = x.g.Id, - ReadOnly = groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly, - HidePasswords = groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords, - }).ToList(); - var update = union - .Where( - x => x.g != null && - x.cg != null && - (x.cg.ReadOnly != groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly || - x.cg.HidePasswords != groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords) - ) - .Select(x => new CollectionGroup - { - CollectionId = collection.Id, - GroupId = x.g.Id, - ReadOnly = groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly, - HidePasswords = groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords, - }); - var delete = union - .Where( - x => x.g == null && - x.cg.CollectionId == collection.Id - ) - .Select(x => new CollectionGroup - { - CollectionId = collection.Id, - GroupId = x.cg.GroupId, - }) - .ToList(); - - await dbContext.AddRangeAsync(insert); - dbContext.UpdateRange(update); - dbContext.RemoveRange(delete); - await dbContext.SaveChangesAsync(); - await UserBumpAccountRevisionDateByCollectionId(collection.Id, collection.OrganizationId); - } - } - - public async Task UpdateUsersAsync(Guid id, IEnumerable users) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var procedure = new CollectionUserUpdateUsersQuery(id, users); - var updateData = await procedure.Update.BuildInMemory(dbContext); - dbContext.UpdateRange(updateData); - var insertData = await procedure.Insert.BuildInMemory(dbContext); - await dbContext.AddRangeAsync(insertData); - dbContext.RemoveRange(await procedure.Delete.Run(dbContext).ToListAsync()); - } + var dbContext = GetDatabaseContext(scope); + var procedure = new CollectionUserUpdateUsersQuery(id, users); + var updateData = await procedure.Update.BuildInMemory(dbContext); + dbContext.UpdateRange(updateData); + var insertData = await procedure.Insert.BuildInMemory(dbContext); + await dbContext.AddRangeAsync(insertData); + dbContext.RemoveRange(await procedure.Delete.Run(dbContext).ToListAsync()); } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs index 8d3af7a4f..88c2bb464 100644 --- a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs +++ b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs @@ -1,140 +1,139 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class DatabaseContext : DbContext { - public class DatabaseContext : DbContext + public const string postgresIndetermanisticCollation = "postgresIndetermanisticCollation"; + + public DatabaseContext(DbContextOptions options) + : base(options) + { } + + public DbSet Ciphers { get; set; } + public DbSet Collections { get; set; } + public DbSet CollectionCiphers { get; set; } + public DbSet CollectionGroups { get; set; } + public DbSet CollectionUsers { get; set; } + public DbSet Devices { get; set; } + public DbSet EmergencyAccesses { get; set; } + public DbSet Events { get; set; } + public DbSet Folders { get; set; } + public DbSet Grants { get; set; } + public DbSet Groups { get; set; } + public DbSet GroupUsers { get; set; } + public DbSet Installations { get; set; } + public DbSet Organizations { get; set; } + public DbSet OrganizationApiKeys { get; set; } + public DbSet OrganizationSponsorships { get; set; } + public DbSet OrganizationConnections { get; set; } + public DbSet OrganizationUsers { get; set; } + public DbSet Policies { get; set; } + public DbSet Providers { get; set; } + public DbSet ProviderUsers { get; set; } + public DbSet ProviderOrganizations { get; set; } + public DbSet Sends { get; set; } + public DbSet SsoConfigs { get; set; } + public DbSet SsoUsers { get; set; } + public DbSet TaxRates { get; set; } + public DbSet Transactions { get; set; } + public DbSet Users { get; set; } + + protected override void OnModelCreating(ModelBuilder builder) { - public const string postgresIndetermanisticCollation = "postgresIndetermanisticCollation"; + var eCipher = builder.Entity(); + var eCollection = builder.Entity(); + var eCollectionCipher = builder.Entity(); + var eCollectionUser = builder.Entity(); + var eCollectionGroup = builder.Entity(); + var eDevice = builder.Entity(); + var eEmergencyAccess = builder.Entity(); + var eEvent = builder.Entity(); + var eFolder = builder.Entity(); + var eGrant = builder.Entity(); + var eGroup = builder.Entity(); + var eGroupUser = builder.Entity(); + var eInstallation = builder.Entity(); + var eOrganization = builder.Entity(); + var eOrganizationSponsorship = builder.Entity(); + var eOrganizationUser = builder.Entity(); + var ePolicy = builder.Entity(); + var eProvider = builder.Entity(); + var eProviderUser = builder.Entity(); + var eProviderOrganization = builder.Entity(); + var eSend = builder.Entity(); + var eSsoConfig = builder.Entity(); + var eSsoUser = builder.Entity(); + var eTaxRate = builder.Entity(); + var eTransaction = builder.Entity(); + var eUser = builder.Entity(); + var eOrganizationApiKey = builder.Entity(); + var eOrganizationConnection = builder.Entity(); - public DatabaseContext(DbContextOptions options) - : base(options) - { } + eCipher.Property(c => c.Id).ValueGeneratedNever(); + eCollection.Property(c => c.Id).ValueGeneratedNever(); + eEmergencyAccess.Property(c => c.Id).ValueGeneratedNever(); + eEvent.Property(c => c.Id).ValueGeneratedNever(); + eFolder.Property(c => c.Id).ValueGeneratedNever(); + eGroup.Property(c => c.Id).ValueGeneratedNever(); + eInstallation.Property(c => c.Id).ValueGeneratedNever(); + eOrganization.Property(c => c.Id).ValueGeneratedNever(); + eOrganizationSponsorship.Property(c => c.Id).ValueGeneratedNever(); + eOrganizationUser.Property(c => c.Id).ValueGeneratedNever(); + ePolicy.Property(c => c.Id).ValueGeneratedNever(); + eProvider.Property(c => c.Id).ValueGeneratedNever(); + eProviderUser.Property(c => c.Id).ValueGeneratedNever(); + eProviderOrganization.Property(c => c.Id).ValueGeneratedNever(); + eSend.Property(c => c.Id).ValueGeneratedNever(); + eTransaction.Property(c => c.Id).ValueGeneratedNever(); + eUser.Property(c => c.Id).ValueGeneratedNever(); + eOrganizationApiKey.Property(c => c.Id).ValueGeneratedNever(); + eOrganizationConnection.Property(c => c.Id).ValueGeneratedNever(); - public DbSet Ciphers { get; set; } - public DbSet Collections { get; set; } - public DbSet CollectionCiphers { get; set; } - public DbSet CollectionGroups { get; set; } - public DbSet CollectionUsers { get; set; } - public DbSet Devices { get; set; } - public DbSet EmergencyAccesses { get; set; } - public DbSet Events { get; set; } - public DbSet Folders { get; set; } - public DbSet Grants { get; set; } - public DbSet Groups { get; set; } - public DbSet GroupUsers { get; set; } - public DbSet Installations { get; set; } - public DbSet Organizations { get; set; } - public DbSet OrganizationApiKeys { get; set; } - public DbSet OrganizationSponsorships { get; set; } - public DbSet OrganizationConnections { get; set; } - public DbSet OrganizationUsers { get; set; } - public DbSet Policies { get; set; } - public DbSet Providers { get; set; } - public DbSet ProviderUsers { get; set; } - public DbSet ProviderOrganizations { get; set; } - public DbSet Sends { get; set; } - public DbSet SsoConfigs { get; set; } - public DbSet SsoUsers { get; set; } - public DbSet TaxRates { get; set; } - public DbSet Transactions { get; set; } - public DbSet Users { get; set; } + eCollectionCipher.HasKey(cc => new { cc.CollectionId, cc.CipherId }); + eCollectionUser.HasKey(cu => new { cu.CollectionId, cu.OrganizationUserId }); + eCollectionGroup.HasKey(cg => new { cg.CollectionId, cg.GroupId }); + eGrant.HasKey(x => x.Key); + eGroupUser.HasKey(gu => new { gu.GroupId, gu.OrganizationUserId }); - protected override void OnModelCreating(ModelBuilder builder) + + if (Database.IsNpgsql()) { - var eCipher = builder.Entity(); - var eCollection = builder.Entity(); - var eCollectionCipher = builder.Entity(); - var eCollectionUser = builder.Entity(); - var eCollectionGroup = builder.Entity(); - var eDevice = builder.Entity(); - var eEmergencyAccess = builder.Entity(); - var eEvent = builder.Entity(); - var eFolder = builder.Entity(); - var eGrant = builder.Entity(); - var eGroup = builder.Entity(); - var eGroupUser = builder.Entity(); - var eInstallation = builder.Entity(); - var eOrganization = builder.Entity(); - var eOrganizationSponsorship = builder.Entity(); - var eOrganizationUser = builder.Entity(); - var ePolicy = builder.Entity(); - var eProvider = builder.Entity(); - var eProviderUser = builder.Entity(); - var eProviderOrganization = builder.Entity(); - var eSend = builder.Entity(); - var eSsoConfig = builder.Entity(); - var eSsoUser = builder.Entity(); - var eTaxRate = builder.Entity(); - var eTransaction = builder.Entity(); - var eUser = builder.Entity(); - var eOrganizationApiKey = builder.Entity(); - var eOrganizationConnection = builder.Entity(); - - eCipher.Property(c => c.Id).ValueGeneratedNever(); - eCollection.Property(c => c.Id).ValueGeneratedNever(); - eEmergencyAccess.Property(c => c.Id).ValueGeneratedNever(); - eEvent.Property(c => c.Id).ValueGeneratedNever(); - eFolder.Property(c => c.Id).ValueGeneratedNever(); - eGroup.Property(c => c.Id).ValueGeneratedNever(); - eInstallation.Property(c => c.Id).ValueGeneratedNever(); - eOrganization.Property(c => c.Id).ValueGeneratedNever(); - eOrganizationSponsorship.Property(c => c.Id).ValueGeneratedNever(); - eOrganizationUser.Property(c => c.Id).ValueGeneratedNever(); - ePolicy.Property(c => c.Id).ValueGeneratedNever(); - eProvider.Property(c => c.Id).ValueGeneratedNever(); - eProviderUser.Property(c => c.Id).ValueGeneratedNever(); - eProviderOrganization.Property(c => c.Id).ValueGeneratedNever(); - eSend.Property(c => c.Id).ValueGeneratedNever(); - eTransaction.Property(c => c.Id).ValueGeneratedNever(); - eUser.Property(c => c.Id).ValueGeneratedNever(); - eOrganizationApiKey.Property(c => c.Id).ValueGeneratedNever(); - eOrganizationConnection.Property(c => c.Id).ValueGeneratedNever(); - - eCollectionCipher.HasKey(cc => new { cc.CollectionId, cc.CipherId }); - eCollectionUser.HasKey(cu => new { cu.CollectionId, cu.OrganizationUserId }); - eCollectionGroup.HasKey(cg => new { cg.CollectionId, cg.GroupId }); - eGrant.HasKey(x => x.Key); - eGroupUser.HasKey(gu => new { gu.GroupId, gu.OrganizationUserId }); - - - if (Database.IsNpgsql()) - { - // the postgres provider doesn't currently support database level non-deterministic collations. - // see https://www.npgsql.org/efcore/misc/collations-and-case-sensitivity.html#database-collation - builder.HasCollation(postgresIndetermanisticCollation, locale: "en-u-ks-primary", provider: "icu", deterministic: false); - eUser.Property(e => e.Email).UseCollation(postgresIndetermanisticCollation); - eSsoUser.Property(e => e.ExternalId).UseCollation(postgresIndetermanisticCollation); - eOrganization.Property(e => e.Identifier).UseCollation(postgresIndetermanisticCollation); - // - } - - eCipher.ToTable(nameof(Cipher)); - eCollection.ToTable(nameof(Collection)); - eCollectionCipher.ToTable(nameof(CollectionCipher)); - eDevice.ToTable(nameof(Device)); - eEmergencyAccess.ToTable(nameof(EmergencyAccess)); - eEvent.ToTable(nameof(Event)); - eFolder.ToTable(nameof(Folder)); - eGrant.ToTable(nameof(Grant)); - eGroup.ToTable(nameof(Group)); - eGroupUser.ToTable(nameof(GroupUser)); - eInstallation.ToTable(nameof(Installation)); - eOrganization.ToTable(nameof(Organization)); - eOrganizationSponsorship.ToTable(nameof(OrganizationSponsorship)); - eOrganizationUser.ToTable(nameof(OrganizationUser)); - ePolicy.ToTable(nameof(Policy)); - eProvider.ToTable(nameof(Provider)); - eProviderUser.ToTable(nameof(ProviderUser)); - eProviderOrganization.ToTable(nameof(ProviderOrganization)); - eSend.ToTable(nameof(Send)); - eSsoConfig.ToTable(nameof(SsoConfig)); - eSsoUser.ToTable(nameof(SsoUser)); - eTaxRate.ToTable(nameof(TaxRate)); - eTransaction.ToTable(nameof(Transaction)); - eUser.ToTable(nameof(User)); - eOrganizationApiKey.ToTable(nameof(OrganizationApiKey)); - eOrganizationConnection.ToTable(nameof(OrganizationConnection)); + // the postgres provider doesn't currently support database level non-deterministic collations. + // see https://www.npgsql.org/efcore/misc/collations-and-case-sensitivity.html#database-collation + builder.HasCollation(postgresIndetermanisticCollation, locale: "en-u-ks-primary", provider: "icu", deterministic: false); + eUser.Property(e => e.Email).UseCollation(postgresIndetermanisticCollation); + eSsoUser.Property(e => e.ExternalId).UseCollation(postgresIndetermanisticCollation); + eOrganization.Property(e => e.Identifier).UseCollation(postgresIndetermanisticCollation); + // } + + eCipher.ToTable(nameof(Cipher)); + eCollection.ToTable(nameof(Collection)); + eCollectionCipher.ToTable(nameof(CollectionCipher)); + eDevice.ToTable(nameof(Device)); + eEmergencyAccess.ToTable(nameof(EmergencyAccess)); + eEvent.ToTable(nameof(Event)); + eFolder.ToTable(nameof(Folder)); + eGrant.ToTable(nameof(Grant)); + eGroup.ToTable(nameof(Group)); + eGroupUser.ToTable(nameof(GroupUser)); + eInstallation.ToTable(nameof(Installation)); + eOrganization.ToTable(nameof(Organization)); + eOrganizationSponsorship.ToTable(nameof(OrganizationSponsorship)); + eOrganizationUser.ToTable(nameof(OrganizationUser)); + ePolicy.ToTable(nameof(Policy)); + eProvider.ToTable(nameof(Provider)); + eProviderUser.ToTable(nameof(ProviderUser)); + eProviderOrganization.ToTable(nameof(ProviderOrganization)); + eSend.ToTable(nameof(Send)); + eSsoConfig.ToTable(nameof(SsoConfig)); + eSsoUser.ToTable(nameof(SsoUser)); + eTaxRate.ToTable(nameof(TaxRate)); + eTransaction.ToTable(nameof(Transaction)); + eUser.ToTable(nameof(User)); + eOrganizationApiKey.ToTable(nameof(OrganizationApiKey)); + eOrganizationConnection.ToTable(nameof(OrganizationConnection)); } } diff --git a/src/Infrastructure.EntityFramework/Repositories/DeviceRepository.cs b/src/Infrastructure.EntityFramework/Repositories/DeviceRepository.cs index 79ad60818..cc664aa1b 100644 --- a/src/Infrastructure.EntityFramework/Repositories/DeviceRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/DeviceRepository.cs @@ -4,68 +4,67 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories -{ - public class DeviceRepository : Repository, IDeviceRepository - { - public DeviceRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Devices) - { } +namespace Bit.Infrastructure.EntityFramework.Repositories; - public async Task ClearPushTokenAsync(Guid id) +public class DeviceRepository : Repository, IDeviceRepository +{ + public DeviceRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Devices) + { } + + public async Task ClearPushTokenAsync(Guid id) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = dbContext.Devices.Where(d => d.Id == id); - dbContext.AttachRange(query); - await query.ForEachAsync(x => x.PushToken = null); - await dbContext.SaveChangesAsync(); - } + var dbContext = GetDatabaseContext(scope); + var query = dbContext.Devices.Where(d => d.Id == id); + dbContext.AttachRange(query); + await query.ForEachAsync(x => x.PushToken = null); + await dbContext.SaveChangesAsync(); + } + } + + public async Task GetByIdAsync(Guid id, Guid userId) + { + var device = await base.GetByIdAsync(id); + if (device == null || device.UserId != userId) + { + return null; } - public async Task GetByIdAsync(Guid id, Guid userId) - { - var device = await base.GetByIdAsync(id); - if (device == null || device.UserId != userId) - { - return null; - } + return Mapper.Map(device); + } + public async Task GetByIdentifierAsync(string identifier) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = dbContext.Devices.Where(d => d.Identifier == identifier); + var device = await query.FirstOrDefaultAsync(); return Mapper.Map(device); } + } - public async Task GetByIdentifierAsync(string identifier) + public async Task GetByIdentifierAsync(string identifier, Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = dbContext.Devices.Where(d => d.Identifier == identifier); - var device = await query.FirstOrDefaultAsync(); - return Mapper.Map(device); - } + var dbContext = GetDatabaseContext(scope); + var query = dbContext.Devices.Where(d => d.Identifier == identifier && d.UserId == userId); + var device = await query.FirstOrDefaultAsync(); + return Mapper.Map(device); } + } - public async Task GetByIdentifierAsync(string identifier, Guid userId) + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = dbContext.Devices.Where(d => d.Identifier == identifier && d.UserId == userId); - var device = await query.FirstOrDefaultAsync(); - return Mapper.Map(device); - } - } - - public async Task> GetManyByUserIdAsync(Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = dbContext.Devices.Where(d => d.UserId == userId); - var devices = await query.ToListAsync(); - return Mapper.Map>(devices); - } + var dbContext = GetDatabaseContext(scope); + var query = dbContext.Devices.Where(d => d.UserId == userId); + var devices = await query.ToListAsync(); + return Mapper.Map>(devices); } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/EmergencyAccessRepository.cs b/src/Infrastructure.EntityFramework/Repositories/EmergencyAccessRepository.cs index 4ace88560..028ce222f 100644 --- a/src/Infrastructure.EntityFramework/Repositories/EmergencyAccessRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/EmergencyAccessRepository.cs @@ -7,102 +7,101 @@ using Bit.Infrastructure.EntityFramework.Repositories.Queries; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class EmergencyAccessRepository : Repository, IEmergencyAccessRepository { - public class EmergencyAccessRepository : Repository, IEmergencyAccessRepository + public EmergencyAccessRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.EmergencyAccesses) + { } + + public async Task GetCountByGrantorIdEmailAsync(Guid grantorId, string email, bool onlyRegisteredUsers) { - public EmergencyAccessRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.EmergencyAccesses) - { } + var query = new EmergencyAccessReadCountByGrantorIdEmailQuery(grantorId, email, onlyRegisteredUsers); + return await GetCountFromQuery(query); + } - public async Task GetCountByGrantorIdEmailAsync(Guid grantorId, string email, bool onlyRegisteredUsers) + public async Task GetDetailsByIdGrantorIdAsync(Guid id, Guid grantorId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - var query = new EmergencyAccessReadCountByGrantorIdEmailQuery(grantorId, email, onlyRegisteredUsers); - return await GetCountFromQuery(query); + var dbContext = GetDatabaseContext(scope); + var view = new EmergencyAccessDetailsViewQuery(); + var query = view.Run(dbContext).Where(ea => + ea.Id == id && + ea.GrantorId == grantorId + ); + return await query.FirstOrDefaultAsync(); } + } - public async Task GetDetailsByIdGrantorIdAsync(Guid id, Guid grantorId) + public async Task> GetExpiredRecoveriesAsync() + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var view = new EmergencyAccessDetailsViewQuery(); - var query = view.Run(dbContext).Where(ea => - ea.Id == id && - ea.GrantorId == grantorId - ); - return await query.FirstOrDefaultAsync(); - } + var dbContext = GetDatabaseContext(scope); + var view = new EmergencyAccessDetailsViewQuery(); + var query = view.Run(dbContext).Where(ea => + ea.Status == EmergencyAccessStatusType.RecoveryInitiated + ); + return await query.ToListAsync(); } + } - public async Task> GetExpiredRecoveriesAsync() + public async Task> GetManyDetailsByGranteeIdAsync(Guid granteeId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var view = new EmergencyAccessDetailsViewQuery(); - var query = view.Run(dbContext).Where(ea => - ea.Status == EmergencyAccessStatusType.RecoveryInitiated - ); - return await query.ToListAsync(); - } + var dbContext = GetDatabaseContext(scope); + var view = new EmergencyAccessDetailsViewQuery(); + var query = view.Run(dbContext).Where(ea => + ea.GranteeId == granteeId + ); + return await query.ToListAsync(); } + } - public async Task> GetManyDetailsByGranteeIdAsync(Guid granteeId) + public async Task> GetManyDetailsByGrantorIdAsync(Guid grantorId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var view = new EmergencyAccessDetailsViewQuery(); - var query = view.Run(dbContext).Where(ea => - ea.GranteeId == granteeId - ); - return await query.ToListAsync(); - } + var dbContext = GetDatabaseContext(scope); + var view = new EmergencyAccessDetailsViewQuery(); + var query = view.Run(dbContext).Where(ea => + ea.GrantorId == grantorId + ); + return await query.ToListAsync(); } + } - public async Task> GetManyDetailsByGrantorIdAsync(Guid grantorId) + public async Task> GetManyToNotifyAsync() + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var view = new EmergencyAccessDetailsViewQuery(); + var query = view.Run(dbContext).Where(ea => + ea.Status == EmergencyAccessStatusType.RecoveryInitiated + ); + var notifies = await query.Select(ea => new EmergencyAccessNotify { - var dbContext = GetDatabaseContext(scope); - var view = new EmergencyAccessDetailsViewQuery(); - var query = view.Run(dbContext).Where(ea => - ea.GrantorId == grantorId - ); - return await query.ToListAsync(); - } - } - - public async Task> GetManyToNotifyAsync() - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var view = new EmergencyAccessDetailsViewQuery(); - var query = view.Run(dbContext).Where(ea => - ea.Status == EmergencyAccessStatusType.RecoveryInitiated - ); - var notifies = await query.Select(ea => new EmergencyAccessNotify - { - Id = ea.Id, - GrantorId = ea.GrantorId, - GranteeId = ea.GranteeId, - Email = ea.Email, - KeyEncrypted = ea.KeyEncrypted, - Type = ea.Type, - Status = ea.Status, - WaitTimeDays = ea.WaitTimeDays, - RecoveryInitiatedDate = ea.RecoveryInitiatedDate, - LastNotificationDate = ea.LastNotificationDate, - CreationDate = ea.CreationDate, - RevisionDate = ea.RevisionDate, - GranteeName = ea.GranteeName, - GranteeEmail = ea.GranteeEmail, - GrantorEmail = ea.GrantorEmail, - }).ToListAsync(); - return notifies; - } + Id = ea.Id, + GrantorId = ea.GrantorId, + GranteeId = ea.GranteeId, + Email = ea.Email, + KeyEncrypted = ea.KeyEncrypted, + Type = ea.Type, + Status = ea.Status, + WaitTimeDays = ea.WaitTimeDays, + RecoveryInitiatedDate = ea.RecoveryInitiatedDate, + LastNotificationDate = ea.LastNotificationDate, + CreationDate = ea.CreationDate, + RevisionDate = ea.RevisionDate, + GranteeName = ea.GranteeName, + GranteeEmail = ea.GranteeEmail, + GrantorEmail = ea.GrantorEmail, + }).ToListAsync(); + return notifies; } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/EventRepository.cs b/src/Infrastructure.EntityFramework/Repositories/EventRepository.cs index 712885245..cb49f8535 100644 --- a/src/Infrastructure.EntityFramework/Repositories/EventRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/EventRepository.cs @@ -8,196 +8,195 @@ using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; using Cipher = Bit.Core.Entities.Cipher; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class EventRepository : Repository, IEventRepository { - public class EventRepository : Repository, IEventRepository + public EventRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Events) + { } + + public async Task CreateAsync(IEvent e) { - public EventRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Events) - { } - - public async Task CreateAsync(IEvent e) + if (e is not Core.Entities.Event ev) { - if (e is not Core.Entities.Event ev) - { - ev = new Core.Entities.Event(e); - } - - await base.CreateAsync(ev); + ev = new Core.Entities.Event(e); } - public async Task CreateManyAsync(IEnumerable entities) + await base.CreateAsync(ev); + } + + public async Task CreateManyAsync(IEnumerable entities) + { + if (!entities?.Any() ?? true) { - if (!entities?.Any() ?? true) - { - return; - } - - if (!entities.Skip(1).Any()) - { - await CreateAsync(entities.First()); - return; - } - - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var tableEvents = entities.Select(e => e as Core.Entities.Event ?? new Core.Entities.Event(e)); - var entityEvents = Mapper.Map>(tableEvents); - entityEvents.ForEach(e => e.SetNewId()); - await dbContext.BulkCopyAsync(entityEvents); - } + return; } - public async Task> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate, PageOptions pageOptions) + if (!entities.Skip(1).Any()) { - DateTime? beforeDate = null; - if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && - long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) - { - beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); - } - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new EventReadPageByCipherIdQuery(cipher, startDate, endDate, beforeDate, pageOptions); - var events = await query.Run(dbContext).ToListAsync(); - - var result = new PagedResult(); - if (events.Any() && events.Count >= pageOptions.PageSize) - { - result.ContinuationToken = events.Last().Date.ToBinary().ToString(); - } - result.Data.AddRange(events); - return result; - } + await CreateAsync(entities.First()); + return; } - - public async Task> GetManyByOrganizationActingUserAsync(Guid organizationId, Guid actingUserId, DateTime startDate, DateTime endDate, PageOptions pageOptions) + using (var scope = ServiceScopeFactory.CreateScope()) { - DateTime? beforeDate = null; - if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && - long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) - { - beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); - } - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new EventReadPageByOrganizationIdActingUserIdQuery(organizationId, actingUserId, - startDate, endDate, beforeDate, pageOptions); - var events = await query.Run(dbContext).ToListAsync(); - - var result = new PagedResult(); - if (events.Any() && events.Count >= pageOptions.PageSize) - { - result.ContinuationToken = events.Last().Date.ToBinary().ToString(); - } - result.Data.AddRange(events); - return result; - } + var dbContext = GetDatabaseContext(scope); + var tableEvents = entities.Select(e => e as Core.Entities.Event ?? new Core.Entities.Event(e)); + var entityEvents = Mapper.Map>(tableEvents); + entityEvents.ForEach(e => e.SetNewId()); + await dbContext.BulkCopyAsync(entityEvents); } + } - public async Task> GetManyByProviderAsync(Guid providerId, DateTime startDate, DateTime endDate, PageOptions pageOptions) + public async Task> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + DateTime? beforeDate = null; + if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && + long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) { - DateTime? beforeDate = null; - if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && - long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) - { - beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); - } - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new EventReadPageByProviderIdQuery(providerId, startDate, - endDate, beforeDate, pageOptions); - var events = await query.Run(dbContext).ToListAsync(); - - var result = new PagedResult(); - if (events.Any() && events.Count >= pageOptions.PageSize) - { - result.ContinuationToken = events.Last().Date.ToBinary().ToString(); - } - result.Data.AddRange(events); - return result; - } + beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); } - - public async Task> GetManyByProviderActingUserAsync(Guid providerId, Guid actingUserId, - DateTime startDate, DateTime endDate, PageOptions pageOptions) + using (var scope = ServiceScopeFactory.CreateScope()) { - DateTime? beforeDate = null; - if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && - long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) - { - beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); - } - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new EventReadPageByProviderIdActingUserIdQuery(providerId, actingUserId, - startDate, endDate, beforeDate, pageOptions); - var events = await query.Run(dbContext).ToListAsync(); + var dbContext = GetDatabaseContext(scope); + var query = new EventReadPageByCipherIdQuery(cipher, startDate, endDate, beforeDate, pageOptions); + var events = await query.Run(dbContext).ToListAsync(); - var result = new PagedResult(); - if (events.Any() && events.Count >= pageOptions.PageSize) - { - result.ContinuationToken = events.Last().Date.ToBinary().ToString(); - } - result.Data.AddRange(events); - return result; + var result = new PagedResult(); + if (events.Any() && events.Count >= pageOptions.PageSize) + { + result.ContinuationToken = events.Last().Date.ToBinary().ToString(); } + result.Data.AddRange(events); + return result; } + } - public async Task> GetManyByOrganizationAsync(Guid organizationId, DateTime startDate, DateTime endDate, PageOptions pageOptions) + + public async Task> GetManyByOrganizationActingUserAsync(Guid organizationId, Guid actingUserId, DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + DateTime? beforeDate = null; + if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && + long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) { - DateTime? beforeDate = null; - if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && - long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) - { - beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); - } - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new EventReadPageByOrganizationIdQuery(organizationId, startDate, - endDate, beforeDate, pageOptions); - var events = await query.Run(dbContext).ToListAsync(); - - var result = new PagedResult(); - if (events.Any() && events.Count >= pageOptions.PageSize) - { - result.ContinuationToken = events.Last().Date.ToBinary().ToString(); - } - result.Data.AddRange(events); - return result; - } + beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); } - - public async Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate, PageOptions pageOptions) + using (var scope = ServiceScopeFactory.CreateScope()) { - DateTime? beforeDate = null; - if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && - long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) - { - beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); - } - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new EventReadPageByUserIdQuery(userId, startDate, - endDate, beforeDate, pageOptions); - var events = await query.Run(dbContext).ToListAsync(); + var dbContext = GetDatabaseContext(scope); + var query = new EventReadPageByOrganizationIdActingUserIdQuery(organizationId, actingUserId, + startDate, endDate, beforeDate, pageOptions); + var events = await query.Run(dbContext).ToListAsync(); - var result = new PagedResult(); - if (events.Any() && events.Count >= pageOptions.PageSize) - { - result.ContinuationToken = events.Last().Date.ToBinary().ToString(); - } - result.Data.AddRange(events); - return result; + var result = new PagedResult(); + if (events.Any() && events.Count >= pageOptions.PageSize) + { + result.ContinuationToken = events.Last().Date.ToBinary().ToString(); } + result.Data.AddRange(events); + return result; + } + } + + public async Task> GetManyByProviderAsync(Guid providerId, DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + DateTime? beforeDate = null; + if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && + long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) + { + beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); + } + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new EventReadPageByProviderIdQuery(providerId, startDate, + endDate, beforeDate, pageOptions); + var events = await query.Run(dbContext).ToListAsync(); + + var result = new PagedResult(); + if (events.Any() && events.Count >= pageOptions.PageSize) + { + result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + } + result.Data.AddRange(events); + return result; + } + } + + public async Task> GetManyByProviderActingUserAsync(Guid providerId, Guid actingUserId, + DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + DateTime? beforeDate = null; + if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && + long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) + { + beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); + } + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new EventReadPageByProviderIdActingUserIdQuery(providerId, actingUserId, + startDate, endDate, beforeDate, pageOptions); + var events = await query.Run(dbContext).ToListAsync(); + + var result = new PagedResult(); + if (events.Any() && events.Count >= pageOptions.PageSize) + { + result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + } + result.Data.AddRange(events); + return result; + } + } + + public async Task> GetManyByOrganizationAsync(Guid organizationId, DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + DateTime? beforeDate = null; + if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && + long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) + { + beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); + } + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new EventReadPageByOrganizationIdQuery(organizationId, startDate, + endDate, beforeDate, pageOptions); + var events = await query.Run(dbContext).ToListAsync(); + + var result = new PagedResult(); + if (events.Any() && events.Count >= pageOptions.PageSize) + { + result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + } + result.Data.AddRange(events); + return result; + } + } + + public async Task> GetManyByUserAsync(Guid userId, DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + DateTime? beforeDate = null; + if (!string.IsNullOrWhiteSpace(pageOptions.ContinuationToken) && + long.TryParse(pageOptions.ContinuationToken, out var binaryDate)) + { + beforeDate = DateTime.SpecifyKind(DateTime.FromBinary(binaryDate), DateTimeKind.Utc); + } + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new EventReadPageByUserIdQuery(userId, startDate, + endDate, beforeDate, pageOptions); + var events = await query.Run(dbContext).ToListAsync(); + + var result = new PagedResult(); + if (events.Any() && events.Count >= pageOptions.PageSize) + { + result.ContinuationToken = events.Last().Date.ToBinary().ToString(); + } + result.Data.AddRange(events); + return result; } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/FolderRepository.cs b/src/Infrastructure.EntityFramework/Repositories/FolderRepository.cs index dae64f9c2..9f1f862bf 100644 --- a/src/Infrastructure.EntityFramework/Repositories/FolderRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/FolderRepository.cs @@ -4,36 +4,35 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class FolderRepository : Repository, IFolderRepository { - public class FolderRepository : Repository, IFolderRepository + public FolderRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Folders) + { } + + public async Task GetByIdAsync(Guid id, Guid userId) { - public FolderRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Folders) - { } - - public async Task GetByIdAsync(Guid id, Guid userId) + var folder = await base.GetByIdAsync(id); + if (folder == null || folder.UserId != userId) { - var folder = await base.GetByIdAsync(id); - if (folder == null || folder.UserId != userId) - { - return null; - } - - return folder; + return null; } - public async Task> GetManyByUserIdAsync(Guid userId) + return folder; + } + + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from f in dbContext.Folders - where f.UserId == userId - select f; - var folders = await query.ToListAsync(); - return Mapper.Map>(folders); - } + var dbContext = GetDatabaseContext(scope); + var query = from f in dbContext.Folders + where f.UserId == userId + select f; + var folders = await query.ToListAsync(); + return Mapper.Map>(folders); } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/GrantRepository.cs b/src/Infrastructure.EntityFramework/Repositories/GrantRepository.cs index 0f8f197fe..2edb62d9c 100644 --- a/src/Infrastructure.EntityFramework/Repositories/GrantRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/GrantRepository.cs @@ -4,92 +4,91 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class GrantRepository : BaseEntityFrameworkRepository, IGrantRepository { - public class GrantRepository : BaseEntityFrameworkRepository, IGrantRepository + public GrantRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper) + { } + + public async Task DeleteByKeyAsync(string key) { - public GrantRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper) - { } - - public async Task DeleteByKeyAsync(string key) + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var query = from g in dbContext.Grants + where g.Key == key + select g; + dbContext.Remove(query); + await dbContext.SaveChangesAsync(); + } + } + + public async Task DeleteManyAsync(string subjectId, string sessionId, string clientId, string type) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from g in dbContext.Grants + where g.SubjectId == subjectId && + g.ClientId == clientId && + g.SessionId == sessionId && + g.Type == type + select g; + dbContext.Remove(query); + await dbContext.SaveChangesAsync(); + } + } + + public async Task GetByKeyAsync(string key) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from g in dbContext.Grants + where g.Key == key + select g; + var grant = await query.FirstOrDefaultAsync(); + return grant; + } + } + + public async Task> GetManyAsync(string subjectId, string sessionId, string clientId, string type) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from g in dbContext.Grants + where g.SubjectId == subjectId && + g.ClientId == clientId && + g.SessionId == sessionId && + g.Type == type + select g; + var grants = await query.ToListAsync(); + return (ICollection)grants; + } + } + + public async Task SaveAsync(Core.Entities.Grant obj) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var existingGrant = await (from g in dbContext.Grants + where g.Key == obj.Key + select g).FirstOrDefaultAsync(); + if (existingGrant != null) { - var dbContext = GetDatabaseContext(scope); - var query = from g in dbContext.Grants - where g.Key == key - select g; - dbContext.Remove(query); + dbContext.Entry(existingGrant).CurrentValues.SetValues(obj); + } + else + { + var entity = Mapper.Map(obj); + await dbContext.AddAsync(entity); await dbContext.SaveChangesAsync(); } } - - public async Task DeleteManyAsync(string subjectId, string sessionId, string clientId, string type) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from g in dbContext.Grants - where g.SubjectId == subjectId && - g.ClientId == clientId && - g.SessionId == sessionId && - g.Type == type - select g; - dbContext.Remove(query); - await dbContext.SaveChangesAsync(); - } - } - - public async Task GetByKeyAsync(string key) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from g in dbContext.Grants - where g.Key == key - select g; - var grant = await query.FirstOrDefaultAsync(); - return grant; - } - } - - public async Task> GetManyAsync(string subjectId, string sessionId, string clientId, string type) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from g in dbContext.Grants - where g.SubjectId == subjectId && - g.ClientId == clientId && - g.SessionId == sessionId && - g.Type == type - select g; - var grants = await query.ToListAsync(); - return (ICollection)grants; - } - } - - public async Task SaveAsync(Core.Entities.Grant obj) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var existingGrant = await (from g in dbContext.Grants - where g.Key == obj.Key - select g).FirstOrDefaultAsync(); - if (existingGrant != null) - { - dbContext.Entry(existingGrant).CurrentValues.SetValues(obj); - } - else - { - var entity = Mapper.Map(obj); - await dbContext.AddAsync(entity); - await dbContext.SaveChangesAsync(); - } - } - } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/GroupRepository.cs b/src/Infrastructure.EntityFramework/Repositories/GroupRepository.cs index b471a5fdb..d41f07804 100644 --- a/src/Infrastructure.EntityFramework/Repositories/GroupRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/GroupRepository.cs @@ -5,164 +5,163 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class GroupRepository : Repository, IGroupRepository { - public class GroupRepository : Repository, IGroupRepository + public GroupRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Groups) + { } + + public async Task CreateAsync(Core.Entities.Group obj, IEnumerable collections) { - public GroupRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Groups) - { } - - public async Task CreateAsync(Core.Entities.Group obj, IEnumerable collections) + var grp = await base.CreateAsync(obj); + using (var scope = ServiceScopeFactory.CreateScope()) { - var grp = await base.CreateAsync(obj); - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var availibleCollections = await ( + from c in dbContext.Collections + where c.OrganizationId == grp.OrganizationId + select c).ToListAsync(); + var filteredCollections = collections.Where(c => availibleCollections.Any(a => c.Id == a.Id)); + var collectionGroups = filteredCollections.Select(y => new CollectionGroup { - var dbContext = GetDatabaseContext(scope); - var availibleCollections = await ( - from c in dbContext.Collections - where c.OrganizationId == grp.OrganizationId - select c).ToListAsync(); - var filteredCollections = collections.Where(c => availibleCollections.Any(a => c.Id == a.Id)); - var collectionGroups = filteredCollections.Select(y => new CollectionGroup - { - CollectionId = y.Id, - GroupId = grp.Id, - ReadOnly = y.ReadOnly, - HidePasswords = y.HidePasswords, - }); - await dbContext.CollectionGroups.AddRangeAsync(collectionGroups); - await dbContext.SaveChangesAsync(); - } + CollectionId = y.Id, + GroupId = grp.Id, + ReadOnly = y.ReadOnly, + HidePasswords = y.HidePasswords, + }); + await dbContext.CollectionGroups.AddRangeAsync(collectionGroups); + await dbContext.SaveChangesAsync(); } + } - public async Task DeleteUserAsync(Guid groupId, Guid organizationUserId) + public async Task DeleteUserAsync(Guid groupId, Guid organizationUserId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from gu in dbContext.GroupUsers - where gu.GroupId == groupId && - gu.OrganizationUserId == organizationUserId - select gu; - dbContext.RemoveRange(await query.ToListAsync()); - await dbContext.SaveChangesAsync(); - } + var dbContext = GetDatabaseContext(scope); + var query = from gu in dbContext.GroupUsers + where gu.GroupId == groupId && + gu.OrganizationUserId == organizationUserId + select gu; + dbContext.RemoveRange(await query.ToListAsync()); + await dbContext.SaveChangesAsync(); } + } - public async Task>> GetByIdWithCollectionsAsync(Guid id) + public async Task>> GetByIdWithCollectionsAsync(Guid id) + { + var grp = await base.GetByIdAsync(id); + using (var scope = ServiceScopeFactory.CreateScope()) { - var grp = await base.GetByIdAsync(id); - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var query = await ( + from cg in dbContext.CollectionGroups + where cg.GroupId == id + select cg).ToListAsync(); + var collections = query.Select(c => new SelectionReadOnly { - var dbContext = GetDatabaseContext(scope); - var query = await ( - from cg in dbContext.CollectionGroups - where cg.GroupId == id - select cg).ToListAsync(); - var collections = query.Select(c => new SelectionReadOnly - { - Id = c.CollectionId, - ReadOnly = c.ReadOnly, - HidePasswords = c.HidePasswords, - }).ToList(); - return new Tuple>( - grp, collections); - } + Id = c.CollectionId, + ReadOnly = c.ReadOnly, + HidePasswords = c.HidePasswords, + }).ToList(); + return new Tuple>( + grp, collections); } + } - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var data = await ( - from g in dbContext.Groups - where g.OrganizationId == organizationId - select g).ToListAsync(); - return Mapper.Map>(data); - } + var dbContext = GetDatabaseContext(scope); + var data = await ( + from g in dbContext.Groups + where g.OrganizationId == organizationId + select g).ToListAsync(); + return Mapper.Map>(data); } + } - public async Task> GetManyGroupUsersByOrganizationIdAsync(Guid organizationId) + public async Task> GetManyGroupUsersByOrganizationIdAsync(Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = - from gu in dbContext.GroupUsers - join g in dbContext.Groups - on gu.GroupId equals g.Id - where g.OrganizationId == organizationId - select gu; - var groupUsers = await query.ToListAsync(); - return Mapper.Map>(groupUsers); - } + var dbContext = GetDatabaseContext(scope); + var query = + from gu in dbContext.GroupUsers + join g in dbContext.Groups + on gu.GroupId equals g.Id + where g.OrganizationId == organizationId + select gu; + var groupUsers = await query.ToListAsync(); + return Mapper.Map>(groupUsers); } + } - public async Task> GetManyIdsByUserIdAsync(Guid organizationUserId) + public async Task> GetManyIdsByUserIdAsync(Guid organizationUserId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = - from gu in dbContext.GroupUsers - where gu.OrganizationUserId == organizationUserId - select gu; - var groupIds = await query.Select(x => x.GroupId).ToListAsync(); - return groupIds; - } + var dbContext = GetDatabaseContext(scope); + var query = + from gu in dbContext.GroupUsers + where gu.OrganizationUserId == organizationUserId + select gu; + var groupIds = await query.Select(x => x.GroupId).ToListAsync(); + return groupIds; } + } - public async Task> GetManyUserIdsByIdAsync(Guid id) + public async Task> GetManyUserIdsByIdAsync(Guid id) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = - from gu in dbContext.GroupUsers - where gu.GroupId == id - select gu; - var groupIds = await query.Select(x => x.OrganizationUserId).ToListAsync(); - return groupIds; - } + var dbContext = GetDatabaseContext(scope); + var query = + from gu in dbContext.GroupUsers + where gu.GroupId == id + select gu; + var groupIds = await query.Select(x => x.OrganizationUserId).ToListAsync(); + return groupIds; } + } - public async Task ReplaceAsync(Core.Entities.Group obj, IEnumerable collections) + public async Task ReplaceAsync(Core.Entities.Group obj, IEnumerable collections) + { + await base.ReplaceAsync(obj); + using (var scope = ServiceScopeFactory.CreateScope()) { - await base.ReplaceAsync(obj); - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - await UserBumpAccountRevisionDateByOrganizationId(obj.OrganizationId); - } + var dbContext = GetDatabaseContext(scope); + await UserBumpAccountRevisionDateByOrganizationId(obj.OrganizationId); } + } - public async Task UpdateUsersAsync(Guid groupId, IEnumerable organizationUserIds) + public async Task UpdateUsersAsync(Guid groupId, IEnumerable organizationUserIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var orgId = (await dbContext.Groups.FindAsync(groupId)).OrganizationId; - var insert = from ou in dbContext.OrganizationUsers - where organizationUserIds.Contains(ou.Id) && - ou.OrganizationId == orgId && - !dbContext.GroupUsers.Any(gu => gu.GroupId == groupId && ou.Id == gu.OrganizationUserId) - select new GroupUser - { - GroupId = groupId, - OrganizationUserId = ou.Id, - }; - await dbContext.AddRangeAsync(insert); + var dbContext = GetDatabaseContext(scope); + var orgId = (await dbContext.Groups.FindAsync(groupId)).OrganizationId; + var insert = from ou in dbContext.OrganizationUsers + where organizationUserIds.Contains(ou.Id) && + ou.OrganizationId == orgId && + !dbContext.GroupUsers.Any(gu => gu.GroupId == groupId && ou.Id == gu.OrganizationUserId) + select new GroupUser + { + GroupId = groupId, + OrganizationUserId = ou.Id, + }; + await dbContext.AddRangeAsync(insert); - var delete = from gu in dbContext.GroupUsers - where gu.GroupId == groupId && - !organizationUserIds.Contains(gu.OrganizationUserId) - select gu; - dbContext.RemoveRange(delete); - await dbContext.SaveChangesAsync(); - await UserBumpAccountRevisionDateByOrganizationId(orgId); - } + var delete = from gu in dbContext.GroupUsers + where gu.GroupId == groupId && + !organizationUserIds.Contains(gu.OrganizationUserId) + select gu; + dbContext.RemoveRange(delete); + await dbContext.SaveChangesAsync(); + await UserBumpAccountRevisionDateByOrganizationId(orgId); } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/InstallationRepository.cs b/src/Infrastructure.EntityFramework/Repositories/InstallationRepository.cs index 1cc4808c3..292e98f85 100644 --- a/src/Infrastructure.EntityFramework/Repositories/InstallationRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/InstallationRepository.cs @@ -3,12 +3,11 @@ using Bit.Core.Repositories; using Bit.Infrastructure.EntityFramework.Models; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class InstallationRepository : Repository, IInstallationRepository { - public class InstallationRepository : Repository, IInstallationRepository - { - public InstallationRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Installations) - { } - } + public InstallationRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Installations) + { } } diff --git a/src/Infrastructure.EntityFramework/Repositories/MaintenanceRepository.cs b/src/Infrastructure.EntityFramework/Repositories/MaintenanceRepository.cs index e91d775cc..340834ca5 100644 --- a/src/Infrastructure.EntityFramework/Repositories/MaintenanceRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/MaintenanceRepository.cs @@ -2,53 +2,52 @@ using Bit.Core.Repositories; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class MaintenanceRepository : BaseEntityFrameworkRepository, IMaintenanceRepository { - public class MaintenanceRepository : BaseEntityFrameworkRepository, IMaintenanceRepository + public MaintenanceRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper) + { } + + public async Task DeleteExpiredGrantsAsync() { - public MaintenanceRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper) - { } - - public async Task DeleteExpiredGrantsAsync() + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from g in dbContext.Grants - where g.ExpirationDate < DateTime.UtcNow - select g; - dbContext.RemoveRange(query); - await dbContext.SaveChangesAsync(); - } - } - - public async Task DeleteExpiredSponsorshipsAsync(DateTime validUntilBeforeDate) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from s in dbContext.OrganizationSponsorships - where s.ValidUntil < validUntilBeforeDate - select s; - dbContext.RemoveRange(query); - await dbContext.SaveChangesAsync(); - } - } - - public Task DisableCipherAutoStatsAsync() - { - return Task.CompletedTask; - } - - public Task RebuildIndexesAsync() - { - return Task.CompletedTask; - } - - public Task UpdateStatisticsAsync() - { - return Task.CompletedTask; + var dbContext = GetDatabaseContext(scope); + var query = from g in dbContext.Grants + where g.ExpirationDate < DateTime.UtcNow + select g; + dbContext.RemoveRange(query); + await dbContext.SaveChangesAsync(); } } + + public async Task DeleteExpiredSponsorshipsAsync(DateTime validUntilBeforeDate) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from s in dbContext.OrganizationSponsorships + where s.ValidUntil < validUntilBeforeDate + select s; + dbContext.RemoveRange(query); + await dbContext.SaveChangesAsync(); + } + } + + public Task DisableCipherAutoStatsAsync() + { + return Task.CompletedTask; + } + + public Task RebuildIndexesAsync() + { + return Task.CompletedTask; + } + + public Task UpdateStatisticsAsync() + { + return Task.CompletedTask; + } } diff --git a/src/Infrastructure.EntityFramework/Repositories/OrganizationApiKeyRepository.cs b/src/Infrastructure.EntityFramework/Repositories/OrganizationApiKeyRepository.cs index 8bc462adf..52cf3d5e6 100644 --- a/src/Infrastructure.EntityFramework/Repositories/OrganizationApiKeyRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/OrganizationApiKeyRepository.cs @@ -5,26 +5,25 @@ using Bit.Core.Repositories; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class OrganizationApiKeyRepository : Repository, IOrganizationApiKeyRepository { - public class OrganizationApiKeyRepository : Repository, IOrganizationApiKeyRepository + public OrganizationApiKeyRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, db => db.OrganizationApiKeys) { - public OrganizationApiKeyRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, db => db.OrganizationApiKeys) - { - } + } - public async Task> GetManyByOrganizationIdTypeAsync(Guid organizationId, OrganizationApiKeyType? type = null) + public async Task> GetManyByOrganizationIdTypeAsync(Guid organizationId, OrganizationApiKeyType? type = null) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var apiKeys = await dbContext.OrganizationApiKeys - .Where(o => o.OrganizationId == organizationId && (type == null || o.Type == type)) - .ToListAsync(); - return Mapper.Map>(apiKeys); - } + var dbContext = GetDatabaseContext(scope); + var apiKeys = await dbContext.OrganizationApiKeys + .Where(o => o.OrganizationId == organizationId && (type == null || o.Type == type)) + .ToListAsync(); + return Mapper.Map>(apiKeys); } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/OrganizationConnectionRepository.cs b/src/Infrastructure.EntityFramework/Repositories/OrganizationConnectionRepository.cs index 5acd8807d..298e28e02 100644 --- a/src/Infrastructure.EntityFramework/Repositories/OrganizationConnectionRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/OrganizationConnectionRepository.cs @@ -5,38 +5,37 @@ using Bit.Core.Repositories; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class OrganizationConnectionRepository : Repository, IOrganizationConnectionRepository { - public class OrganizationConnectionRepository : Repository, IOrganizationConnectionRepository + public OrganizationConnectionRepository(IServiceScopeFactory serviceScopeFactory, + IMapper mapper) + : base(serviceScopeFactory, mapper, context => context.OrganizationConnections) { - public OrganizationConnectionRepository(IServiceScopeFactory serviceScopeFactory, - IMapper mapper) - : base(serviceScopeFactory, mapper, context => context.OrganizationConnections) - { - } + } - public async Task> GetByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type) + public async Task> GetByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var connections = await dbContext.OrganizationConnections - .Where(oc => oc.OrganizationId == organizationId && oc.Type == type) - .ToListAsync(); - return Mapper.Map>(connections); - } + var dbContext = GetDatabaseContext(scope); + var connections = await dbContext.OrganizationConnections + .Where(oc => oc.OrganizationId == organizationId && oc.Type == type) + .ToListAsync(); + return Mapper.Map>(connections); } + } - public async Task> GetEnabledByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type) + public async Task> GetEnabledByOrganizationIdTypeAsync(Guid organizationId, OrganizationConnectionType type) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var connections = await dbContext.OrganizationConnections - .Where(oc => oc.OrganizationId == organizationId && oc.Type == type && oc.Enabled) - .ToListAsync(); - return Mapper.Map>(connections); - } + var dbContext = GetDatabaseContext(scope); + var connections = await dbContext.OrganizationConnections + .Where(oc => oc.OrganizationId == organizationId && oc.Type == type && oc.Enabled) + .ToListAsync(); + return Mapper.Map>(connections); } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/OrganizationRepository.cs b/src/Infrastructure.EntityFramework/Repositories/OrganizationRepository.cs index b12ff65bf..bd60b53e8 100644 --- a/src/Infrastructure.EntityFramework/Repositories/OrganizationRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/OrganizationRepository.cs @@ -5,105 +5,104 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class OrganizationRepository : Repository, IOrganizationRepository { - public class OrganizationRepository : Repository, IOrganizationRepository + public OrganizationRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Organizations) + { } + + public async Task GetByIdentifierAsync(string identifier) { - public OrganizationRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Organizations) - { } - - public async Task GetByIdentifierAsync(string identifier) + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var organization = await GetDbSet(dbContext).Where(e => e.Identifier == identifier) - .FirstOrDefaultAsync(); - return organization; - } + var dbContext = GetDatabaseContext(scope); + var organization = await GetDbSet(dbContext).Where(e => e.Identifier == identifier) + .FirstOrDefaultAsync(); + return organization; } + } - public async Task> GetManyByEnabledAsync() + public async Task> GetManyByEnabledAsync() + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var organizations = await GetDbSet(dbContext).Where(e => e.Enabled).ToListAsync(); - return Mapper.Map>(organizations); - } + var dbContext = GetDatabaseContext(scope); + var organizations = await GetDbSet(dbContext).Where(e => e.Enabled).ToListAsync(); + return Mapper.Map>(organizations); } + } - public async Task> GetManyByUserIdAsync(Guid userId) + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var organizations = await GetDbSet(dbContext) - .Select(e => e.OrganizationUsers - .Where(ou => ou.UserId == userId) - .Select(ou => ou.Organization)) - .ToListAsync(); - return Mapper.Map>(organizations); - } + var dbContext = GetDatabaseContext(scope); + var organizations = await GetDbSet(dbContext) + .Select(e => e.OrganizationUsers + .Where(ou => ou.UserId == userId) + .Select(ou => ou.Organization)) + .ToListAsync(); + return Mapper.Map>(organizations); } + } - public async Task> SearchAsync(string name, string userEmail, - bool? paid, int skip, int take) + public async Task> SearchAsync(string name, string userEmail, + bool? paid, int skip, int take) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var organizations = await GetDbSet(dbContext) - .Where(e => name == null || e.Name.Contains(name)) - .Where(e => userEmail == null || e.OrganizationUsers.Any(u => u.Email == userEmail)) - .Where(e => paid == null || - (paid == true && !string.IsNullOrWhiteSpace(e.GatewaySubscriptionId)) || - (paid == false && e.GatewaySubscriptionId == null)) - .OrderBy(e => e.CreationDate) - .Skip(skip).Take(take) - .ToListAsync(); - return Mapper.Map>(organizations); - } + var dbContext = GetDatabaseContext(scope); + var organizations = await GetDbSet(dbContext) + .Where(e => name == null || e.Name.Contains(name)) + .Where(e => userEmail == null || e.OrganizationUsers.Any(u => u.Email == userEmail)) + .Where(e => paid == null || + (paid == true && !string.IsNullOrWhiteSpace(e.GatewaySubscriptionId)) || + (paid == false && e.GatewaySubscriptionId == null)) + .OrderBy(e => e.CreationDate) + .Skip(skip).Take(take) + .ToListAsync(); + return Mapper.Map>(organizations); } + } - public async Task> GetManyAbilitiesAsync() + public async Task> GetManyAbilitiesAsync() + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + return await GetDbSet(dbContext) + .Select(e => new OrganizationAbility { - var dbContext = GetDatabaseContext(scope); - return await GetDbSet(dbContext) - .Select(e => new OrganizationAbility - { - Enabled = e.Enabled, - Id = e.Id, - Use2fa = e.Use2fa, - UseEvents = e.UseEvents, - UsersGetPremium = e.UsersGetPremium, - Using2fa = e.Use2fa && e.TwoFactorProviders != null, - UseSso = e.UseSso, - UseKeyConnector = e.UseKeyConnector, - UseResetPassword = e.UseResetPassword, - UseScim = e.UseScim, - }).ToListAsync(); - } + Enabled = e.Enabled, + Id = e.Id, + Use2fa = e.Use2fa, + UseEvents = e.UseEvents, + UsersGetPremium = e.UsersGetPremium, + Using2fa = e.Use2fa && e.TwoFactorProviders != null, + UseSso = e.UseSso, + UseKeyConnector = e.UseKeyConnector, + UseResetPassword = e.UseResetPassword, + UseScim = e.UseScim, + }).ToListAsync(); } + } - public async Task UpdateStorageAsync(Guid id) + public async Task UpdateStorageAsync(Guid id) + { + await OrganizationUpdateStorage(id); + } + + public override async Task DeleteAsync(Core.Entities.Organization organization) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - await OrganizationUpdateStorage(id); - } + var dbContext = GetDatabaseContext(scope); + var orgEntity = await dbContext.FindAsync(organization.Id); - public override async Task DeleteAsync(Core.Entities.Organization organization) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var orgEntity = await dbContext.FindAsync(organization.Id); - - dbContext.Remove(orgEntity); - await dbContext.SaveChangesAsync(); - } + dbContext.Remove(orgEntity); + await dbContext.SaveChangesAsync(); } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/OrganizationSponsorshipRepository.cs b/src/Infrastructure.EntityFramework/Repositories/OrganizationSponsorshipRepository.cs index 9e00d924d..de0af89df 100644 --- a/src/Infrastructure.EntityFramework/Repositories/OrganizationSponsorshipRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/OrganizationSponsorshipRepository.cs @@ -4,138 +4,137 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class OrganizationSponsorshipRepository : Repository, IOrganizationSponsorshipRepository { - public class OrganizationSponsorshipRepository : Repository, IOrganizationSponsorshipRepository + public OrganizationSponsorshipRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.OrganizationSponsorships) + { } + + public async Task> CreateManyAsync(IEnumerable organizationSponsorships) { - public OrganizationSponsorshipRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.OrganizationSponsorships) - { } - - public async Task> CreateManyAsync(IEnumerable organizationSponsorships) + if (!organizationSponsorships.Any()) { - if (!organizationSponsorships.Any()) - { - return new List(); - } - - foreach (var organizationSponsorship in organizationSponsorships) - { - organizationSponsorship.SetNewId(); - } - - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entities = Mapper.Map>(organizationSponsorships); - await dbContext.AddRangeAsync(entities); - await dbContext.SaveChangesAsync(); - } - - return organizationSponsorships.Select(u => u.Id).ToList(); + return new List(); } - public async Task ReplaceManyAsync(IEnumerable organizationSponsorships) + foreach (var organizationSponsorship in organizationSponsorships) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - dbContext.UpdateRange(organizationSponsorships); - await dbContext.SaveChangesAsync(); - } + organizationSponsorship.SetNewId(); } - public async Task UpsertManyAsync(IEnumerable organizationSponsorships) + using (var scope = ServiceScopeFactory.CreateScope()) { - var createSponsorships = new List(); - var replaceSponsorships = new List(); - foreach (var organizationSponsorship in organizationSponsorships) - { - if (organizationSponsorship.Id.Equals(default)) - { - createSponsorships.Add(organizationSponsorship); - } - else - { - replaceSponsorships.Add(organizationSponsorship); - } - } - - await CreateManyAsync(createSponsorships); - await ReplaceManyAsync(replaceSponsorships); - } - - public async Task DeleteManyAsync(IEnumerable organizationSponsorshipIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entities = await dbContext.OrganizationSponsorships - .Where(os => organizationSponsorshipIds.Contains(os.Id)) - .ToListAsync(); - - dbContext.OrganizationSponsorships.RemoveRange(entities); - await dbContext.SaveChangesAsync(); - } - } - - public async Task GetByOfferedToEmailAsync(string email) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var orgSponsorship = await GetDbSet(dbContext).Where(e => e.OfferedToEmail == email) - .FirstOrDefaultAsync(); - return orgSponsorship; - } - } - - public async Task GetBySponsoredOrganizationIdAsync(Guid sponsoredOrganizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var orgSponsorship = await GetDbSet(dbContext).Where(e => e.SponsoredOrganizationId == sponsoredOrganizationId) - .FirstOrDefaultAsync(); - return orgSponsorship; - } - } - - public async Task GetBySponsoringOrganizationUserIdAsync(Guid sponsoringOrganizationUserId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var orgSponsorship = await GetDbSet(dbContext).Where(e => e.SponsoringOrganizationUserId == sponsoringOrganizationUserId) - .FirstOrDefaultAsync(); - return orgSponsorship; - } - } - - public async Task GetLatestSyncDateBySponsoringOrganizationIdAsync(Guid sponsoringOrganizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - return await GetDbSet(dbContext).Where(e => e.SponsoringOrganizationId == sponsoringOrganizationId && e.LastSyncDate != null) - .OrderByDescending(e => e.LastSyncDate) - .Select(e => e.LastSyncDate) - .FirstOrDefaultAsync(); - - } - } - - public async Task> GetManyBySponsoringOrganizationAsync(Guid sponsoringOrganizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from os in dbContext.OrganizationSponsorships - where os.SponsoringOrganizationId == sponsoringOrganizationId - select os; - return Mapper.Map>(await query.ToListAsync()); - } + var dbContext = GetDatabaseContext(scope); + var entities = Mapper.Map>(organizationSponsorships); + await dbContext.AddRangeAsync(entities); + await dbContext.SaveChangesAsync(); } + return organizationSponsorships.Select(u => u.Id).ToList(); } + + public async Task ReplaceManyAsync(IEnumerable organizationSponsorships) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + dbContext.UpdateRange(organizationSponsorships); + await dbContext.SaveChangesAsync(); + } + } + + public async Task UpsertManyAsync(IEnumerable organizationSponsorships) + { + var createSponsorships = new List(); + var replaceSponsorships = new List(); + foreach (var organizationSponsorship in organizationSponsorships) + { + if (organizationSponsorship.Id.Equals(default)) + { + createSponsorships.Add(organizationSponsorship); + } + else + { + replaceSponsorships.Add(organizationSponsorship); + } + } + + await CreateManyAsync(createSponsorships); + await ReplaceManyAsync(replaceSponsorships); + } + + public async Task DeleteManyAsync(IEnumerable organizationSponsorshipIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entities = await dbContext.OrganizationSponsorships + .Where(os => organizationSponsorshipIds.Contains(os.Id)) + .ToListAsync(); + + dbContext.OrganizationSponsorships.RemoveRange(entities); + await dbContext.SaveChangesAsync(); + } + } + + public async Task GetByOfferedToEmailAsync(string email) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var orgSponsorship = await GetDbSet(dbContext).Where(e => e.OfferedToEmail == email) + .FirstOrDefaultAsync(); + return orgSponsorship; + } + } + + public async Task GetBySponsoredOrganizationIdAsync(Guid sponsoredOrganizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var orgSponsorship = await GetDbSet(dbContext).Where(e => e.SponsoredOrganizationId == sponsoredOrganizationId) + .FirstOrDefaultAsync(); + return orgSponsorship; + } + } + + public async Task GetBySponsoringOrganizationUserIdAsync(Guid sponsoringOrganizationUserId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var orgSponsorship = await GetDbSet(dbContext).Where(e => e.SponsoringOrganizationUserId == sponsoringOrganizationUserId) + .FirstOrDefaultAsync(); + return orgSponsorship; + } + } + + public async Task GetLatestSyncDateBySponsoringOrganizationIdAsync(Guid sponsoringOrganizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + return await GetDbSet(dbContext).Where(e => e.SponsoringOrganizationId == sponsoringOrganizationId && e.LastSyncDate != null) + .OrderByDescending(e => e.LastSyncDate) + .Select(e => e.LastSyncDate) + .FirstOrDefaultAsync(); + + } + } + + public async Task> GetManyBySponsoringOrganizationAsync(Guid sponsoringOrganizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from os in dbContext.OrganizationSponsorships + where os.SponsoringOrganizationId == sponsoringOrganizationId + select os; + return Mapper.Map>(await query.ToListAsync()); + } + } + } diff --git a/src/Infrastructure.EntityFramework/Repositories/OrganizationUserRepository.cs b/src/Infrastructure.EntityFramework/Repositories/OrganizationUserRepository.cs index 0c0383182..70f4401ee 100644 --- a/src/Infrastructure.EntityFramework/Repositories/OrganizationUserRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/OrganizationUserRepository.cs @@ -8,456 +8,455 @@ using Bit.Infrastructure.EntityFramework.Repositories.Queries; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class OrganizationUserRepository : Repository, IOrganizationUserRepository { - public class OrganizationUserRepository : Repository, IOrganizationUserRepository + public OrganizationUserRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.OrganizationUsers) + { } + + public async Task CreateAsync(Core.Entities.OrganizationUser obj, IEnumerable collections) { - public OrganizationUserRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.OrganizationUsers) - { } - - public async Task CreateAsync(Core.Entities.OrganizationUser obj, IEnumerable collections) + var organizationUser = await base.CreateAsync(obj); + using (var scope = ServiceScopeFactory.CreateScope()) { - var organizationUser = await base.CreateAsync(obj); - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var availibleCollections = await ( + from c in dbContext.Collections + where c.OrganizationId == organizationUser.OrganizationId + select c).ToListAsync(); + var filteredCollections = collections.Where(c => availibleCollections.Any(a => c.Id == a.Id)); + var collectionUsers = filteredCollections.Select(y => new CollectionUser { - var dbContext = GetDatabaseContext(scope); - var availibleCollections = await ( - from c in dbContext.Collections - where c.OrganizationId == organizationUser.OrganizationId - select c).ToListAsync(); - var filteredCollections = collections.Where(c => availibleCollections.Any(a => c.Id == a.Id)); - var collectionUsers = filteredCollections.Select(y => new CollectionUser + CollectionId = y.Id, + OrganizationUserId = organizationUser.Id, + ReadOnly = y.ReadOnly, + HidePasswords = y.HidePasswords, + }); + await dbContext.CollectionUsers.AddRangeAsync(collectionUsers); + await dbContext.SaveChangesAsync(); + } + + return organizationUser.Id; + } + + public async Task> CreateManyAsync(IEnumerable organizationUsers) + { + if (!organizationUsers.Any()) + { + return new List(); + } + + foreach (var organizationUser in organizationUsers) + { + organizationUser.SetNewId(); + } + + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entities = Mapper.Map>(organizationUsers); + await dbContext.AddRangeAsync(entities); + await dbContext.SaveChangesAsync(); + } + + return organizationUsers.Select(u => u.Id).ToList(); + } + + public override async Task DeleteAsync(Core.Entities.OrganizationUser organizationUser) => await DeleteAsync(organizationUser.Id); + public async Task DeleteAsync(Guid organizationUserId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var orgUser = await dbContext.FindAsync(organizationUserId); + + dbContext.Remove(orgUser); + await dbContext.SaveChangesAsync(); + } + } + + public async Task DeleteManyAsync(IEnumerable organizationUserIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entities = await dbContext.OrganizationUsers + .Where(ou => organizationUserIds.Contains(ou.Id)) + .ToListAsync(); + + dbContext.OrganizationUsers.RemoveRange(entities); + await dbContext.SaveChangesAsync(); + } + } + + public async Task>> GetByIdWithCollectionsAsync(Guid id) + { + var organizationUser = await base.GetByIdAsync(id); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = await ( + from ou in dbContext.OrganizationUsers + join cu in dbContext.CollectionUsers + on ou.Id equals cu.OrganizationUserId + where !ou.AccessAll && + ou.Id == id + select cu).ToListAsync(); + var collections = query.Select(cu => new SelectionReadOnly + { + Id = cu.CollectionId, + ReadOnly = cu.ReadOnly, + HidePasswords = cu.HidePasswords, + }); + return new Tuple>( + organizationUser, collections.ToList()); + } + } + + public async Task GetByOrganizationAsync(Guid organizationId, Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entity = await GetDbSet(dbContext) + .FirstOrDefaultAsync(e => e.OrganizationId == organizationId && e.UserId == userId); + return entity; + } + } + + public async Task GetByOrganizationEmailAsync(Guid organizationId, string email) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entity = await GetDbSet(dbContext) + .FirstOrDefaultAsync(ou => ou.OrganizationId == organizationId && + !string.IsNullOrWhiteSpace(ou.Email) && + ou.Email == email); + return entity; + } + } + + public async Task GetCountByFreeOrganizationAdminUserAsync(Guid userId) + { + var query = new OrganizationUserReadCountByFreeOrganizationAdminUserQuery(userId); + return await GetCountFromQuery(query); + } + + public async Task GetCountByOnlyOwnerAsync(Guid userId) + { + var query = new OrganizationUserReadCountByOnlyOwnerQuery(userId); + return await GetCountFromQuery(query); + } + + public async Task GetCountByOrganizationAsync(Guid organizationId, string email, bool onlyRegisteredUsers) + { + var query = new OrganizationUserReadCountByOrganizationIdEmailQuery(organizationId, email, onlyRegisteredUsers); + return await GetCountFromQuery(query); + } + + public async Task GetCountByOrganizationIdAsync(Guid organizationId) + { + var query = new OrganizationUserReadCountByOrganizationIdQuery(organizationId); + return await GetCountFromQuery(query); + } + + public async Task GetDetailsByIdAsync(Guid id) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var view = new OrganizationUserUserDetailsViewQuery(); + var entity = await view.Run(dbContext).FirstOrDefaultAsync(ou => ou.Id == id); + return entity; + } + } + + public async Task>> GetDetailsByIdWithCollectionsAsync(Guid id) + { + var organizationUserUserDetails = await GetDetailsByIdAsync(id); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from ou in dbContext.OrganizationUsers + join cu in dbContext.CollectionUsers on ou.Id equals cu.OrganizationUserId + where !ou.AccessAll && ou.Id == id + select cu; + var collections = await query.Select(cu => new SelectionReadOnly + { + Id = cu.CollectionId, + ReadOnly = cu.ReadOnly, + HidePasswords = cu.HidePasswords, + }).ToListAsync(); + return new Tuple>(organizationUserUserDetails, collections); + } + } + + public async Task GetDetailsByUserAsync(Guid userId, Guid organizationId, OrganizationUserStatusType? status = null) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var view = new OrganizationUserOrganizationDetailsViewQuery(); + var t = await (view.Run(dbContext)).ToArrayAsync(); + var entity = await view.Run(dbContext) + .FirstOrDefaultAsync(o => o.UserId == userId && + o.OrganizationId == organizationId && + (status == null || o.Status == status)); + return entity; + } + } + + public async Task> GetManyAsync(IEnumerable Ids) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from ou in dbContext.OrganizationUsers + where Ids.Contains(ou.Id) + select ou; + var data = await query.ToArrayAsync(); + return data; + } + } + + public async Task> GetManyByManyUsersAsync(IEnumerable userIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from ou in dbContext.OrganizationUsers + where userIds.Contains(ou.Id) + select ou; + return Mapper.Map>(await query.ToListAsync()); + } + } + + public async Task> GetManyByOrganizationAsync(Guid organizationId, OrganizationUserType? type) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from ou in dbContext.OrganizationUsers + where ou.OrganizationId == organizationId && + (type == null || ou.Type == type) + select ou; + return Mapper.Map>(await query.ToListAsync()); + } + } + + public async Task> GetManyByUserAsync(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from ou in dbContext.OrganizationUsers + where ou.UserId == userId + select ou; + return Mapper.Map>(await query.ToListAsync()); + } + } + + public async Task> GetManyDetailsByOrganizationAsync(Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var view = new OrganizationUserUserDetailsViewQuery(); + var query = from ou in view.Run(dbContext) + where ou.OrganizationId == organizationId + select ou; + return await query.ToListAsync(); + } + } + + public async Task> GetManyDetailsByUserAsync(Guid userId, + OrganizationUserStatusType? status = null) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var view = new OrganizationUserOrganizationDetailsViewQuery(); + var query = from ou in view.Run(dbContext) + where ou.UserId == userId && + (status == null || ou.Status == status) + select ou; + var organizationUsers = await query.ToListAsync(); + return organizationUsers; + } + } + + public async Task> GetManyPublicKeysByOrganizationUserAsync(Guid organizationId, IEnumerable Ids) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from ou in dbContext.OrganizationUsers + where Ids.Contains(ou.Id) && ou.Status == OrganizationUserStatusType.Accepted + join u in dbContext.Users + on ou.UserId equals u.Id + where ou.OrganizationId == organizationId + select new { ou, u }; + var data = await query + .Select(x => new OrganizationUserPublicKey() { - CollectionId = y.Id, - OrganizationUserId = organizationUser.Id, - ReadOnly = y.ReadOnly, - HidePasswords = y.HidePasswords, - }); - await dbContext.CollectionUsers.AddRangeAsync(collectionUsers); - await dbContext.SaveChangesAsync(); - } - - return organizationUser.Id; - } - - public async Task> CreateManyAsync(IEnumerable organizationUsers) - { - if (!organizationUsers.Any()) - { - return new List(); - } - - foreach (var organizationUser in organizationUsers) - { - organizationUser.SetNewId(); - } - - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entities = Mapper.Map>(organizationUsers); - await dbContext.AddRangeAsync(entities); - await dbContext.SaveChangesAsync(); - } - - return organizationUsers.Select(u => u.Id).ToList(); - } - - public override async Task DeleteAsync(Core.Entities.OrganizationUser organizationUser) => await DeleteAsync(organizationUser.Id); - public async Task DeleteAsync(Guid organizationUserId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var orgUser = await dbContext.FindAsync(organizationUserId); - - dbContext.Remove(orgUser); - await dbContext.SaveChangesAsync(); - } - } - - public async Task DeleteManyAsync(IEnumerable organizationUserIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entities = await dbContext.OrganizationUsers - .Where(ou => organizationUserIds.Contains(ou.Id)) - .ToListAsync(); - - dbContext.OrganizationUsers.RemoveRange(entities); - await dbContext.SaveChangesAsync(); - } - } - - public async Task>> GetByIdWithCollectionsAsync(Guid id) - { - var organizationUser = await base.GetByIdAsync(id); - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = await ( - from ou in dbContext.OrganizationUsers - join cu in dbContext.CollectionUsers - on ou.Id equals cu.OrganizationUserId - where !ou.AccessAll && - ou.Id == id - select cu).ToListAsync(); - var collections = query.Select(cu => new SelectionReadOnly - { - Id = cu.CollectionId, - ReadOnly = cu.ReadOnly, - HidePasswords = cu.HidePasswords, - }); - return new Tuple>( - organizationUser, collections.ToList()); - } - } - - public async Task GetByOrganizationAsync(Guid organizationId, Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entity = await GetDbSet(dbContext) - .FirstOrDefaultAsync(e => e.OrganizationId == organizationId && e.UserId == userId); - return entity; - } - } - - public async Task GetByOrganizationEmailAsync(Guid organizationId, string email) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entity = await GetDbSet(dbContext) - .FirstOrDefaultAsync(ou => ou.OrganizationId == organizationId && - !string.IsNullOrWhiteSpace(ou.Email) && - ou.Email == email); - return entity; - } - } - - public async Task GetCountByFreeOrganizationAdminUserAsync(Guid userId) - { - var query = new OrganizationUserReadCountByFreeOrganizationAdminUserQuery(userId); - return await GetCountFromQuery(query); - } - - public async Task GetCountByOnlyOwnerAsync(Guid userId) - { - var query = new OrganizationUserReadCountByOnlyOwnerQuery(userId); - return await GetCountFromQuery(query); - } - - public async Task GetCountByOrganizationAsync(Guid organizationId, string email, bool onlyRegisteredUsers) - { - var query = new OrganizationUserReadCountByOrganizationIdEmailQuery(organizationId, email, onlyRegisteredUsers); - return await GetCountFromQuery(query); - } - - public async Task GetCountByOrganizationIdAsync(Guid organizationId) - { - var query = new OrganizationUserReadCountByOrganizationIdQuery(organizationId); - return await GetCountFromQuery(query); - } - - public async Task GetDetailsByIdAsync(Guid id) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var view = new OrganizationUserUserDetailsViewQuery(); - var entity = await view.Run(dbContext).FirstOrDefaultAsync(ou => ou.Id == id); - return entity; - } - } - - public async Task>> GetDetailsByIdWithCollectionsAsync(Guid id) - { - var organizationUserUserDetails = await GetDetailsByIdAsync(id); - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from ou in dbContext.OrganizationUsers - join cu in dbContext.CollectionUsers on ou.Id equals cu.OrganizationUserId - where !ou.AccessAll && ou.Id == id - select cu; - var collections = await query.Select(cu => new SelectionReadOnly - { - Id = cu.CollectionId, - ReadOnly = cu.ReadOnly, - HidePasswords = cu.HidePasswords, + Id = x.ou.Id, + PublicKey = x.u.PublicKey, }).ToListAsync(); - return new Tuple>(organizationUserUserDetails, collections); + return data; + } + } + + public async Task ReplaceAsync(Core.Entities.OrganizationUser obj, IEnumerable collections) + { + await base.ReplaceAsync(obj); + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + + var procedure = new OrganizationUserUpdateWithCollectionsQuery(obj, collections); + + var update = procedure.Update.Run(dbContext); + dbContext.UpdateRange(await update.ToListAsync()); + + var insert = procedure.Insert.Run(dbContext); + await dbContext.AddRangeAsync(await insert.ToListAsync()); + + dbContext.RemoveRange(await procedure.Delete.Run(dbContext).ToListAsync()); + await dbContext.SaveChangesAsync(); + } + } + + public async Task ReplaceManyAsync(IEnumerable organizationUsers) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + dbContext.UpdateRange(organizationUsers); + await dbContext.SaveChangesAsync(); + await UserBumpManyAccountRevisionDates(organizationUsers + .Where(ou => ou.UserId.HasValue) + .Select(ou => ou.UserId.Value).ToArray()); + } + } + + public async Task> SelectKnownEmailsAsync(Guid organizationId, IEnumerable emails, bool onlyRegisteredUsers) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var usersQuery = from ou in dbContext.OrganizationUsers + join u in dbContext.Users + on ou.UserId equals u.Id into u_g + from u in u_g + where ou.OrganizationId == organizationId + select new { ou, u }; + var ouu = await usersQuery.ToListAsync(); + var ouEmails = ouu.Select(x => x.ou.Email); + var uEmails = ouu.Select(x => x.u.Email); + var knownEmails = from e in emails + where (ouEmails.Contains(e) || uEmails.Contains(e)) && + (!onlyRegisteredUsers && (uEmails.Contains(e) || ouEmails.Contains(e))) || + (onlyRegisteredUsers && uEmails.Contains(e)) + select e; + return knownEmails.ToList(); + } + } + + public async Task UpdateGroupsAsync(Guid orgUserId, IEnumerable groupIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + + var procedure = new GroupUserUpdateGroupsQuery(orgUserId, groupIds); + + var insert = procedure.Insert.Run(dbContext); + var data = await insert.ToListAsync(); + await dbContext.AddRangeAsync(data); + + var delete = procedure.Delete.Run(dbContext); + var deleteData = await delete.ToListAsync(); + dbContext.RemoveRange(deleteData); + await UserBumpAccountRevisionDateByOrganizationUserId(orgUserId); + await dbContext.SaveChangesAsync(); + } + } + + public async Task UpsertManyAsync(IEnumerable organizationUsers) + { + var createUsers = new List(); + var replaceUsers = new List(); + foreach (var organizationUser in organizationUsers) + { + if (organizationUser.Id.Equals(default)) + { + createUsers.Add(organizationUser); + } + else + { + replaceUsers.Add(organizationUser); } } - public async Task GetDetailsByUserAsync(Guid userId, Guid organizationId, OrganizationUserStatusType? status = null) + await CreateManyAsync(createUsers); + await ReplaceManyAsync(replaceUsers); + } + + public async Task> GetManyByMinimumRoleAsync(Guid organizationId, OrganizationUserType minRole) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var view = new OrganizationUserOrganizationDetailsViewQuery(); - var t = await (view.Run(dbContext)).ToArrayAsync(); - var entity = await view.Run(dbContext) - .FirstOrDefaultAsync(o => o.UserId == userId && - o.OrganizationId == organizationId && - (status == null || o.Status == status)); - return entity; - } - } - - public async Task> GetManyAsync(IEnumerable Ids) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from ou in dbContext.OrganizationUsers - where Ids.Contains(ou.Id) - select ou; - var data = await query.ToArrayAsync(); - return data; - } - } - - public async Task> GetManyByManyUsersAsync(IEnumerable userIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from ou in dbContext.OrganizationUsers - where userIds.Contains(ou.Id) - select ou; - return Mapper.Map>(await query.ToListAsync()); - } - } - - public async Task> GetManyByOrganizationAsync(Guid organizationId, OrganizationUserType? type) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from ou in dbContext.OrganizationUsers - where ou.OrganizationId == organizationId && - (type == null || ou.Type == type) - select ou; - return Mapper.Map>(await query.ToListAsync()); - } - } - - public async Task> GetManyByUserAsync(Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from ou in dbContext.OrganizationUsers - where ou.UserId == userId - select ou; - return Mapper.Map>(await query.ToListAsync()); - } - } - - public async Task> GetManyDetailsByOrganizationAsync(Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var view = new OrganizationUserUserDetailsViewQuery(); - var query = from ou in view.Run(dbContext) - where ou.OrganizationId == organizationId - select ou; - return await query.ToListAsync(); - } - } - - public async Task> GetManyDetailsByUserAsync(Guid userId, - OrganizationUserStatusType? status = null) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var view = new OrganizationUserOrganizationDetailsViewQuery(); - var query = from ou in view.Run(dbContext) - where ou.UserId == userId && - (status == null || ou.Status == status) - select ou; - var organizationUsers = await query.ToListAsync(); - return organizationUsers; - } - } - - public async Task> GetManyPublicKeysByOrganizationUserAsync(Guid organizationId, IEnumerable Ids) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from ou in dbContext.OrganizationUsers - where Ids.Contains(ou.Id) && ou.Status == OrganizationUserStatusType.Accepted - join u in dbContext.Users - on ou.UserId equals u.Id - where ou.OrganizationId == organizationId - select new { ou, u }; - var data = await query - .Select(x => new OrganizationUserPublicKey() - { - Id = x.ou.Id, - PublicKey = x.u.PublicKey, - }).ToListAsync(); - return data; - } - } - - public async Task ReplaceAsync(Core.Entities.OrganizationUser obj, IEnumerable collections) - { - await base.ReplaceAsync(obj); - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - - var procedure = new OrganizationUserUpdateWithCollectionsQuery(obj, collections); - - var update = procedure.Update.Run(dbContext); - dbContext.UpdateRange(await update.ToListAsync()); - - var insert = procedure.Insert.Run(dbContext); - await dbContext.AddRangeAsync(await insert.ToListAsync()); - - dbContext.RemoveRange(await procedure.Delete.Run(dbContext).ToListAsync()); - await dbContext.SaveChangesAsync(); - } - } - - public async Task ReplaceManyAsync(IEnumerable organizationUsers) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - dbContext.UpdateRange(organizationUsers); - await dbContext.SaveChangesAsync(); - await UserBumpManyAccountRevisionDates(organizationUsers - .Where(ou => ou.UserId.HasValue) - .Select(ou => ou.UserId.Value).ToArray()); - } - } - - public async Task> SelectKnownEmailsAsync(Guid organizationId, IEnumerable emails, bool onlyRegisteredUsers) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var usersQuery = from ou in dbContext.OrganizationUsers - join u in dbContext.Users - on ou.UserId equals u.Id into u_g - from u in u_g - where ou.OrganizationId == organizationId - select new { ou, u }; - var ouu = await usersQuery.ToListAsync(); - var ouEmails = ouu.Select(x => x.ou.Email); - var uEmails = ouu.Select(x => x.u.Email); - var knownEmails = from e in emails - where (ouEmails.Contains(e) || uEmails.Contains(e)) && - (!onlyRegisteredUsers && (uEmails.Contains(e) || ouEmails.Contains(e))) || - (onlyRegisteredUsers && uEmails.Contains(e)) - select e; - return knownEmails.ToList(); - } - } - - public async Task UpdateGroupsAsync(Guid orgUserId, IEnumerable groupIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - - var procedure = new GroupUserUpdateGroupsQuery(orgUserId, groupIds); - - var insert = procedure.Insert.Run(dbContext); - var data = await insert.ToListAsync(); - await dbContext.AddRangeAsync(data); - - var delete = procedure.Delete.Run(dbContext); - var deleteData = await delete.ToListAsync(); - dbContext.RemoveRange(deleteData); - await UserBumpAccountRevisionDateByOrganizationUserId(orgUserId); - await dbContext.SaveChangesAsync(); - } - } - - public async Task UpsertManyAsync(IEnumerable organizationUsers) - { - var createUsers = new List(); - var replaceUsers = new List(); - foreach (var organizationUser in organizationUsers) - { - if (organizationUser.Id.Equals(default)) + var dbContext = GetDatabaseContext(scope); + var query = dbContext.OrganizationUsers + .Include(e => e.User) + .Where(e => e.OrganizationId.Equals(organizationId) && + e.Type <= minRole && + e.Status == OrganizationUserStatusType.Confirmed) + .Select(e => new OrganizationUserUserDetails() { - createUsers.Add(organizationUser); - } - else - { - replaceUsers.Add(organizationUser); - } - } - - await CreateManyAsync(createUsers); - await ReplaceManyAsync(replaceUsers); + Id = e.Id, + Email = e.Email ?? e.User.Email + }); + return await query.ToListAsync(); } + } - public async Task> GetManyByMinimumRoleAsync(Guid organizationId, OrganizationUserType minRole) + public async Task RevokeAsync(Guid id) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var orgUser = await GetDbSet(dbContext).FindAsync(id); + if (orgUser != null) { - var dbContext = GetDatabaseContext(scope); - var query = dbContext.OrganizationUsers - .Include(e => e.User) - .Where(e => e.OrganizationId.Equals(organizationId) && - e.Type <= minRole && - e.Status == OrganizationUserStatusType.Confirmed) - .Select(e => new OrganizationUserUserDetails() - { - Id = e.Id, - Email = e.Email ?? e.User.Email - }); - return await query.ToListAsync(); - } - } - - public async Task RevokeAsync(Guid id) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var orgUser = await GetDbSet(dbContext).FindAsync(id); - if (orgUser != null) + dbContext.Update(orgUser); + orgUser.Status = OrganizationUserStatusType.Revoked; + await dbContext.SaveChangesAsync(); + if (orgUser.UserId.HasValue) { - dbContext.Update(orgUser); - orgUser.Status = OrganizationUserStatusType.Revoked; - await dbContext.SaveChangesAsync(); - if (orgUser.UserId.HasValue) - { - await UserBumpAccountRevisionDate(orgUser.UserId.Value); - } + await UserBumpAccountRevisionDate(orgUser.UserId.Value); } } } + } - public async Task RestoreAsync(Guid id, OrganizationUserStatusType status) + public async Task RestoreAsync(Guid id, OrganizationUserStatusType status) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var orgUser = await GetDbSet(dbContext).FindAsync(id); + if (orgUser != null) { - var dbContext = GetDatabaseContext(scope); - var orgUser = await GetDbSet(dbContext).FindAsync(id); - if (orgUser != null) + dbContext.Update(orgUser); + orgUser.Status = status; + await dbContext.SaveChangesAsync(); + if (orgUser.UserId.HasValue) { - dbContext.Update(orgUser); - orgUser.Status = status; - await dbContext.SaveChangesAsync(); - if (orgUser.UserId.HasValue) - { - await UserBumpAccountRevisionDate(orgUser.UserId.Value); - } + await UserBumpAccountRevisionDate(orgUser.UserId.Value); } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/PolicyRepository.cs b/src/Infrastructure.EntityFramework/Repositories/PolicyRepository.cs index 8d6a92809..1a02c6aa7 100644 --- a/src/Infrastructure.EntityFramework/Repositories/PolicyRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/PolicyRepository.cs @@ -6,72 +6,71 @@ using Bit.Infrastructure.EntityFramework.Repositories.Queries; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class PolicyRepository : Repository, IPolicyRepository { - public class PolicyRepository : Repository, IPolicyRepository + public PolicyRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Policies) + { } + + public async Task GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type) { - public PolicyRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Policies) - { } - - public async Task GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type) + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.Policies - .FirstOrDefaultAsync(p => p.OrganizationId == organizationId && p.Type == type); - return Mapper.Map(results); - } + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.Policies + .FirstOrDefaultAsync(p => p.OrganizationId == organizationId && p.Type == type); + return Mapper.Map(results); } + } - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.Policies - .Where(p => p.OrganizationId == organizationId) - .ToListAsync(); - return Mapper.Map>(results); - } + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.Policies + .Where(p => p.OrganizationId == organizationId) + .ToListAsync(); + return Mapper.Map>(results); } + } - public async Task> GetManyByUserIdAsync(Guid userId) + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); + var dbContext = GetDatabaseContext(scope); - var query = new PolicyReadByUserIdQuery(userId); - var results = await query.Run(dbContext).ToListAsync(); - return Mapper.Map>(results); - } + var query = new PolicyReadByUserIdQuery(userId); + var results = await query.Run(dbContext).ToListAsync(); + return Mapper.Map>(results); } + } - public async Task> GetManyByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, - OrganizationUserStatusType minStatus) + public async Task> GetManyByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, + OrganizationUserStatusType minStatus) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); + var dbContext = GetDatabaseContext(scope); - var query = new PolicyReadByTypeApplicableToUserQuery(userId, policyType, minStatus); - var results = await query.Run(dbContext).ToListAsync(); - return Mapper.Map>(results); - } + var query = new PolicyReadByTypeApplicableToUserQuery(userId, policyType, minStatus); + var results = await query.Run(dbContext).ToListAsync(); + return Mapper.Map>(results); } + } - public async Task GetCountByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, - OrganizationUserStatusType minStatus) + public async Task GetCountByTypeApplicableToUserIdAsync(Guid userId, PolicyType policyType, + OrganizationUserStatusType minStatus) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); + var dbContext = GetDatabaseContext(scope); - var query = new PolicyReadByTypeApplicableToUserQuery(userId, policyType, minStatus); - return await GetCountFromQuery(query); - } + var query = new PolicyReadByTypeApplicableToUserQuery(userId, policyType, minStatus); + return await GetCountFromQuery(query); } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/ProviderOrganizationRepository.cs b/src/Infrastructure.EntityFramework/Repositories/ProviderOrganizationRepository.cs index dd5271fcd..5d17d38bb 100644 --- a/src/Infrastructure.EntityFramework/Repositories/ProviderOrganizationRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/ProviderOrganizationRepository.cs @@ -6,31 +6,30 @@ using Bit.Infrastructure.EntityFramework.Repositories.Queries; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class ProviderOrganizationRepository : + Repository, IProviderOrganizationRepository { - public class ProviderOrganizationRepository : - Repository, IProviderOrganizationRepository + public ProviderOrganizationRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, context => context.ProviderOrganizations) + { } + + public async Task> GetManyDetailsByProviderAsync(Guid providerId) { - public ProviderOrganizationRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, context => context.ProviderOrganizations) - { } - - public async Task> GetManyDetailsByProviderAsync(Guid providerId) + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new ProviderOrganizationOrganizationDetailsReadByProviderIdQuery(providerId); - var data = await query.Run(dbContext).ToListAsync(); - return data; - } - } - - public async Task GetByOrganizationId(Guid organizationId) - { - using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); - return await GetDbSet(dbContext).Where(po => po.OrganizationId == organizationId).FirstOrDefaultAsync(); + var query = new ProviderOrganizationOrganizationDetailsReadByProviderIdQuery(providerId); + var data = await query.Run(dbContext).ToListAsync(); + return data; } } + + public async Task GetByOrganizationId(Guid organizationId) + { + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + return await GetDbSet(dbContext).Where(po => po.OrganizationId == organizationId).FirstOrDefaultAsync(); + } } diff --git a/src/Infrastructure.EntityFramework/Repositories/ProviderRepository.cs b/src/Infrastructure.EntityFramework/Repositories/ProviderRepository.cs index 75c8788c6..cf015b273 100644 --- a/src/Infrastructure.EntityFramework/Repositories/ProviderRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/ProviderRepository.cs @@ -5,52 +5,51 @@ using Bit.Core.Repositories; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class ProviderRepository : Repository, IProviderRepository { - public class ProviderRepository : Repository, IProviderRepository + + public ProviderRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, context => context.Providers) + { } + + public async Task> SearchAsync(string name, string userEmail, int skip, int take) { - - public ProviderRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, context => context.Providers) - { } - - public async Task> SearchAsync(string name, string userEmail, int skip, int take) + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = !string.IsNullOrWhiteSpace(userEmail) ? - (from p in dbContext.Providers - join pu in dbContext.ProviderUsers - on p.Id equals pu.ProviderId - join u in dbContext.Users - on pu.UserId equals u.Id - where (string.IsNullOrWhiteSpace(name) || p.Name.Contains(name)) && - u.Email == userEmail - orderby p.CreationDate descending - select new { p, pu, u }).Skip(skip).Take(take).Select(x => x.p) : - (from p in dbContext.Providers - where string.IsNullOrWhiteSpace(name) || p.Name.Contains(name) - orderby p.CreationDate descending - select new { p }).Skip(skip).Take(take).Select(x => x.p); - var providers = await query.ToArrayAsync(); - return Mapper.Map>(providers); - } + var dbContext = GetDatabaseContext(scope); + var query = !string.IsNullOrWhiteSpace(userEmail) ? + (from p in dbContext.Providers + join pu in dbContext.ProviderUsers + on p.Id equals pu.ProviderId + join u in dbContext.Users + on pu.UserId equals u.Id + where (string.IsNullOrWhiteSpace(name) || p.Name.Contains(name)) && + u.Email == userEmail + orderby p.CreationDate descending + select new { p, pu, u }).Skip(skip).Take(take).Select(x => x.p) : + (from p in dbContext.Providers + where string.IsNullOrWhiteSpace(name) || p.Name.Contains(name) + orderby p.CreationDate descending + select new { p }).Skip(skip).Take(take).Select(x => x.p); + var providers = await query.ToArrayAsync(); + return Mapper.Map>(providers); } + } - public async Task> GetManyAbilitiesAsync() + public async Task> GetManyAbilitiesAsync() + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - return await GetDbSet(dbContext) - .Select(e => new ProviderAbility - { - Enabled = e.Enabled, - Id = e.Id, - UseEvents = e.UseEvents, - }).ToListAsync(); - } + var dbContext = GetDatabaseContext(scope); + return await GetDbSet(dbContext) + .Select(e => new ProviderAbility + { + Enabled = e.Enabled, + Id = e.Id, + UseEvents = e.UseEvents, + }).ToListAsync(); } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/ProviderUserRepository.cs b/src/Infrastructure.EntityFramework/Repositories/ProviderUserRepository.cs index 87d82a542..3aac5cca9 100644 --- a/src/Infrastructure.EntityFramework/Repositories/ProviderUserRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/ProviderUserRepository.cs @@ -7,154 +7,153 @@ using Bit.Infrastructure.EntityFramework.Repositories.Queries; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class ProviderUserRepository : + Repository, IProviderUserRepository { - public class ProviderUserRepository : - Repository, IProviderUserRepository + public ProviderUserRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.ProviderUsers) + { } + + public async Task GetCountByProviderAsync(Guid providerId, string email, bool onlyRegisteredUsers) { - public ProviderUserRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.ProviderUsers) - { } - - public async Task GetCountByProviderAsync(Guid providerId, string email, bool onlyRegisteredUsers) + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from pu in dbContext.ProviderUsers - join u in dbContext.Users - on pu.UserId equals u.Id into u_g - from u in u_g.DefaultIfEmpty() - where pu.ProviderId == providerId && - ((!onlyRegisteredUsers && (pu.Email == email || u.Email == email)) || - (onlyRegisteredUsers && u.Email == email)) - select new { pu, u }; - return await query.CountAsync(); - } - } - - public async Task> GetManyAsync(IEnumerable ids) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = dbContext.ProviderUsers.Where(item => ids.Contains(item.Id)); - return await query.ToArrayAsync(); - } - } - - public async Task> GetManyByProviderAsync(Guid providerId, ProviderUserType? type = null) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = dbContext.ProviderUsers.Where(pu => pu.ProviderId.Equals(providerId) && - (type != null && pu.Type.Equals(type))); - return await query.ToArrayAsync(); - } - } - - public async Task DeleteManyAsync(IEnumerable providerUserIds) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - await UserBumpAccountRevisionDateByProviderUserIds(providerUserIds.ToArray()); - var entities = dbContext.ProviderUsers.Where(pu => providerUserIds.Contains(pu.Id)); - dbContext.ProviderUsers.RemoveRange(entities); - await dbContext.SaveChangesAsync(); - } - } - - public async Task> GetManyByUserAsync(Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from pu in dbContext.ProviderUsers - where pu.UserId == userId - select pu; - return await query.ToArrayAsync(); - } - } - public async Task GetByProviderUserAsync(Guid providerId, Guid userId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = from pu in dbContext.ProviderUsers - where pu.UserId == userId && - pu.ProviderId == providerId - select pu; - return await query.FirstOrDefaultAsync(); - } - } - public async Task> GetManyDetailsByProviderAsync(Guid providerId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var view = from pu in dbContext.ProviderUsers - join u in dbContext.Users - on pu.UserId equals u.Id into u_g - from u in u_g.DefaultIfEmpty() - select new { pu, u }; - var data = await view.Where(e => e.pu.ProviderId == providerId).Select(e => new ProviderUserUserDetails - { - Id = e.pu.Id, - UserId = e.pu.UserId, - ProviderId = e.pu.ProviderId, - Name = e.u.Name, - Email = e.u.Email ?? e.pu.Email, - Status = e.pu.Status, - Type = e.pu.Type, - Permissions = e.pu.Permissions, - }).ToArrayAsync(); - return data; - } - } - - public async Task> GetManyDetailsByUserAsync(Guid userId, ProviderUserStatusType? status = null) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new ProviderUserProviderDetailsReadByUserIdStatusQuery(userId, status); - var data = await query.Run(dbContext).ToArrayAsync(); - return data; - } - } - - public async Task> GetManyPublicKeysByProviderUserAsync(Guid providerId, IEnumerable Ids) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new UserReadPublicKeysByProviderUserIdsQuery(providerId, Ids); - var data = await query.Run(dbContext).ToListAsync(); - return data; - } - } - - public async Task> GetManyOrganizationDetailsByUserAsync(Guid userId, ProviderUserStatusType? status = null) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var view = new ProviderUserOrganizationDetailsViewQuery(); - var query = from ou in view.Run(dbContext) - where ou.UserId == userId && - (status == null || ou.Status == status) - select ou; - var organizationUsers = await query.ToListAsync(); - return organizationUsers; - } - } - - public async Task GetCountByOnlyOwnerAsync(Guid userId) - { - var query = new ProviderUserReadCountByOnlyOwnerQuery(userId); - return await GetCountFromQuery(query); + var dbContext = GetDatabaseContext(scope); + var query = from pu in dbContext.ProviderUsers + join u in dbContext.Users + on pu.UserId equals u.Id into u_g + from u in u_g.DefaultIfEmpty() + where pu.ProviderId == providerId && + ((!onlyRegisteredUsers && (pu.Email == email || u.Email == email)) || + (onlyRegisteredUsers && u.Email == email)) + select new { pu, u }; + return await query.CountAsync(); } } + + public async Task> GetManyAsync(IEnumerable ids) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = dbContext.ProviderUsers.Where(item => ids.Contains(item.Id)); + return await query.ToArrayAsync(); + } + } + + public async Task> GetManyByProviderAsync(Guid providerId, ProviderUserType? type = null) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = dbContext.ProviderUsers.Where(pu => pu.ProviderId.Equals(providerId) && + (type != null && pu.Type.Equals(type))); + return await query.ToArrayAsync(); + } + } + + public async Task DeleteManyAsync(IEnumerable providerUserIds) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + await UserBumpAccountRevisionDateByProviderUserIds(providerUserIds.ToArray()); + var entities = dbContext.ProviderUsers.Where(pu => providerUserIds.Contains(pu.Id)); + dbContext.ProviderUsers.RemoveRange(entities); + await dbContext.SaveChangesAsync(); + } + } + + public async Task> GetManyByUserAsync(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from pu in dbContext.ProviderUsers + where pu.UserId == userId + select pu; + return await query.ToArrayAsync(); + } + } + public async Task GetByProviderUserAsync(Guid providerId, Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = from pu in dbContext.ProviderUsers + where pu.UserId == userId && + pu.ProviderId == providerId + select pu; + return await query.FirstOrDefaultAsync(); + } + } + public async Task> GetManyDetailsByProviderAsync(Guid providerId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var view = from pu in dbContext.ProviderUsers + join u in dbContext.Users + on pu.UserId equals u.Id into u_g + from u in u_g.DefaultIfEmpty() + select new { pu, u }; + var data = await view.Where(e => e.pu.ProviderId == providerId).Select(e => new ProviderUserUserDetails + { + Id = e.pu.Id, + UserId = e.pu.UserId, + ProviderId = e.pu.ProviderId, + Name = e.u.Name, + Email = e.u.Email ?? e.pu.Email, + Status = e.pu.Status, + Type = e.pu.Type, + Permissions = e.pu.Permissions, + }).ToArrayAsync(); + return data; + } + } + + public async Task> GetManyDetailsByUserAsync(Guid userId, ProviderUserStatusType? status = null) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new ProviderUserProviderDetailsReadByUserIdStatusQuery(userId, status); + var data = await query.Run(dbContext).ToArrayAsync(); + return data; + } + } + + public async Task> GetManyPublicKeysByProviderUserAsync(Guid providerId, IEnumerable Ids) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new UserReadPublicKeysByProviderUserIdsQuery(providerId, Ids); + var data = await query.Run(dbContext).ToListAsync(); + return data; + } + } + + public async Task> GetManyOrganizationDetailsByUserAsync(Guid userId, ProviderUserStatusType? status = null) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var view = new ProviderUserOrganizationDetailsViewQuery(); + var query = from ou in view.Run(dbContext) + where ou.UserId == userId && + (status == null || ou.Status == status) + select ou; + var organizationUsers = await query.ToListAsync(); + return organizationUsers; + } + } + + public async Task GetCountByOnlyOwnerAsync(Guid userId) + { + var query = new ProviderUserReadCountByOnlyOwnerQuery(userId); + return await GetCountFromQuery(query); + } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherDetailsQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherDetailsQuery.cs index 38c451c3f..7d676c021 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherDetailsQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherDetailsQuery.cs @@ -1,37 +1,36 @@ using Bit.Core.Utilities; using Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class CipherDetailsQuery : IQuery { - public class CipherDetailsQuery : IQuery + private readonly Guid? _userId; + private readonly bool _ignoreFolders; + public CipherDetailsQuery(Guid? userId, bool ignoreFolders = false) { - private readonly Guid? _userId; - private readonly bool _ignoreFolders; - public CipherDetailsQuery(Guid? userId, bool ignoreFolders = false) - { - _userId = userId; - _ignoreFolders = ignoreFolders; - } - public virtual IQueryable Run(DatabaseContext dbContext) - { - var query = from c in dbContext.Ciphers - select new CipherDetails - { - Id = c.Id, - UserId = c.UserId, - OrganizationId = c.OrganizationId, - Type = c.Type, - Data = c.Data, - Attachments = c.Attachments, - CreationDate = c.CreationDate, - RevisionDate = c.RevisionDate, - DeletedDate = c.DeletedDate, - Favorite = _userId.HasValue && c.Favorites != null && c.Favorites.Contains($"\"{_userId}\":true"), - FolderId = (_ignoreFolders || !_userId.HasValue || c.Folders == null || !c.Folders.Contains(_userId.Value.ToString())) ? - null : - CoreHelpers.LoadClassFromJsonData>(c.Folders)[_userId.Value], - }; - return query; - } + _userId = userId; + _ignoreFolders = ignoreFolders; + } + public virtual IQueryable Run(DatabaseContext dbContext) + { + var query = from c in dbContext.Ciphers + select new CipherDetails + { + Id = c.Id, + UserId = c.UserId, + OrganizationId = c.OrganizationId, + Type = c.Type, + Data = c.Data, + Attachments = c.Attachments, + CreationDate = c.CreationDate, + RevisionDate = c.RevisionDate, + DeletedDate = c.DeletedDate, + Favorite = _userId.HasValue && c.Favorites != null && c.Favorites.Contains($"\"{_userId}\":true"), + FolderId = (_ignoreFolders || !_userId.HasValue || c.Folders == null || !c.Folders.Contains(_userId.Value.ToString())) ? + null : + CoreHelpers.LoadClassFromJsonData>(c.Folders)[_userId.Value], + }; + return query; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherOrganizationDetailsReadByIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherOrganizationDetailsReadByIdQuery.cs index 2bca0c11a..b93954a52 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherOrganizationDetailsReadByIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherOrganizationDetailsReadByIdQuery.cs @@ -1,39 +1,38 @@ using Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class CipherOrganizationDetailsReadByIdQuery : IQuery { - public class CipherOrganizationDetailsReadByIdQuery : IQuery + private readonly Guid _cipherId; + + public CipherOrganizationDetailsReadByIdQuery(Guid cipherId) { - private readonly Guid _cipherId; + _cipherId = cipherId; + } - public CipherOrganizationDetailsReadByIdQuery(Guid cipherId) - { - _cipherId = cipherId; - } - - public virtual IQueryable Run(DatabaseContext dbContext) - { - var query = from c in dbContext.Ciphers - join o in dbContext.Organizations - on c.OrganizationId equals o.Id into o_g - from o in o_g.DefaultIfEmpty() - where c.Id == _cipherId - select new CipherOrganizationDetails - { - Id = c.Id, - UserId = c.UserId, - OrganizationId = c.OrganizationId, - Type = c.Type, - Data = c.Data, - Favorites = c.Favorites, - Folders = c.Folders, - Attachments = c.Attachments, - CreationDate = c.CreationDate, - RevisionDate = c.RevisionDate, - DeletedDate = c.DeletedDate, - OrganizationUseTotp = o.UseTotp, - }; - return query; - } + public virtual IQueryable Run(DatabaseContext dbContext) + { + var query = from c in dbContext.Ciphers + join o in dbContext.Organizations + on c.OrganizationId equals o.Id into o_g + from o in o_g.DefaultIfEmpty() + where c.Id == _cipherId + select new CipherOrganizationDetails + { + Id = c.Id, + UserId = c.UserId, + OrganizationId = c.OrganizationId, + Type = c.Type, + Data = c.Data, + Favorites = c.Favorites, + Folders = c.Folders, + Attachments = c.Attachments, + CreationDate = c.CreationDate, + RevisionDate = c.RevisionDate, + DeletedDate = c.DeletedDate, + OrganizationUseTotp = o.UseTotp, + }; + return query; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherOrganizationDetailsReadByOrgizationIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherOrganizationDetailsReadByOrgizationIdQuery.cs index 84d2779a1..578bd7701 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherOrganizationDetailsReadByOrgizationIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherOrganizationDetailsReadByOrgizationIdQuery.cs @@ -1,38 +1,37 @@ using Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries -{ - public class CipherOrganizationDetailsReadByOrgizationIdQuery : IQuery - { - private readonly Guid _organizationId; +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - public CipherOrganizationDetailsReadByOrgizationIdQuery(Guid organizationId) - { - _organizationId = organizationId; - } - public virtual IQueryable Run(DatabaseContext dbContext) - { - var query = from c in dbContext.Ciphers - join o in dbContext.Organizations - on c.OrganizationId equals o.Id into o_g - from o in o_g.DefaultIfEmpty() - where c.OrganizationId == _organizationId - select new CipherOrganizationDetails - { - Id = c.Id, - UserId = c.UserId, - OrganizationId = c.OrganizationId, - Type = c.Type, - Data = c.Data, - Favorites = c.Favorites, - Folders = c.Folders, - Attachments = c.Attachments, - CreationDate = c.CreationDate, - RevisionDate = c.RevisionDate, - DeletedDate = c.DeletedDate, - OrganizationUseTotp = o.UseTotp, - }; - return query; - } +public class CipherOrganizationDetailsReadByOrgizationIdQuery : IQuery +{ + private readonly Guid _organizationId; + + public CipherOrganizationDetailsReadByOrgizationIdQuery(Guid organizationId) + { + _organizationId = organizationId; + } + public virtual IQueryable Run(DatabaseContext dbContext) + { + var query = from c in dbContext.Ciphers + join o in dbContext.Organizations + on c.OrganizationId equals o.Id into o_g + from o in o_g.DefaultIfEmpty() + where c.OrganizationId == _organizationId + select new CipherOrganizationDetails + { + Id = c.Id, + UserId = c.UserId, + OrganizationId = c.OrganizationId, + Type = c.Type, + Data = c.Data, + Favorites = c.Favorites, + Folders = c.Folders, + Attachments = c.Attachments, + CreationDate = c.CreationDate, + RevisionDate = c.RevisionDate, + DeletedDate = c.DeletedDate, + OrganizationUseTotp = o.UseTotp, + }; + return query; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherReadCanEditByIdUserIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherReadCanEditByIdUserIdQuery.cs index ab9a32b52..4ac3718c7 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherReadCanEditByIdUserIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherReadCanEditByIdUserIdQuery.cs @@ -1,56 +1,55 @@ using Bit.Core.Enums; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class CipherReadCanEditByIdUserIdQuery : IQuery { - public class CipherReadCanEditByIdUserIdQuery : IQuery + private readonly Guid _userId; + private readonly Guid _cipherId; + + public CipherReadCanEditByIdUserIdQuery(Guid userId, Guid cipherId) { - private readonly Guid _userId; - private readonly Guid _cipherId; + _userId = userId; + _cipherId = cipherId; + } - public CipherReadCanEditByIdUserIdQuery(Guid userId, Guid cipherId) - { - _userId = userId; - _cipherId = cipherId; - } - - public virtual IQueryable Run(DatabaseContext dbContext) - { - var query = from c in dbContext.Ciphers - join o in dbContext.Organizations - on c.OrganizationId equals o.Id into o_g - from o in o_g.DefaultIfEmpty() - where !c.UserId.HasValue - join ou in dbContext.OrganizationUsers - on o.Id equals ou.OrganizationId into ou_g - from ou in ou_g.DefaultIfEmpty() - where ou.UserId == _userId - join cc in dbContext.CollectionCiphers - on c.Id equals cc.CipherId into cc_g - from cc in cc_g.DefaultIfEmpty() - where !c.UserId.HasValue && !ou.AccessAll - join cu in dbContext.CollectionUsers - on cc.CollectionId equals cu.CollectionId into cu_g - from cu in cu_g.DefaultIfEmpty() - where ou.Id == cu.OrganizationUserId - join gu in dbContext.GroupUsers - on ou.Id equals gu.OrganizationUserId into gu_g - from gu in gu_g.DefaultIfEmpty() - where !c.UserId.HasValue && cu.CollectionId == null && !ou.AccessAll - join g in dbContext.Groups - on gu.GroupId equals g.Id into g_g - from g in g_g.DefaultIfEmpty() - join cg in dbContext.CollectionGroups - on gu.GroupId equals cg.GroupId into cg_g - from cg in cg_g.DefaultIfEmpty() - where !g.AccessAll && cg.CollectionId == cc.CollectionId && - (c.Id == _cipherId && - (c.UserId == _userId || - (!c.UserId.HasValue && ou.Status == OrganizationUserStatusType.Confirmed && o.Enabled && - (ou.AccessAll || cu.CollectionId != null || g.AccessAll || cg.CollectionId != null)))) && - (c.UserId.HasValue || ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly) - select c; - return query; - } + public virtual IQueryable Run(DatabaseContext dbContext) + { + var query = from c in dbContext.Ciphers + join o in dbContext.Organizations + on c.OrganizationId equals o.Id into o_g + from o in o_g.DefaultIfEmpty() + where !c.UserId.HasValue + join ou in dbContext.OrganizationUsers + on o.Id equals ou.OrganizationId into ou_g + from ou in ou_g.DefaultIfEmpty() + where ou.UserId == _userId + join cc in dbContext.CollectionCiphers + on c.Id equals cc.CipherId into cc_g + from cc in cc_g.DefaultIfEmpty() + where !c.UserId.HasValue && !ou.AccessAll + join cu in dbContext.CollectionUsers + on cc.CollectionId equals cu.CollectionId into cu_g + from cu in cu_g.DefaultIfEmpty() + where ou.Id == cu.OrganizationUserId + join gu in dbContext.GroupUsers + on ou.Id equals gu.OrganizationUserId into gu_g + from gu in gu_g.DefaultIfEmpty() + where !c.UserId.HasValue && cu.CollectionId == null && !ou.AccessAll + join g in dbContext.Groups + on gu.GroupId equals g.Id into g_g + from g in g_g.DefaultIfEmpty() + join cg in dbContext.CollectionGroups + on gu.GroupId equals cg.GroupId into cg_g + from cg in cg_g.DefaultIfEmpty() + where !g.AccessAll && cg.CollectionId == cc.CollectionId && + (c.Id == _cipherId && + (c.UserId == _userId || + (!c.UserId.HasValue && ou.Status == OrganizationUserStatusType.Confirmed && o.Enabled && + (ou.AccessAll || cu.CollectionId != null || g.AccessAll || cg.CollectionId != null)))) && + (c.UserId.HasValue || ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly) + select c; + return query; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherUpdateCollectionsQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherUpdateCollectionsQuery.cs index 25be9135a..859b26182 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CipherUpdateCollectionsQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CipherUpdateCollectionsQuery.cs @@ -2,65 +2,64 @@ using Bit.Core.Enums; using CollectionCipher = Bit.Infrastructure.EntityFramework.Models.CollectionCipher; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class CipherUpdateCollectionsQuery : IQuery { - public class CipherUpdateCollectionsQuery : IQuery + private readonly Cipher _cipher; + private readonly IEnumerable _collectionIds; + + public CipherUpdateCollectionsQuery(Cipher cipher, IEnumerable collectionIds) { - private readonly Cipher _cipher; - private readonly IEnumerable _collectionIds; + _cipher = cipher; + _collectionIds = collectionIds; + } - public CipherUpdateCollectionsQuery(Cipher cipher, IEnumerable collectionIds) + public virtual IQueryable Run(DatabaseContext dbContext) + { + if (!_cipher.OrganizationId.HasValue || !_collectionIds.Any()) { - _cipher = cipher; - _collectionIds = collectionIds; + return null; } - public virtual IQueryable Run(DatabaseContext dbContext) + var availibleCollections = !_cipher.UserId.HasValue ? + from c in dbContext.Collections + where c.OrganizationId == _cipher.OrganizationId + select c.Id : + from c in dbContext.Collections + join o in dbContext.Organizations + on c.OrganizationId equals o.Id + join ou in dbContext.OrganizationUsers + on o.Id equals ou.OrganizationId + where ou.UserId == _cipher.UserId + join cu in dbContext.CollectionUsers + on c.Id equals cu.CollectionId into cu_g + from cu in cu_g.DefaultIfEmpty() + where !ou.AccessAll && cu.OrganizationUserId == ou.Id + join gu in dbContext.GroupUsers + on ou.Id equals gu.OrganizationUserId into gu_g + from gu in gu_g.DefaultIfEmpty() + where cu.CollectionId == null && !ou.AccessAll + join g in dbContext.Groups + on gu.GroupId equals g.Id into g_g + from g in g_g.DefaultIfEmpty() + join cg in dbContext.CollectionGroups + on c.Id equals cg.CollectionId into cg_g + from cg in cg_g.DefaultIfEmpty() + where !g.AccessAll && gu.GroupId == cg.GroupId && + o.Id == _cipher.OrganizationId && + o.Enabled && + ou.Status == OrganizationUserStatusType.Confirmed && + (ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly) + select c.Id; + + if (!availibleCollections.Any()) { - if (!_cipher.OrganizationId.HasValue || !_collectionIds.Any()) - { - return null; - } - - var availibleCollections = !_cipher.UserId.HasValue ? - from c in dbContext.Collections - where c.OrganizationId == _cipher.OrganizationId - select c.Id : - from c in dbContext.Collections - join o in dbContext.Organizations - on c.OrganizationId equals o.Id - join ou in dbContext.OrganizationUsers - on o.Id equals ou.OrganizationId - where ou.UserId == _cipher.UserId - join cu in dbContext.CollectionUsers - on c.Id equals cu.CollectionId into cu_g - from cu in cu_g.DefaultIfEmpty() - where !ou.AccessAll && cu.OrganizationUserId == ou.Id - join gu in dbContext.GroupUsers - on ou.Id equals gu.OrganizationUserId into gu_g - from gu in gu_g.DefaultIfEmpty() - where cu.CollectionId == null && !ou.AccessAll - join g in dbContext.Groups - on gu.GroupId equals g.Id into g_g - from g in g_g.DefaultIfEmpty() - join cg in dbContext.CollectionGroups - on c.Id equals cg.CollectionId into cg_g - from cg in cg_g.DefaultIfEmpty() - where !g.AccessAll && gu.GroupId == cg.GroupId && - o.Id == _cipher.OrganizationId && - o.Enabled && - ou.Status == OrganizationUserStatusType.Confirmed && - (ou.AccessAll || !cu.ReadOnly || g.AccessAll || !cg.ReadOnly) - select c.Id; - - if (!availibleCollections.Any()) - { - return null; - } - - var query = from c in availibleCollections - select new CollectionCipher { CollectionId = c, CipherId = _cipher.Id }; - return query; + return null; } + + var query = from c in availibleCollections + select new CollectionCipher { CollectionId = c, CipherId = _cipher.Id }; + return query; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionCipherReadByUserIdCipherIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionCipherReadByUserIdCipherIdQuery.cs index 51fcb15fd..e494aec1f 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionCipherReadByUserIdCipherIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionCipherReadByUserIdCipherIdQuery.cs @@ -1,20 +1,19 @@ using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class CollectionCipherReadByUserIdCipherIdQuery : CollectionCipherReadByUserIdQuery { - public class CollectionCipherReadByUserIdCipherIdQuery : CollectionCipherReadByUserIdQuery + private readonly Guid _cipherId; + + public CollectionCipherReadByUserIdCipherIdQuery(Guid userId, Guid cipherId) : base(userId) { - private readonly Guid _cipherId; + _cipherId = cipherId; + } - public CollectionCipherReadByUserIdCipherIdQuery(Guid userId, Guid cipherId) : base(userId) - { - _cipherId = cipherId; - } - - public override IQueryable Run(DatabaseContext dbContext) - { - var query = base.Run(dbContext); - return query.Where(x => x.CipherId == _cipherId); - } + public override IQueryable Run(DatabaseContext dbContext) + { + var query = base.Run(dbContext); + return query.Where(x => x.CipherId == _cipherId); } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionCipherReadByUserIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionCipherReadByUserIdQuery.cs index 6c8e17372..156707b46 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionCipherReadByUserIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionCipherReadByUserIdQuery.cs @@ -1,44 +1,43 @@ using Bit.Core.Enums; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class CollectionCipherReadByUserIdQuery : IQuery { - public class CollectionCipherReadByUserIdQuery : IQuery + private readonly Guid _userId; + + public CollectionCipherReadByUserIdQuery(Guid userId) { - private readonly Guid _userId; + _userId = userId; + } - public CollectionCipherReadByUserIdQuery(Guid userId) - { - _userId = userId; - } - - public virtual IQueryable Run(DatabaseContext dbContext) - { - var query = from cc in dbContext.CollectionCiphers - join c in dbContext.Collections - on cc.CollectionId equals c.Id - join ou in dbContext.OrganizationUsers - on c.OrganizationId equals ou.OrganizationId - where ou.UserId == _userId - join cu in dbContext.CollectionUsers - on c.Id equals cu.CollectionId into cu_g - from cu in cu_g - where ou.AccessAll && cu.OrganizationUserId == ou.Id - join gu in dbContext.GroupUsers - on ou.Id equals gu.OrganizationUserId into gu_g - from gu in gu_g - where cu.CollectionId == null && !ou.AccessAll - join g in dbContext.Groups - on gu.GroupId equals g.Id into g_g - from g in g_g - join cg in dbContext.CollectionGroups - on cc.CollectionId equals cg.CollectionId into cg_g - from cg in cg_g - where g.AccessAll && cg.GroupId == gu.GroupId && - ou.Status == OrganizationUserStatusType.Confirmed && - (ou.AccessAll || cu.CollectionId != null || g.AccessAll || cg.CollectionId != null) - select cc; - return query; - } + public virtual IQueryable Run(DatabaseContext dbContext) + { + var query = from cc in dbContext.CollectionCiphers + join c in dbContext.Collections + on cc.CollectionId equals c.Id + join ou in dbContext.OrganizationUsers + on c.OrganizationId equals ou.OrganizationId + where ou.UserId == _userId + join cu in dbContext.CollectionUsers + on c.Id equals cu.CollectionId into cu_g + from cu in cu_g + where ou.AccessAll && cu.OrganizationUserId == ou.Id + join gu in dbContext.GroupUsers + on ou.Id equals gu.OrganizationUserId into gu_g + from gu in gu_g + where cu.CollectionId == null && !ou.AccessAll + join g in dbContext.Groups + on gu.GroupId equals g.Id into g_g + from g in g_g + join cg in dbContext.CollectionGroups + on cc.CollectionId equals cg.CollectionId into cg_g + from cg in cg_g + where g.AccessAll && cg.GroupId == gu.GroupId && + ou.Status == OrganizationUserStatusType.Confirmed && + (ou.AccessAll || cu.CollectionId != null || g.AccessAll || cg.CollectionId != null) + select cc; + return query; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionReadCountByOrganizationIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionReadCountByOrganizationIdQuery.cs index de878db34..90e800398 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionReadCountByOrganizationIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionReadCountByOrganizationIdQuery.cs @@ -1,22 +1,21 @@ using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class CollectionReadCountByOrganizationIdQuery : IQuery { - public class CollectionReadCountByOrganizationIdQuery : IQuery + private readonly Guid _organizationId; + + public CollectionReadCountByOrganizationIdQuery(Guid organizationId) { - private readonly Guid _organizationId; + _organizationId = organizationId; + } - public CollectionReadCountByOrganizationIdQuery(Guid organizationId) - { - _organizationId = organizationId; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var query = from c in dbContext.Collections - where c.OrganizationId == _organizationId - select c; - return query; - } + public IQueryable Run(DatabaseContext dbContext) + { + var query = from c in dbContext.Collections + where c.OrganizationId == _organizationId + select c; + return query; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionUserUpdateUsersQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionUserUpdateUsersQuery.cs index 45023772a..db2d91190 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionUserUpdateUsersQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionUserUpdateUsersQuery.cs @@ -2,116 +2,115 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class CollectionUserUpdateUsersQuery { - public class CollectionUserUpdateUsersQuery + public readonly CollectionUserUpdateUsersInsertQuery Insert; + public readonly CollectionUserUpdateUsersUpdateQuery Update; + public readonly CollectionUserUpdateUsersDeleteQuery Delete; + + public CollectionUserUpdateUsersQuery(Guid collectionId, IEnumerable users) { - public readonly CollectionUserUpdateUsersInsertQuery Insert; - public readonly CollectionUserUpdateUsersUpdateQuery Update; - public readonly CollectionUserUpdateUsersDeleteQuery Delete; - - public CollectionUserUpdateUsersQuery(Guid collectionId, IEnumerable users) - { - Insert = new CollectionUserUpdateUsersInsertQuery(collectionId, users); - Update = new CollectionUserUpdateUsersUpdateQuery(collectionId, users); - Delete = new CollectionUserUpdateUsersDeleteQuery(collectionId, users); - } - } - - public class CollectionUserUpdateUsersInsertQuery : IQuery - { - private readonly Guid _collectionId; - private readonly IEnumerable _users; - - public CollectionUserUpdateUsersInsertQuery(Guid collectionId, IEnumerable users) - { - _collectionId = collectionId; - _users = users; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var orgId = dbContext.Collections.FirstOrDefault(c => c.Id == _collectionId)?.OrganizationId; - var organizationUserIds = _users.Select(u => u.Id); - var insertQuery = from ou in dbContext.OrganizationUsers - where - organizationUserIds.Contains(ou.Id) && - ou.OrganizationId == orgId && - !dbContext.CollectionUsers.Any( - x => x.CollectionId != _collectionId && x.OrganizationUserId == ou.Id) - select ou; - return insertQuery; - } - - public async Task> BuildInMemory(DatabaseContext dbContext) - { - var data = await Run(dbContext).ToListAsync(); - var collectionUsers = data.Select(x => new CollectionUser() - { - CollectionId = _collectionId, - OrganizationUserId = x.Id, - ReadOnly = _users.FirstOrDefault(u => u.Id.Equals(x.Id)).ReadOnly, - HidePasswords = _users.FirstOrDefault(u => u.Id.Equals(x.Id)).HidePasswords, - }); - return collectionUsers; - } - } - - public class CollectionUserUpdateUsersUpdateQuery : IQuery - { - private readonly Guid _collectionId; - private readonly IEnumerable _users; - - public CollectionUserUpdateUsersUpdateQuery(Guid collectionId, IEnumerable users) - { - _collectionId = collectionId; - _users = users; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var orgId = dbContext.Collections.FirstOrDefault(c => c.Id == _collectionId)?.OrganizationId; - var ids = _users.Select(x => x.Id); - var updateQuery = from target in dbContext.CollectionUsers - where target.CollectionId == _collectionId && - ids.Contains(target.OrganizationUserId) - select target; - return updateQuery; - } - - public async Task> BuildInMemory(DatabaseContext dbContext) - { - var data = await Run(dbContext).ToListAsync(); - var collectionUsers = data.Select(x => new CollectionUser - { - CollectionId = _collectionId, - OrganizationUserId = x.OrganizationUserId, - ReadOnly = _users.FirstOrDefault(u => u.Id.Equals(x.OrganizationUserId)).ReadOnly, - HidePasswords = _users.FirstOrDefault(u => u.Id.Equals(x.OrganizationUserId)).HidePasswords, - }); - return collectionUsers; - } - } - - public class CollectionUserUpdateUsersDeleteQuery : IQuery - { - private readonly Guid _collectionId; - private readonly IEnumerable _users; - - public CollectionUserUpdateUsersDeleteQuery(Guid collectionId, IEnumerable users) - { - _collectionId = collectionId; - _users = users; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var orgId = dbContext.Collections.FirstOrDefault(c => c.Id == _collectionId)?.OrganizationId; - var deleteQuery = from cu in dbContext.CollectionUsers - where !dbContext.Users.Any( - u => u.Id == cu.OrganizationUserId) - select cu; - return deleteQuery; - } + Insert = new CollectionUserUpdateUsersInsertQuery(collectionId, users); + Update = new CollectionUserUpdateUsersUpdateQuery(collectionId, users); + Delete = new CollectionUserUpdateUsersDeleteQuery(collectionId, users); + } +} + +public class CollectionUserUpdateUsersInsertQuery : IQuery +{ + private readonly Guid _collectionId; + private readonly IEnumerable _users; + + public CollectionUserUpdateUsersInsertQuery(Guid collectionId, IEnumerable users) + { + _collectionId = collectionId; + _users = users; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var orgId = dbContext.Collections.FirstOrDefault(c => c.Id == _collectionId)?.OrganizationId; + var organizationUserIds = _users.Select(u => u.Id); + var insertQuery = from ou in dbContext.OrganizationUsers + where + organizationUserIds.Contains(ou.Id) && + ou.OrganizationId == orgId && + !dbContext.CollectionUsers.Any( + x => x.CollectionId != _collectionId && x.OrganizationUserId == ou.Id) + select ou; + return insertQuery; + } + + public async Task> BuildInMemory(DatabaseContext dbContext) + { + var data = await Run(dbContext).ToListAsync(); + var collectionUsers = data.Select(x => new CollectionUser() + { + CollectionId = _collectionId, + OrganizationUserId = x.Id, + ReadOnly = _users.FirstOrDefault(u => u.Id.Equals(x.Id)).ReadOnly, + HidePasswords = _users.FirstOrDefault(u => u.Id.Equals(x.Id)).HidePasswords, + }); + return collectionUsers; + } +} + +public class CollectionUserUpdateUsersUpdateQuery : IQuery +{ + private readonly Guid _collectionId; + private readonly IEnumerable _users; + + public CollectionUserUpdateUsersUpdateQuery(Guid collectionId, IEnumerable users) + { + _collectionId = collectionId; + _users = users; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var orgId = dbContext.Collections.FirstOrDefault(c => c.Id == _collectionId)?.OrganizationId; + var ids = _users.Select(x => x.Id); + var updateQuery = from target in dbContext.CollectionUsers + where target.CollectionId == _collectionId && + ids.Contains(target.OrganizationUserId) + select target; + return updateQuery; + } + + public async Task> BuildInMemory(DatabaseContext dbContext) + { + var data = await Run(dbContext).ToListAsync(); + var collectionUsers = data.Select(x => new CollectionUser + { + CollectionId = _collectionId, + OrganizationUserId = x.OrganizationUserId, + ReadOnly = _users.FirstOrDefault(u => u.Id.Equals(x.OrganizationUserId)).ReadOnly, + HidePasswords = _users.FirstOrDefault(u => u.Id.Equals(x.OrganizationUserId)).HidePasswords, + }); + return collectionUsers; + } +} + +public class CollectionUserUpdateUsersDeleteQuery : IQuery +{ + private readonly Guid _collectionId; + private readonly IEnumerable _users; + + public CollectionUserUpdateUsersDeleteQuery(Guid collectionId, IEnumerable users) + { + _collectionId = collectionId; + _users = users; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var orgId = dbContext.Collections.FirstOrDefault(c => c.Id == _collectionId)?.OrganizationId; + var deleteQuery = from cu in dbContext.CollectionUsers + where !dbContext.Users.Any( + u => u.Id == cu.OrganizationUserId) + select cu; + return deleteQuery; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/EmergencyAccessDetailsViewQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/EmergencyAccessDetailsViewQuery.cs index 24c1bda8d..2ad2149ae 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/EmergencyAccessDetailsViewQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/EmergencyAccessDetailsViewQuery.cs @@ -1,38 +1,37 @@ using Bit.Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class EmergencyAccessDetailsViewQuery : IQuery { - public class EmergencyAccessDetailsViewQuery : IQuery + public IQueryable Run(DatabaseContext dbContext) { - public IQueryable Run(DatabaseContext dbContext) + var query = from ea in dbContext.EmergencyAccesses + join grantee in dbContext.Users + on ea.GranteeId equals grantee.Id into grantee_g + from grantee in grantee_g.DefaultIfEmpty() + join grantor in dbContext.Users + on ea.GrantorId equals grantor.Id into grantor_g + from grantor in grantor_g.DefaultIfEmpty() + select new { ea, grantee, grantor }; + return query.Select(x => new EmergencyAccessDetails { - var query = from ea in dbContext.EmergencyAccesses - join grantee in dbContext.Users - on ea.GranteeId equals grantee.Id into grantee_g - from grantee in grantee_g.DefaultIfEmpty() - join grantor in dbContext.Users - on ea.GrantorId equals grantor.Id into grantor_g - from grantor in grantor_g.DefaultIfEmpty() - select new { ea, grantee, grantor }; - return query.Select(x => new EmergencyAccessDetails - { - Id = x.ea.Id, - GrantorId = x.ea.GrantorId, - GranteeId = x.ea.GranteeId, - Email = x.ea.Email, - KeyEncrypted = x.ea.KeyEncrypted, - Type = x.ea.Type, - Status = x.ea.Status, - WaitTimeDays = x.ea.WaitTimeDays, - RecoveryInitiatedDate = x.ea.RecoveryInitiatedDate, - LastNotificationDate = x.ea.LastNotificationDate, - CreationDate = x.ea.CreationDate, - RevisionDate = x.ea.RevisionDate, - GranteeName = x.grantee.Name, - GranteeEmail = x.grantee.Email, - GrantorName = x.grantor.Name, - GrantorEmail = x.grantor.Email, - }); - } + Id = x.ea.Id, + GrantorId = x.ea.GrantorId, + GranteeId = x.ea.GranteeId, + Email = x.ea.Email, + KeyEncrypted = x.ea.KeyEncrypted, + Type = x.ea.Type, + Status = x.ea.Status, + WaitTimeDays = x.ea.WaitTimeDays, + RecoveryInitiatedDate = x.ea.RecoveryInitiatedDate, + LastNotificationDate = x.ea.LastNotificationDate, + CreationDate = x.ea.CreationDate, + RevisionDate = x.ea.RevisionDate, + GranteeName = x.grantee.Name, + GranteeEmail = x.grantee.Email, + GrantorName = x.grantor.Name, + GrantorEmail = x.grantor.Email, + }); } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/EmergencyAccessReadCountByGrantorIdEmailQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/EmergencyAccessReadCountByGrantorIdEmailQuery.cs index d28ce1372..3a09fa857 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/EmergencyAccessReadCountByGrantorIdEmailQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/EmergencyAccessReadCountByGrantorIdEmailQuery.cs @@ -1,31 +1,30 @@ using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class EmergencyAccessReadCountByGrantorIdEmailQuery : IQuery { - public class EmergencyAccessReadCountByGrantorIdEmailQuery : IQuery + private readonly Guid _grantorId; + private readonly string _email; + private readonly bool _onlyRegisteredUsers; + + public EmergencyAccessReadCountByGrantorIdEmailQuery(Guid grantorId, string email, bool onlyRegisteredUsers) { - private readonly Guid _grantorId; - private readonly string _email; - private readonly bool _onlyRegisteredUsers; + _grantorId = grantorId; + _email = email; + _onlyRegisteredUsers = onlyRegisteredUsers; + } - public EmergencyAccessReadCountByGrantorIdEmailQuery(Guid grantorId, string email, bool onlyRegisteredUsers) - { - _grantorId = grantorId; - _email = email; - _onlyRegisteredUsers = onlyRegisteredUsers; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var query = from ea in dbContext.EmergencyAccesses - join u in dbContext.Users - on ea.GranteeId equals u.Id into u_g - from u in u_g.DefaultIfEmpty() - where ea.GrantorId == _grantorId && - ((!_onlyRegisteredUsers && (ea.Email == _email || u.Email == _email)) - || (_onlyRegisteredUsers && u.Email == _email)) - select ea; - return query; - } + public IQueryable Run(DatabaseContext dbContext) + { + var query = from ea in dbContext.EmergencyAccesses + join u in dbContext.Users + on ea.GranteeId equals u.Id into u_g + from u in u_g.DefaultIfEmpty() + where ea.GrantorId == _grantorId && + ((!_onlyRegisteredUsers && (ea.Email == _email || u.Email == _email)) + || (_onlyRegisteredUsers && u.Email == _email)) + select ea; + return query; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByCipherIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByCipherIdQuery.cs index d94d130ef..570f3a249 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByCipherIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByCipherIdQuery.cs @@ -2,47 +2,46 @@ using Bit.Core.Models.Data; using Event = Bit.Infrastructure.EntityFramework.Models.Event; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class EventReadPageByCipherIdQuery : IQuery { - public class EventReadPageByCipherIdQuery : IQuery + private readonly Cipher _cipher; + private readonly DateTime _startDate; + private readonly DateTime _endDate; + private readonly DateTime? _beforeDate; + private readonly PageOptions _pageOptions; + + public EventReadPageByCipherIdQuery(Cipher cipher, DateTime startDate, DateTime endDate, PageOptions pageOptions) { - private readonly Cipher _cipher; - private readonly DateTime _startDate; - private readonly DateTime _endDate; - private readonly DateTime? _beforeDate; - private readonly PageOptions _pageOptions; + _cipher = cipher; + _startDate = startDate; + _endDate = endDate; + _beforeDate = null; + _pageOptions = pageOptions; + } - public EventReadPageByCipherIdQuery(Cipher cipher, DateTime startDate, DateTime endDate, PageOptions pageOptions) - { - _cipher = cipher; - _startDate = startDate; - _endDate = endDate; - _beforeDate = null; - _pageOptions = pageOptions; - } + public EventReadPageByCipherIdQuery(Cipher cipher, DateTime startDate, DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) + { + _cipher = cipher; + _startDate = startDate; + _endDate = endDate; + _beforeDate = beforeDate; + _pageOptions = pageOptions; + } - public EventReadPageByCipherIdQuery(Cipher cipher, DateTime startDate, DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) - { - _cipher = cipher; - _startDate = startDate; - _endDate = endDate; - _beforeDate = beforeDate; - _pageOptions = pageOptions; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var q = from e in dbContext.Events - where e.Date >= _startDate && - (_beforeDate == null || e.Date < _beforeDate.Value) && - ((!_cipher.OrganizationId.HasValue && !e.OrganizationId.HasValue) || - (_cipher.OrganizationId.HasValue && _cipher.OrganizationId == e.OrganizationId)) && - ((!_cipher.UserId.HasValue && !e.UserId.HasValue) || - (_cipher.UserId.HasValue && _cipher.UserId == e.UserId)) && - _cipher.Id == e.CipherId - orderby e.Date descending - select e; - return q.Skip(0).Take(_pageOptions.PageSize); - } + public IQueryable Run(DatabaseContext dbContext) + { + var q = from e in dbContext.Events + where e.Date >= _startDate && + (_beforeDate == null || e.Date < _beforeDate.Value) && + ((!_cipher.OrganizationId.HasValue && !e.OrganizationId.HasValue) || + (_cipher.OrganizationId.HasValue && _cipher.OrganizationId == e.OrganizationId)) && + ((!_cipher.UserId.HasValue && !e.UserId.HasValue) || + (_cipher.UserId.HasValue && _cipher.UserId == e.UserId)) && + _cipher.Id == e.CipherId + orderby e.Date descending + select e; + return q.Skip(0).Take(_pageOptions.PageSize); } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByOrganizationIdActingUserIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByOrganizationIdActingUserIdQuery.cs index f9553dd38..8e49ca239 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByOrganizationIdActingUserIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByOrganizationIdActingUserIdQuery.cs @@ -1,39 +1,38 @@ using Bit.Core.Models.Data; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class EventReadPageByOrganizationIdActingUserIdQuery : IQuery { - public class EventReadPageByOrganizationIdActingUserIdQuery : IQuery + private readonly Guid _organizationId; + private readonly Guid _actingUserId; + private readonly DateTime _startDate; + private readonly DateTime _endDate; + private readonly DateTime? _beforeDate; + private readonly PageOptions _pageOptions; + + public EventReadPageByOrganizationIdActingUserIdQuery(Guid organizationId, Guid actingUserId, + DateTime startDate, DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) { - private readonly Guid _organizationId; - private readonly Guid _actingUserId; - private readonly DateTime _startDate; - private readonly DateTime _endDate; - private readonly DateTime? _beforeDate; - private readonly PageOptions _pageOptions; + _organizationId = organizationId; + _actingUserId = actingUserId; + _startDate = startDate; + _endDate = endDate; + _beforeDate = beforeDate; + _pageOptions = pageOptions; + } - public EventReadPageByOrganizationIdActingUserIdQuery(Guid organizationId, Guid actingUserId, - DateTime startDate, DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) - { - _organizationId = organizationId; - _actingUserId = actingUserId; - _startDate = startDate; - _endDate = endDate; - _beforeDate = beforeDate; - _pageOptions = pageOptions; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var q = from e in dbContext.Events - where e.Date >= _startDate && - (_beforeDate != null || e.Date <= _endDate) && - (_beforeDate == null || e.Date < _beforeDate.Value) && - e.OrganizationId == _organizationId && - e.ActingUserId == _actingUserId - orderby e.Date descending - select e; - return q.Skip(0).Take(_pageOptions.PageSize); - } + public IQueryable Run(DatabaseContext dbContext) + { + var q = from e in dbContext.Events + where e.Date >= _startDate && + (_beforeDate != null || e.Date <= _endDate) && + (_beforeDate == null || e.Date < _beforeDate.Value) && + e.OrganizationId == _organizationId && + e.ActingUserId == _actingUserId + orderby e.Date descending + select e; + return q.Skip(0).Take(_pageOptions.PageSize); } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByOrganizationIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByOrganizationIdQuery.cs index 261bef32a..ce0de6afc 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByOrganizationIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByOrganizationIdQuery.cs @@ -1,36 +1,35 @@ using Bit.Core.Models.Data; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class EventReadPageByOrganizationIdQuery : IQuery { - public class EventReadPageByOrganizationIdQuery : IQuery + private readonly Guid _organizationId; + private readonly DateTime _startDate; + private readonly DateTime _endDate; + private readonly DateTime? _beforeDate; + private readonly PageOptions _pageOptions; + + public EventReadPageByOrganizationIdQuery(Guid organizationId, DateTime startDate, + DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) { - private readonly Guid _organizationId; - private readonly DateTime _startDate; - private readonly DateTime _endDate; - private readonly DateTime? _beforeDate; - private readonly PageOptions _pageOptions; + _organizationId = organizationId; + _startDate = startDate; + _endDate = endDate; + _beforeDate = beforeDate; + _pageOptions = pageOptions; + } - public EventReadPageByOrganizationIdQuery(Guid organizationId, DateTime startDate, - DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) - { - _organizationId = organizationId; - _startDate = startDate; - _endDate = endDate; - _beforeDate = beforeDate; - _pageOptions = pageOptions; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var q = from e in dbContext.Events - where e.Date >= _startDate && - (_beforeDate != null || e.Date <= _endDate) && - (_beforeDate == null || e.Date < _beforeDate.Value) && - e.OrganizationId == _organizationId - orderby e.Date descending - select e; - return q.Skip(0).Take(_pageOptions.PageSize); - } + public IQueryable Run(DatabaseContext dbContext) + { + var q = from e in dbContext.Events + where e.Date >= _startDate && + (_beforeDate != null || e.Date <= _endDate) && + (_beforeDate == null || e.Date < _beforeDate.Value) && + e.OrganizationId == _organizationId + orderby e.Date descending + select e; + return q.Skip(0).Take(_pageOptions.PageSize); } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByProviderIdActingUserIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByProviderIdActingUserIdQuery.cs index 4b08ecf2b..171b4e26c 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByProviderIdActingUserIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByProviderIdActingUserIdQuery.cs @@ -1,39 +1,38 @@ using Bit.Core.Models.Data; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class EventReadPageByProviderIdActingUserIdQuery : IQuery { - public class EventReadPageByProviderIdActingUserIdQuery : IQuery + private readonly Guid _providerId; + private readonly Guid _actingUserId; + private readonly DateTime _startDate; + private readonly DateTime _endDate; + private readonly DateTime? _beforeDate; + private readonly PageOptions _pageOptions; + + public EventReadPageByProviderIdActingUserIdQuery(Guid providerId, Guid actingUserId, + DateTime startDate, DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) { - private readonly Guid _providerId; - private readonly Guid _actingUserId; - private readonly DateTime _startDate; - private readonly DateTime _endDate; - private readonly DateTime? _beforeDate; - private readonly PageOptions _pageOptions; + _providerId = providerId; + _actingUserId = actingUserId; + _startDate = startDate; + _endDate = endDate; + _beforeDate = beforeDate; + _pageOptions = pageOptions; + } - public EventReadPageByProviderIdActingUserIdQuery(Guid providerId, Guid actingUserId, - DateTime startDate, DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) - { - _providerId = providerId; - _actingUserId = actingUserId; - _startDate = startDate; - _endDate = endDate; - _beforeDate = beforeDate; - _pageOptions = pageOptions; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var q = from e in dbContext.Events - where e.Date >= _startDate && - (_beforeDate != null || e.Date <= _endDate) && - (_beforeDate == null || e.Date < _beforeDate.Value) && - e.ProviderId == _providerId && - e.ActingUserId == _actingUserId - orderby e.Date descending - select e; - return q.Skip(0).Take(_pageOptions.PageSize); - } + public IQueryable Run(DatabaseContext dbContext) + { + var q = from e in dbContext.Events + where e.Date >= _startDate && + (_beforeDate != null || e.Date <= _endDate) && + (_beforeDate == null || e.Date < _beforeDate.Value) && + e.ProviderId == _providerId && + e.ActingUserId == _actingUserId + orderby e.Date descending + select e; + return q.Skip(0).Take(_pageOptions.PageSize); } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByProviderIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByProviderIdQuery.cs index 49e8f518b..52421b9e9 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByProviderIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByProviderIdQuery.cs @@ -1,36 +1,35 @@ using Bit.Core.Models.Data; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class EventReadPageByProviderIdQuery : IQuery { - public class EventReadPageByProviderIdQuery : IQuery + private readonly Guid _providerId; + private readonly DateTime _startDate; + private readonly DateTime _endDate; + private readonly DateTime? _beforeDate; + private readonly PageOptions _pageOptions; + + public EventReadPageByProviderIdQuery(Guid providerId, DateTime startDate, + DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) { - private readonly Guid _providerId; - private readonly DateTime _startDate; - private readonly DateTime _endDate; - private readonly DateTime? _beforeDate; - private readonly PageOptions _pageOptions; + _providerId = providerId; + _startDate = startDate; + _endDate = endDate; + _beforeDate = beforeDate; + _pageOptions = pageOptions; + } - public EventReadPageByProviderIdQuery(Guid providerId, DateTime startDate, - DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) - { - _providerId = providerId; - _startDate = startDate; - _endDate = endDate; - _beforeDate = beforeDate; - _pageOptions = pageOptions; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var q = from e in dbContext.Events - where e.Date >= _startDate && - (_beforeDate != null || e.Date <= _endDate) && - (_beforeDate == null || e.Date < _beforeDate.Value) && - e.ProviderId == _providerId && e.OrganizationId == null - orderby e.Date descending - select e; - return q.Skip(0).Take(_pageOptions.PageSize); - } + public IQueryable Run(DatabaseContext dbContext) + { + var q = from e in dbContext.Events + where e.Date >= _startDate && + (_beforeDate != null || e.Date <= _endDate) && + (_beforeDate == null || e.Date < _beforeDate.Value) && + e.ProviderId == _providerId && e.OrganizationId == null + orderby e.Date descending + select e; + return q.Skip(0).Take(_pageOptions.PageSize); } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByUserIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByUserIdQuery.cs index 3e7ff4cc3..d173c4842 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByUserIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/EventReadPageByUserIdQuery.cs @@ -1,37 +1,36 @@ using Bit.Core.Models.Data; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class EventReadPageByUserIdQuery : IQuery { - public class EventReadPageByUserIdQuery : IQuery + private readonly Guid _userId; + private readonly DateTime _startDate; + private readonly DateTime _endDate; + private readonly DateTime? _beforeDate; + private readonly PageOptions _pageOptions; + + public EventReadPageByUserIdQuery(Guid userId, DateTime startDate, + DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) { - private readonly Guid _userId; - private readonly DateTime _startDate; - private readonly DateTime _endDate; - private readonly DateTime? _beforeDate; - private readonly PageOptions _pageOptions; + _userId = userId; + _startDate = startDate; + _endDate = endDate; + _beforeDate = beforeDate; + _pageOptions = pageOptions; + } - public EventReadPageByUserIdQuery(Guid userId, DateTime startDate, - DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) - { - _userId = userId; - _startDate = startDate; - _endDate = endDate; - _beforeDate = beforeDate; - _pageOptions = pageOptions; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var q = from e in dbContext.Events - where e.Date >= _startDate && - (_beforeDate != null || e.Date <= _endDate) && - (_beforeDate == null || e.Date < _beforeDate.Value) && - !e.OrganizationId.HasValue && - e.ActingUserId == _userId - orderby e.Date descending - select e; - return q.Skip(0).Take(_pageOptions.PageSize); - } + public IQueryable Run(DatabaseContext dbContext) + { + var q = from e in dbContext.Events + where e.Date >= _startDate && + (_beforeDate != null || e.Date <= _endDate) && + (_beforeDate == null || e.Date < _beforeDate.Value) && + !e.OrganizationId.HasValue && + e.ActingUserId == _userId + orderby e.Date descending + select e; + return q.Skip(0).Take(_pageOptions.PageSize); } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/GroupUserUpdateGroupsQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/GroupUserUpdateGroupsQuery.cs index 580199caf..dacbabb28 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/GroupUserUpdateGroupsQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/GroupUserUpdateGroupsQuery.cs @@ -1,69 +1,68 @@ using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class GroupUserUpdateGroupsQuery { - public class GroupUserUpdateGroupsQuery + public readonly GroupUserUpdateGroupsInsertQuery Insert; + public readonly GroupUserUpdateGroupsDeleteQuery Delete; + + public GroupUserUpdateGroupsQuery(Guid organizationUserId, IEnumerable groupIds) { - public readonly GroupUserUpdateGroupsInsertQuery Insert; - public readonly GroupUserUpdateGroupsDeleteQuery Delete; - - public GroupUserUpdateGroupsQuery(Guid organizationUserId, IEnumerable groupIds) - { - Insert = new GroupUserUpdateGroupsInsertQuery(organizationUserId, groupIds); - Delete = new GroupUserUpdateGroupsDeleteQuery(organizationUserId, groupIds); - } - } - - public class GroupUserUpdateGroupsInsertQuery : IQuery - { - private readonly Guid _organizationUserId; - private readonly IEnumerable _groupIds; - - public GroupUserUpdateGroupsInsertQuery(Guid organizationUserId, IEnumerable collections) - { - _organizationUserId = organizationUserId; - _groupIds = collections; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var orgUser = from ou in dbContext.OrganizationUsers - where ou.Id == _organizationUserId - select ou; - var groupIdEntities = dbContext.Groups.Where(x => _groupIds.Contains(x.Id)); - var query = from g in dbContext.Groups - join ou in orgUser - on g.OrganizationId equals ou.OrganizationId - join gie in groupIdEntities - on g.Id equals gie.Id - where !dbContext.GroupUsers.Any(gu => _groupIds.Contains(gu.GroupId) && gu.OrganizationUserId == _organizationUserId) - select g; - return query.Select(x => new GroupUser - { - GroupId = x.Id, - OrganizationUserId = _organizationUserId, - }); - } - } - - public class GroupUserUpdateGroupsDeleteQuery : IQuery - { - private readonly Guid _organizationUserId; - private readonly IEnumerable _groupIds; - - public GroupUserUpdateGroupsDeleteQuery(Guid organizationUserId, IEnumerable groupIds) - { - _organizationUserId = organizationUserId; - _groupIds = groupIds; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var deleteQuery = from gu in dbContext.GroupUsers - where gu.OrganizationUserId == _organizationUserId && - !_groupIds.Any(x => gu.GroupId == x) - select gu; - return deleteQuery; - } + Insert = new GroupUserUpdateGroupsInsertQuery(organizationUserId, groupIds); + Delete = new GroupUserUpdateGroupsDeleteQuery(organizationUserId, groupIds); + } +} + +public class GroupUserUpdateGroupsInsertQuery : IQuery +{ + private readonly Guid _organizationUserId; + private readonly IEnumerable _groupIds; + + public GroupUserUpdateGroupsInsertQuery(Guid organizationUserId, IEnumerable collections) + { + _organizationUserId = organizationUserId; + _groupIds = collections; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var orgUser = from ou in dbContext.OrganizationUsers + where ou.Id == _organizationUserId + select ou; + var groupIdEntities = dbContext.Groups.Where(x => _groupIds.Contains(x.Id)); + var query = from g in dbContext.Groups + join ou in orgUser + on g.OrganizationId equals ou.OrganizationId + join gie in groupIdEntities + on g.Id equals gie.Id + where !dbContext.GroupUsers.Any(gu => _groupIds.Contains(gu.GroupId) && gu.OrganizationUserId == _organizationUserId) + select g; + return query.Select(x => new GroupUser + { + GroupId = x.Id, + OrganizationUserId = _organizationUserId, + }); + } +} + +public class GroupUserUpdateGroupsDeleteQuery : IQuery +{ + private readonly Guid _organizationUserId; + private readonly IEnumerable _groupIds; + + public GroupUserUpdateGroupsDeleteQuery(Guid organizationUserId, IEnumerable groupIds) + { + _organizationUserId = organizationUserId; + _groupIds = groupIds; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var deleteQuery = from gu in dbContext.GroupUsers + where gu.OrganizationUserId == _organizationUserId && + !_groupIds.Any(x => gu.GroupId == x) + select gu; + return deleteQuery; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/IQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/IQuery.cs index 8729f5b15..554efe0b7 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/IQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/IQuery.cs @@ -1,7 +1,6 @@ -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public interface IQuery { - public interface IQuery - { - IQueryable Run(DatabaseContext dbContext); - } + IQueryable Run(DatabaseContext dbContext); } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs index cd0122970..84dc4a7ad 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs @@ -1,65 +1,64 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries -{ - public class OrganizationUserOrganizationDetailsViewQuery : IQuery - { - public IQueryable Run(DatabaseContext dbContext) - { - var query = from ou in dbContext.OrganizationUsers - join o in dbContext.Organizations on ou.OrganizationId equals o.Id - join su in dbContext.SsoUsers on ou.UserId equals su.UserId into su_g - from su in su_g.DefaultIfEmpty() - join po in dbContext.ProviderOrganizations on o.Id equals po.OrganizationId into po_g - from po in po_g.DefaultIfEmpty() - join p in dbContext.Providers on po.ProviderId equals p.Id into p_g - from p in p_g.DefaultIfEmpty() - join os in dbContext.OrganizationSponsorships on ou.Id equals os.SponsoringOrganizationUserId into os_g - from os in os_g.DefaultIfEmpty() - join ss in dbContext.SsoConfigs on ou.OrganizationId equals ss.OrganizationId into ss_g - from ss in ss_g.DefaultIfEmpty() - where ((su == null || !su.OrganizationId.HasValue) || su.OrganizationId == ou.OrganizationId) - select new { ou, o, su, p, ss, os }; +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - return query.Select(x => new OrganizationUserOrganizationDetails - { - OrganizationId = x.ou.OrganizationId, - UserId = x.ou.UserId, - Name = x.o.Name, - Enabled = x.o.Enabled, - PlanType = x.o.PlanType, - UsePolicies = x.o.UsePolicies, - UseSso = x.o.UseSso, - UseKeyConnector = x.o.UseKeyConnector, - UseScim = x.o.UseScim, - UseGroups = x.o.UseGroups, - UseDirectory = x.o.UseDirectory, - UseEvents = x.o.UseEvents, - UseTotp = x.o.UseTotp, - Use2fa = x.o.Use2fa, - UseApi = x.o.UseApi, - SelfHost = x.o.SelfHost, - UsersGetPremium = x.o.UsersGetPremium, - Seats = x.o.Seats, - MaxCollections = x.o.MaxCollections, - MaxStorageGb = x.o.MaxStorageGb, - Identifier = x.o.Identifier, - Key = x.ou.Key, - ResetPasswordKey = x.ou.ResetPasswordKey, - Status = x.ou.Status, - Type = x.ou.Type, - SsoExternalId = x.su.ExternalId, - Permissions = x.ou.Permissions, - PublicKey = x.o.PublicKey, - PrivateKey = x.o.PrivateKey, - ProviderId = x.p.Id, - ProviderName = x.p.Name, - SsoConfig = x.ss.Data, - FamilySponsorshipFriendlyName = x.os.FriendlyName, - FamilySponsorshipLastSyncDate = x.os.LastSyncDate, - FamilySponsorshipToDelete = x.os.ToDelete, - FamilySponsorshipValidUntil = x.os.ValidUntil - }); - } +public class OrganizationUserOrganizationDetailsViewQuery : IQuery +{ + public IQueryable Run(DatabaseContext dbContext) + { + var query = from ou in dbContext.OrganizationUsers + join o in dbContext.Organizations on ou.OrganizationId equals o.Id + join su in dbContext.SsoUsers on ou.UserId equals su.UserId into su_g + from su in su_g.DefaultIfEmpty() + join po in dbContext.ProviderOrganizations on o.Id equals po.OrganizationId into po_g + from po in po_g.DefaultIfEmpty() + join p in dbContext.Providers on po.ProviderId equals p.Id into p_g + from p in p_g.DefaultIfEmpty() + join os in dbContext.OrganizationSponsorships on ou.Id equals os.SponsoringOrganizationUserId into os_g + from os in os_g.DefaultIfEmpty() + join ss in dbContext.SsoConfigs on ou.OrganizationId equals ss.OrganizationId into ss_g + from ss in ss_g.DefaultIfEmpty() + where ((su == null || !su.OrganizationId.HasValue) || su.OrganizationId == ou.OrganizationId) + select new { ou, o, su, p, ss, os }; + + return query.Select(x => new OrganizationUserOrganizationDetails + { + OrganizationId = x.ou.OrganizationId, + UserId = x.ou.UserId, + Name = x.o.Name, + Enabled = x.o.Enabled, + PlanType = x.o.PlanType, + UsePolicies = x.o.UsePolicies, + UseSso = x.o.UseSso, + UseKeyConnector = x.o.UseKeyConnector, + UseScim = x.o.UseScim, + UseGroups = x.o.UseGroups, + UseDirectory = x.o.UseDirectory, + UseEvents = x.o.UseEvents, + UseTotp = x.o.UseTotp, + Use2fa = x.o.Use2fa, + UseApi = x.o.UseApi, + SelfHost = x.o.SelfHost, + UsersGetPremium = x.o.UsersGetPremium, + Seats = x.o.Seats, + MaxCollections = x.o.MaxCollections, + MaxStorageGb = x.o.MaxStorageGb, + Identifier = x.o.Identifier, + Key = x.ou.Key, + ResetPasswordKey = x.ou.ResetPasswordKey, + Status = x.ou.Status, + Type = x.ou.Type, + SsoExternalId = x.su.ExternalId, + Permissions = x.ou.Permissions, + PublicKey = x.o.PublicKey, + PrivateKey = x.o.PrivateKey, + ProviderId = x.p.Id, + ProviderName = x.p.Name, + SsoConfig = x.ss.Data, + FamilySponsorshipFriendlyName = x.os.FriendlyName, + FamilySponsorshipLastSyncDate = x.os.LastSyncDate, + FamilySponsorshipToDelete = x.os.ToDelete, + FamilySponsorshipValidUntil = x.os.ValidUntil + }); } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByFreeOrganizationAdminUserQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByFreeOrganizationAdminUserQuery.cs index 26c66fce9..c1656d3df 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByFreeOrganizationAdminUserQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByFreeOrganizationAdminUserQuery.cs @@ -1,29 +1,28 @@ using Bit.Core.Enums; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class OrganizationUserReadCountByFreeOrganizationAdminUserQuery : IQuery { - public class OrganizationUserReadCountByFreeOrganizationAdminUserQuery : IQuery + private readonly Guid _userId; + + public OrganizationUserReadCountByFreeOrganizationAdminUserQuery(Guid userId) { - private readonly Guid _userId; + _userId = userId; + } - public OrganizationUserReadCountByFreeOrganizationAdminUserQuery(Guid userId) - { - _userId = userId; - } + public IQueryable Run(DatabaseContext dbContext) + { + var query = from ou in dbContext.OrganizationUsers + join o in dbContext.Organizations + on ou.OrganizationId equals o.Id + where ou.UserId == _userId && + (ou.Type == OrganizationUserType.Owner || ou.Type == OrganizationUserType.Admin) && + o.PlanType == PlanType.Free && + ou.Status == OrganizationUserStatusType.Confirmed + select ou; - public IQueryable Run(DatabaseContext dbContext) - { - var query = from ou in dbContext.OrganizationUsers - join o in dbContext.Organizations - on ou.OrganizationId equals o.Id - where ou.UserId == _userId && - (ou.Type == OrganizationUserType.Owner || ou.Type == OrganizationUserType.Admin) && - o.PlanType == PlanType.Free && - ou.Status == OrganizationUserStatusType.Confirmed - select ou; - - return query; - } + return query; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOnlyOwnerQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOnlyOwnerQuery.cs index 6cf8e3c3b..53977e8b7 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOnlyOwnerQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOnlyOwnerQuery.cs @@ -1,37 +1,36 @@ using Bit.Core.Enums; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class OrganizationUserReadCountByOnlyOwnerQuery : IQuery { - public class OrganizationUserReadCountByOnlyOwnerQuery : IQuery + private readonly Guid _userId; + + public OrganizationUserReadCountByOnlyOwnerQuery(Guid userId) { - private readonly Guid _userId; + _userId = userId; + } - public OrganizationUserReadCountByOnlyOwnerQuery(Guid userId) - { - _userId = userId; - } + public IQueryable Run(DatabaseContext dbContext) + { + var owners = from ou in dbContext.OrganizationUsers + where ou.Type == OrganizationUserType.Owner && + ou.Status == OrganizationUserStatusType.Confirmed + group ou by ou.OrganizationId into g + select new + { + OrgUser = g.Select(x => new { x.UserId, x.Id }).FirstOrDefault(), + ConfirmedOwnerCount = g.Count(), + }; - public IQueryable Run(DatabaseContext dbContext) - { - var owners = from ou in dbContext.OrganizationUsers - where ou.Type == OrganizationUserType.Owner && - ou.Status == OrganizationUserStatusType.Confirmed - group ou by ou.OrganizationId into g - select new - { - OrgUser = g.Select(x => new { x.UserId, x.Id }).FirstOrDefault(), - ConfirmedOwnerCount = g.Count(), - }; + var query = from owner in owners + join ou in dbContext.OrganizationUsers + on owner.OrgUser.Id equals ou.Id + where owner.OrgUser.UserId == _userId && + owner.ConfirmedOwnerCount == 1 + select ou; - var query = from owner in owners - join ou in dbContext.OrganizationUsers - on owner.OrgUser.Id equals ou.Id - where owner.OrgUser.UserId == _userId && - owner.ConfirmedOwnerCount == 1 - select ou; - - return query; - } + return query; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOrganizationIdEmailQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOrganizationIdEmailQuery.cs index ed4c786c7..0cb2abc46 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOrganizationIdEmailQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOrganizationIdEmailQuery.cs @@ -1,31 +1,30 @@ using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class OrganizationUserReadCountByOrganizationIdEmailQuery : IQuery { - public class OrganizationUserReadCountByOrganizationIdEmailQuery : IQuery + private readonly Guid _organizationId; + private readonly string _email; + private readonly bool _onlyUsers; + + public OrganizationUserReadCountByOrganizationIdEmailQuery(Guid organizationId, string email, bool onlyUsers) { - private readonly Guid _organizationId; - private readonly string _email; - private readonly bool _onlyUsers; + _organizationId = organizationId; + _email = email; + _onlyUsers = onlyUsers; + } - public OrganizationUserReadCountByOrganizationIdEmailQuery(Guid organizationId, string email, bool onlyUsers) - { - _organizationId = organizationId; - _email = email; - _onlyUsers = onlyUsers; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var query = from ou in dbContext.OrganizationUsers - join u in dbContext.Users - on ou.UserId equals u.Id into u_g - from u in u_g.DefaultIfEmpty() - where ou.OrganizationId == _organizationId && - ((!_onlyUsers && (ou.Email == _email || u.Email == _email)) - || (_onlyUsers && u.Email == _email)) - select ou; - return query; - } + public IQueryable Run(DatabaseContext dbContext) + { + var query = from ou in dbContext.OrganizationUsers + join u in dbContext.Users + on ou.UserId equals u.Id into u_g + from u in u_g.DefaultIfEmpty() + where ou.OrganizationId == _organizationId && + ((!_onlyUsers && (ou.Email == _email || u.Email == _email)) + || (_onlyUsers && u.Email == _email)) + select ou; + return query; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOrganizationIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOrganizationIdQuery.cs index 05c6dd049..a4ab7cb85 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOrganizationIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserReadCountByOrganizationIdQuery.cs @@ -1,22 +1,21 @@ using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class OrganizationUserReadCountByOrganizationIdQuery : IQuery { - public class OrganizationUserReadCountByOrganizationIdQuery : IQuery + private readonly Guid _organizationId; + + public OrganizationUserReadCountByOrganizationIdQuery(Guid organizationId) { - private readonly Guid _organizationId; + _organizationId = organizationId; + } - public OrganizationUserReadCountByOrganizationIdQuery(Guid organizationId) - { - _organizationId = organizationId; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var query = from ou in dbContext.OrganizationUsers - where ou.OrganizationId == _organizationId - select ou; - return query; - } + public IQueryable Run(DatabaseContext dbContext) + { + var query = from ou in dbContext.OrganizationUsers + where ou.OrganizationId == _organizationId + select ou; + return query; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUpdateWithCollectionsQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUpdateWithCollectionsQuery.cs index 10dbf88d7..0a21514d6 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUpdateWithCollectionsQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUpdateWithCollectionsQuery.cs @@ -2,105 +2,104 @@ using Bit.Core.Models.Data; using CollectionUser = Bit.Infrastructure.EntityFramework.Models.CollectionUser; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class OrganizationUserUpdateWithCollectionsQuery { - public class OrganizationUserUpdateWithCollectionsQuery + public OrganizationUserUpdateWithCollectionsInsertQuery Insert { get; set; } + public OrganizationUserUpdateWithCollectionsUpdateQuery Update { get; set; } + public OrganizationUserUpdateWithCollectionsDeleteQuery Delete { get; set; } + + public OrganizationUserUpdateWithCollectionsQuery(OrganizationUser organizationUser, + IEnumerable collections) { - public OrganizationUserUpdateWithCollectionsInsertQuery Insert { get; set; } - public OrganizationUserUpdateWithCollectionsUpdateQuery Update { get; set; } - public OrganizationUserUpdateWithCollectionsDeleteQuery Delete { get; set; } - - public OrganizationUserUpdateWithCollectionsQuery(OrganizationUser organizationUser, - IEnumerable collections) - { - Insert = new OrganizationUserUpdateWithCollectionsInsertQuery(organizationUser, collections); - Update = new OrganizationUserUpdateWithCollectionsUpdateQuery(organizationUser, collections); - Delete = new OrganizationUserUpdateWithCollectionsDeleteQuery(organizationUser, collections); - } - } - - public class OrganizationUserUpdateWithCollectionsInsertQuery : IQuery - { - private readonly OrganizationUser _organizationUser; - private readonly IEnumerable _collections; - - public OrganizationUserUpdateWithCollectionsInsertQuery(OrganizationUser organizationUser, IEnumerable collections) - { - _organizationUser = organizationUser; - _collections = collections; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var collectionIds = _collections.Select(c => c.Id).ToArray(); - var t = (from cu in dbContext.CollectionUsers - where collectionIds.Contains(cu.CollectionId) && - cu.OrganizationUserId == _organizationUser.Id - select cu).AsEnumerable(); - var insertQuery = (from c in dbContext.Collections - where collectionIds.Contains(c.Id) && - c.OrganizationId == _organizationUser.OrganizationId && - !t.Any() - select c).AsEnumerable(); - return insertQuery.Select(x => new CollectionUser - { - CollectionId = x.Id, - OrganizationUserId = _organizationUser.Id, - ReadOnly = _collections.FirstOrDefault(c => c.Id == x.Id).ReadOnly, - HidePasswords = _collections.FirstOrDefault(c => c.Id == x.Id).HidePasswords, - }).AsQueryable(); - } - } - - public class OrganizationUserUpdateWithCollectionsUpdateQuery : IQuery - { - private readonly OrganizationUser _organizationUser; - private readonly IEnumerable _collections; - - public OrganizationUserUpdateWithCollectionsUpdateQuery(OrganizationUser organizationUser, IEnumerable collections) - { - _organizationUser = organizationUser; - _collections = collections; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var collectionIds = _collections.Select(c => c.Id).ToArray(); - var updateQuery = (from target in dbContext.CollectionUsers - where collectionIds.Contains(target.CollectionId) && - target.OrganizationUserId == _organizationUser.Id - select new { target }).AsEnumerable(); - updateQuery = updateQuery.Where(cu => - cu.target.ReadOnly == _collections.FirstOrDefault(u => u.Id == cu.target.CollectionId).ReadOnly && - cu.target.HidePasswords == _collections.FirstOrDefault(u => u.Id == cu.target.CollectionId).HidePasswords); - return updateQuery.Select(x => new CollectionUser - { - CollectionId = x.target.CollectionId, - OrganizationUserId = _organizationUser.Id, - ReadOnly = x.target.ReadOnly, - HidePasswords = x.target.HidePasswords, - }).AsQueryable(); - } - } - - public class OrganizationUserUpdateWithCollectionsDeleteQuery : IQuery - { - private readonly OrganizationUser _organizationUser; - private readonly IEnumerable _collections; - - public OrganizationUserUpdateWithCollectionsDeleteQuery(OrganizationUser organizationUser, IEnumerable collections) - { - _organizationUser = organizationUser; - _collections = collections; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var deleteQuery = from cu in dbContext.CollectionUsers - where !_collections.Any( - c => c.Id == cu.CollectionId) - select cu; - return deleteQuery; - } + Insert = new OrganizationUserUpdateWithCollectionsInsertQuery(organizationUser, collections); + Update = new OrganizationUserUpdateWithCollectionsUpdateQuery(organizationUser, collections); + Delete = new OrganizationUserUpdateWithCollectionsDeleteQuery(organizationUser, collections); + } +} + +public class OrganizationUserUpdateWithCollectionsInsertQuery : IQuery +{ + private readonly OrganizationUser _organizationUser; + private readonly IEnumerable _collections; + + public OrganizationUserUpdateWithCollectionsInsertQuery(OrganizationUser organizationUser, IEnumerable collections) + { + _organizationUser = organizationUser; + _collections = collections; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var collectionIds = _collections.Select(c => c.Id).ToArray(); + var t = (from cu in dbContext.CollectionUsers + where collectionIds.Contains(cu.CollectionId) && + cu.OrganizationUserId == _organizationUser.Id + select cu).AsEnumerable(); + var insertQuery = (from c in dbContext.Collections + where collectionIds.Contains(c.Id) && + c.OrganizationId == _organizationUser.OrganizationId && + !t.Any() + select c).AsEnumerable(); + return insertQuery.Select(x => new CollectionUser + { + CollectionId = x.Id, + OrganizationUserId = _organizationUser.Id, + ReadOnly = _collections.FirstOrDefault(c => c.Id == x.Id).ReadOnly, + HidePasswords = _collections.FirstOrDefault(c => c.Id == x.Id).HidePasswords, + }).AsQueryable(); + } +} + +public class OrganizationUserUpdateWithCollectionsUpdateQuery : IQuery +{ + private readonly OrganizationUser _organizationUser; + private readonly IEnumerable _collections; + + public OrganizationUserUpdateWithCollectionsUpdateQuery(OrganizationUser organizationUser, IEnumerable collections) + { + _organizationUser = organizationUser; + _collections = collections; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var collectionIds = _collections.Select(c => c.Id).ToArray(); + var updateQuery = (from target in dbContext.CollectionUsers + where collectionIds.Contains(target.CollectionId) && + target.OrganizationUserId == _organizationUser.Id + select new { target }).AsEnumerable(); + updateQuery = updateQuery.Where(cu => + cu.target.ReadOnly == _collections.FirstOrDefault(u => u.Id == cu.target.CollectionId).ReadOnly && + cu.target.HidePasswords == _collections.FirstOrDefault(u => u.Id == cu.target.CollectionId).HidePasswords); + return updateQuery.Select(x => new CollectionUser + { + CollectionId = x.target.CollectionId, + OrganizationUserId = _organizationUser.Id, + ReadOnly = x.target.ReadOnly, + HidePasswords = x.target.HidePasswords, + }).AsQueryable(); + } +} + +public class OrganizationUserUpdateWithCollectionsDeleteQuery : IQuery +{ + private readonly OrganizationUser _organizationUser; + private readonly IEnumerable _collections; + + public OrganizationUserUpdateWithCollectionsDeleteQuery(OrganizationUser organizationUser, IEnumerable collections) + { + _organizationUser = organizationUser; + _collections = collections; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var deleteQuery = from cu in dbContext.CollectionUsers + where !_collections.Any( + c => c.Id == cu.CollectionId) + select cu; + return deleteQuery; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUserViewQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUserViewQuery.cs index 248957196..2a5bf06fd 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUserViewQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUserViewQuery.cs @@ -1,35 +1,34 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class OrganizationUserUserDetailsViewQuery : IQuery { - public class OrganizationUserUserDetailsViewQuery : IQuery + public IQueryable Run(DatabaseContext dbContext) { - public IQueryable Run(DatabaseContext dbContext) + var query = from ou in dbContext.OrganizationUsers + join u in dbContext.Users on ou.UserId equals u.Id into u_g + from u in u_g.DefaultIfEmpty() + join su in dbContext.SsoUsers on u.Id equals su.UserId into su_g + from su in su_g.DefaultIfEmpty() + select new { ou, u, su }; + return query.Select(x => new OrganizationUserUserDetails { - var query = from ou in dbContext.OrganizationUsers - join u in dbContext.Users on ou.UserId equals u.Id into u_g - from u in u_g.DefaultIfEmpty() - join su in dbContext.SsoUsers on u.Id equals su.UserId into su_g - from su in su_g.DefaultIfEmpty() - select new { ou, u, su }; - return query.Select(x => new OrganizationUserUserDetails - { - Id = x.ou.Id, - OrganizationId = x.ou.OrganizationId, - UserId = x.ou.UserId, - Name = x.u.Name, - Email = x.u.Email ?? x.ou.Email, - TwoFactorProviders = x.u.TwoFactorProviders, - Premium = x.u.Premium, - Status = x.ou.Status, - Type = x.ou.Type, - AccessAll = x.ou.AccessAll, - ExternalId = x.ou.ExternalId, - SsoExternalId = x.su.ExternalId, - Permissions = x.ou.Permissions, - ResetPasswordKey = x.ou.ResetPasswordKey, - UsesKeyConnector = x.u != null && x.u.UsesKeyConnector, - }); - } + Id = x.ou.Id, + OrganizationId = x.ou.OrganizationId, + UserId = x.ou.UserId, + Name = x.u.Name, + Email = x.u.Email ?? x.ou.Email, + TwoFactorProviders = x.u.TwoFactorProviders, + Premium = x.u.Premium, + Status = x.ou.Status, + Type = x.ou.Type, + AccessAll = x.ou.AccessAll, + ExternalId = x.ou.ExternalId, + SsoExternalId = x.su.ExternalId, + Permissions = x.ou.Permissions, + ResetPasswordKey = x.ou.ResetPasswordKey, + UsesKeyConnector = x.u != null && x.u.UsesKeyConnector, + }); } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/PolicyReadByTypeApplicableToUserQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/PolicyReadByTypeApplicableToUserQuery.cs index bd69570ec..21e5f9a28 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/PolicyReadByTypeApplicableToUserQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/PolicyReadByTypeApplicableToUserQuery.cs @@ -1,51 +1,50 @@ using Bit.Core.Enums; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class PolicyReadByTypeApplicableToUserQuery : IQuery { - public class PolicyReadByTypeApplicableToUserQuery : IQuery + private readonly Guid _userId; + private readonly PolicyType _policyType; + private readonly OrganizationUserStatusType _minimumStatus; + + public PolicyReadByTypeApplicableToUserQuery(Guid userId, PolicyType policyType, OrganizationUserStatusType minimumStatus) { - private readonly Guid _userId; - private readonly PolicyType _policyType; - private readonly OrganizationUserStatusType _minimumStatus; + _userId = userId; + _policyType = policyType; + _minimumStatus = minimumStatus; + } - public PolicyReadByTypeApplicableToUserQuery(Guid userId, PolicyType policyType, OrganizationUserStatusType minimumStatus) + public IQueryable Run(DatabaseContext dbContext) + { + var providerOrganizations = from pu in dbContext.ProviderUsers + where pu.UserId == _userId + join po in dbContext.ProviderOrganizations + on pu.ProviderId equals po.ProviderId + select po; + + string userEmail = null; + if (_minimumStatus == OrganizationUserStatusType.Invited) { - _userId = userId; - _policyType = policyType; - _minimumStatus = minimumStatus; + // Invited orgUsers do not have a UserId associated with them, so we have to match up their email + userEmail = dbContext.Users.Find(_userId)?.Email; } - public IQueryable Run(DatabaseContext dbContext) - { - var providerOrganizations = from pu in dbContext.ProviderUsers - where pu.UserId == _userId - join po in dbContext.ProviderOrganizations - on pu.ProviderId equals po.ProviderId - select po; - - string userEmail = null; - if (_minimumStatus == OrganizationUserStatusType.Invited) - { - // Invited orgUsers do not have a UserId associated with them, so we have to match up their email - userEmail = dbContext.Users.Find(_userId)?.Email; - } - - var query = from p in dbContext.Policies - join ou in dbContext.OrganizationUsers - on p.OrganizationId equals ou.OrganizationId - where - ((_minimumStatus > OrganizationUserStatusType.Invited && ou.UserId == _userId) || - (_minimumStatus == OrganizationUserStatusType.Invited && ou.Email == userEmail)) && - p.Type == _policyType && - p.Enabled && - ou.Status >= _minimumStatus && - ou.Type >= OrganizationUserType.User && - (ou.Permissions == null || - ou.Permissions.Contains($"\"managePolicies\":false")) && - !providerOrganizations.Any(po => po.OrganizationId == p.OrganizationId) - select p; - return query; - } + var query = from p in dbContext.Policies + join ou in dbContext.OrganizationUsers + on p.OrganizationId equals ou.OrganizationId + where + ((_minimumStatus > OrganizationUserStatusType.Invited && ou.UserId == _userId) || + (_minimumStatus == OrganizationUserStatusType.Invited && ou.Email == userEmail)) && + p.Type == _policyType && + p.Enabled && + ou.Status >= _minimumStatus && + ou.Type >= OrganizationUserType.User && + (ou.Permissions == null || + ou.Permissions.Contains($"\"managePolicies\":false")) && + !providerOrganizations.Any(po => po.OrganizationId == p.OrganizationId) + select p; + return query; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/PolicyReadByUserIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/PolicyReadByUserIdQuery.cs index e910e2f75..58c06395a 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/PolicyReadByUserIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/PolicyReadByUserIdQuery.cs @@ -1,30 +1,29 @@ using Bit.Core.Enums; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class PolicyReadByUserIdQuery : IQuery { - public class PolicyReadByUserIdQuery : IQuery + private readonly Guid _userId; + + public PolicyReadByUserIdQuery(Guid userId) { - private readonly Guid _userId; + _userId = userId; + } - public PolicyReadByUserIdQuery(Guid userId) - { - _userId = userId; - } + public IQueryable Run(DatabaseContext dbContext) + { + var query = from p in dbContext.Policies + join ou in dbContext.OrganizationUsers + on p.OrganizationId equals ou.OrganizationId + join o in dbContext.Organizations + on ou.OrganizationId equals o.Id + where ou.UserId == _userId && + ou.Status == OrganizationUserStatusType.Confirmed && + o.Enabled == true + select p; - public IQueryable Run(DatabaseContext dbContext) - { - var query = from p in dbContext.Policies - join ou in dbContext.OrganizationUsers - on p.OrganizationId equals ou.OrganizationId - join o in dbContext.Organizations - on ou.OrganizationId equals o.Id - where ou.UserId == _userId && - ou.Status == OrganizationUserStatusType.Confirmed && - o.Enabled == true - select p; - - return query; - } + return query; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderOrganizationOrganizationDetailsReadByProviderIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderOrganizationOrganizationDetailsReadByProviderIdQuery.cs index 03ada03c3..1429a136c 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderOrganizationOrganizationDetailsReadByProviderIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderOrganizationOrganizationDetailsReadByProviderIdQuery.cs @@ -1,38 +1,37 @@ using Bit.Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries -{ - public class ProviderOrganizationOrganizationDetailsReadByProviderIdQuery : IQuery - { - private readonly Guid _providerId; - public ProviderOrganizationOrganizationDetailsReadByProviderIdQuery(Guid providerId) - { - _providerId = providerId; - } +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - public IQueryable Run(DatabaseContext dbContext) +public class ProviderOrganizationOrganizationDetailsReadByProviderIdQuery : IQuery +{ + private readonly Guid _providerId; + public ProviderOrganizationOrganizationDetailsReadByProviderIdQuery(Guid providerId) + { + _providerId = providerId; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var query = from po in dbContext.ProviderOrganizations + join o in dbContext.Organizations + on po.OrganizationId equals o.Id + join ou in dbContext.OrganizationUsers + on po.OrganizationId equals ou.OrganizationId + where po.ProviderId == _providerId + select new { po, o }; + return query.Select(x => new ProviderOrganizationOrganizationDetails() { - var query = from po in dbContext.ProviderOrganizations - join o in dbContext.Organizations - on po.OrganizationId equals o.Id - join ou in dbContext.OrganizationUsers - on po.OrganizationId equals ou.OrganizationId - where po.ProviderId == _providerId - select new { po, o }; - return query.Select(x => new ProviderOrganizationOrganizationDetails() - { - Id = x.po.Id, - ProviderId = x.po.ProviderId, - OrganizationId = x.po.OrganizationId, - OrganizationName = x.o.Name, - Key = x.po.Key, - Settings = x.po.Settings, - CreationDate = x.po.CreationDate, - RevisionDate = x.po.RevisionDate, - UserCount = x.o.OrganizationUsers.Count(ou => ou.Status == Core.Enums.OrganizationUserStatusType.Confirmed), - Seats = x.o.Seats, - Plan = x.o.Plan - }); - } + Id = x.po.Id, + ProviderId = x.po.ProviderId, + OrganizationId = x.po.OrganizationId, + OrganizationName = x.o.Name, + Key = x.po.Key, + Settings = x.po.Settings, + CreationDate = x.po.CreationDate, + RevisionDate = x.po.RevisionDate, + UserCount = x.o.OrganizationUsers.Count(ou => ou.Status == Core.Enums.OrganizationUserStatusType.Confirmed), + Seats = x.o.Seats, + Plan = x.o.Plan + }); } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserOrganizationDetailsViewQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserOrganizationDetailsViewQuery.cs index 8f3e71861..dfd5f6192 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserOrganizationDetailsViewQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserOrganizationDetailsViewQuery.cs @@ -1,46 +1,45 @@ using Bit.Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class ProviderUserOrganizationDetailsViewQuery : IQuery { - public class ProviderUserOrganizationDetailsViewQuery : IQuery + public IQueryable Run(DatabaseContext dbContext) { - public IQueryable Run(DatabaseContext dbContext) + var query = from pu in dbContext.ProviderUsers + join po in dbContext.ProviderOrganizations on pu.ProviderId equals po.ProviderId + join o in dbContext.Organizations on po.OrganizationId equals o.Id + join p in dbContext.Providers on pu.ProviderId equals p.Id + select new { pu, po, o, p }; + return query.Select(x => new ProviderUserOrganizationDetails { - var query = from pu in dbContext.ProviderUsers - join po in dbContext.ProviderOrganizations on pu.ProviderId equals po.ProviderId - join o in dbContext.Organizations on po.OrganizationId equals o.Id - join p in dbContext.Providers on pu.ProviderId equals p.Id - select new { pu, po, o, p }; - return query.Select(x => new ProviderUserOrganizationDetails - { - OrganizationId = x.po.OrganizationId, - UserId = x.pu.UserId, - Name = x.o.Name, - Enabled = x.o.Enabled, - UsePolicies = x.o.UsePolicies, - UseSso = x.o.UseSso, - UseKeyConnector = x.o.UseKeyConnector, - UseScim = x.o.UseScim, - UseGroups = x.o.UseGroups, - UseDirectory = x.o.UseDirectory, - UseEvents = x.o.UseEvents, - UseTotp = x.o.UseTotp, - Use2fa = x.o.Use2fa, - UseApi = x.o.UseApi, - SelfHost = x.o.SelfHost, - UsersGetPremium = x.o.UsersGetPremium, - Seats = x.o.Seats, - MaxCollections = x.o.MaxCollections, - MaxStorageGb = x.o.MaxStorageGb, - Identifier = x.o.Identifier, - Key = x.po.Key, - Status = x.pu.Status, - Type = x.pu.Type, - PublicKey = x.o.PublicKey, - PrivateKey = x.o.PrivateKey, - ProviderId = x.p.Id, - ProviderName = x.p.Name, - }); - } + OrganizationId = x.po.OrganizationId, + UserId = x.pu.UserId, + Name = x.o.Name, + Enabled = x.o.Enabled, + UsePolicies = x.o.UsePolicies, + UseSso = x.o.UseSso, + UseKeyConnector = x.o.UseKeyConnector, + UseScim = x.o.UseScim, + UseGroups = x.o.UseGroups, + UseDirectory = x.o.UseDirectory, + UseEvents = x.o.UseEvents, + UseTotp = x.o.UseTotp, + Use2fa = x.o.Use2fa, + UseApi = x.o.UseApi, + SelfHost = x.o.SelfHost, + UsersGetPremium = x.o.UsersGetPremium, + Seats = x.o.Seats, + MaxCollections = x.o.MaxCollections, + MaxStorageGb = x.o.MaxStorageGb, + Identifier = x.o.Identifier, + Key = x.po.Key, + Status = x.pu.Status, + Type = x.pu.Type, + PublicKey = x.o.PublicKey, + PrivateKey = x.o.PrivateKey, + ProviderId = x.p.Id, + ProviderName = x.p.Name, + }); } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserProviderDetailsReadByUserIdStatusQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserProviderDetailsReadByUserIdStatusQuery.cs index efa768123..1cae8437a 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserProviderDetailsReadByUserIdStatusQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserProviderDetailsReadByUserIdStatusQuery.cs @@ -1,39 +1,38 @@ using Bit.Core.Enums.Provider; using Bit.Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries -{ - public class ProviderUserProviderDetailsReadByUserIdStatusQuery : IQuery - { - private readonly Guid _userId; - private readonly ProviderUserStatusType? _status; - public ProviderUserProviderDetailsReadByUserIdStatusQuery(Guid userId, ProviderUserStatusType? status) - { - _userId = userId; - _status = status; - } +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - public IQueryable Run(DatabaseContext dbContext) +public class ProviderUserProviderDetailsReadByUserIdStatusQuery : IQuery +{ + private readonly Guid _userId; + private readonly ProviderUserStatusType? _status; + public ProviderUserProviderDetailsReadByUserIdStatusQuery(Guid userId, ProviderUserStatusType? status) + { + _userId = userId; + _status = status; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var query = from pu in dbContext.ProviderUsers + join p in dbContext.Providers + on pu.ProviderId equals p.Id into p_g + from p in p_g.DefaultIfEmpty() + where pu.UserId == _userId && p.Status != ProviderStatusType.Pending && (_status == null || pu.Status == _status) + select new { pu, p }; + return query.Select(x => new ProviderUserProviderDetails() { - var query = from pu in dbContext.ProviderUsers - join p in dbContext.Providers - on pu.ProviderId equals p.Id into p_g - from p in p_g.DefaultIfEmpty() - where pu.UserId == _userId && p.Status != ProviderStatusType.Pending && (_status == null || pu.Status == _status) - select new { pu, p }; - return query.Select(x => new ProviderUserProviderDetails() - { - UserId = x.pu.UserId, - ProviderId = x.pu.ProviderId, - Name = x.p.Name, - Key = x.pu.Key, - Status = x.pu.Status, - Type = x.pu.Type, - Enabled = x.p.Enabled, - Permissions = x.pu.Permissions, - UseEvents = x.p.UseEvents, - ProviderStatus = x.p.Status, - }); - } + UserId = x.pu.UserId, + ProviderId = x.pu.ProviderId, + Name = x.p.Name, + Key = x.pu.Key, + Status = x.pu.Status, + Type = x.pu.Type, + Enabled = x.p.Enabled, + Permissions = x.pu.Permissions, + UseEvents = x.p.UseEvents, + ProviderStatus = x.p.Status, + }); } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserReadCountByOnlyOwnerQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserReadCountByOnlyOwnerQuery.cs index fd6e8521d..899c78b54 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserReadCountByOnlyOwnerQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/ProviderUserReadCountByOnlyOwnerQuery.cs @@ -1,37 +1,36 @@ using Bit.Core.Enums.Provider; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class ProviderUserReadCountByOnlyOwnerQuery : IQuery { - public class ProviderUserReadCountByOnlyOwnerQuery : IQuery + private readonly Guid _userId; + + public ProviderUserReadCountByOnlyOwnerQuery(Guid userId) { - private readonly Guid _userId; + _userId = userId; + } - public ProviderUserReadCountByOnlyOwnerQuery(Guid userId) - { - _userId = userId; - } + public IQueryable Run(DatabaseContext dbContext) + { + var owners = from pu in dbContext.ProviderUsers + where pu.Type == ProviderUserType.ProviderAdmin && + pu.Status == ProviderUserStatusType.Confirmed + group pu by pu.ProviderId into g + select new + { + ProviderUser = g.Select(x => new { x.UserId, x.Id }).FirstOrDefault(), + ConfirmedOwnerCount = g.Count(), + }; - public IQueryable Run(DatabaseContext dbContext) - { - var owners = from pu in dbContext.ProviderUsers - where pu.Type == ProviderUserType.ProviderAdmin && - pu.Status == ProviderUserStatusType.Confirmed - group pu by pu.ProviderId into g - select new - { - ProviderUser = g.Select(x => new { x.UserId, x.Id }).FirstOrDefault(), - ConfirmedOwnerCount = g.Count(), - }; + var query = from owner in owners + join pu in dbContext.ProviderUsers + on owner.ProviderUser.Id equals pu.Id + where owner.ProviderUser.UserId == _userId && + owner.ConfirmedOwnerCount == 1 + select pu; - var query = from owner in owners - join pu in dbContext.ProviderUsers - on owner.ProviderUser.Id equals pu.Id - where owner.ProviderUser.UserId == _userId && - owner.ConfirmedOwnerCount == 1 - select pu; - - return query; - } + return query; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/UserBumpAccountRevisionDateByCipherIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/UserBumpAccountRevisionDateByCipherIdQuery.cs index 6fcf15ab4..00369c594 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/UserBumpAccountRevisionDateByCipherIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/UserBumpAccountRevisionDateByCipherIdQuery.cs @@ -2,51 +2,50 @@ using Bit.Core.Enums; using User = Bit.Infrastructure.EntityFramework.Models.User; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class UserBumpAccountRevisionDateByCipherIdQuery : IQuery { - public class UserBumpAccountRevisionDateByCipherIdQuery : IQuery + private readonly Cipher _cipher; + + public UserBumpAccountRevisionDateByCipherIdQuery(Cipher cipher) { - private readonly Cipher _cipher; + _cipher = cipher; + } - public UserBumpAccountRevisionDateByCipherIdQuery(Cipher cipher) - { - _cipher = cipher; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var query = from u in dbContext.Users - join ou in dbContext.OrganizationUsers - on u.Id equals ou.UserId - join collectionCipher in dbContext.CollectionCiphers - on _cipher.Id equals collectionCipher.CipherId into cc_g - from cc in cc_g.DefaultIfEmpty() - join collectionUser in dbContext.CollectionUsers - on cc.CollectionId equals collectionUser.CollectionId into cu_g - from cu in cu_g.DefaultIfEmpty() - where ou.AccessAll && - cu.OrganizationUserId == ou.Id - join groupUser in dbContext.GroupUsers - on ou.Id equals groupUser.OrganizationUserId into gu_g - from gu in gu_g.DefaultIfEmpty() - where cu.CollectionId == null && - !ou.AccessAll - join grp in dbContext.Groups - on gu.GroupId equals grp.Id into g_g - from g in g_g.DefaultIfEmpty() - join collectionGroup in dbContext.CollectionGroups - on cc.CollectionId equals collectionGroup.CollectionId into cg_g - from cg in cg_g.DefaultIfEmpty() - where !g.AccessAll && - cg.GroupId == gu.GroupId - where ou.OrganizationId == _cipher.OrganizationId && - ou.Status == OrganizationUserStatusType.Confirmed && - (cu.CollectionId != null || - cg.CollectionId != null || - ou.AccessAll || - g.AccessAll) - select u; - return query; - } + public IQueryable Run(DatabaseContext dbContext) + { + var query = from u in dbContext.Users + join ou in dbContext.OrganizationUsers + on u.Id equals ou.UserId + join collectionCipher in dbContext.CollectionCiphers + on _cipher.Id equals collectionCipher.CipherId into cc_g + from cc in cc_g.DefaultIfEmpty() + join collectionUser in dbContext.CollectionUsers + on cc.CollectionId equals collectionUser.CollectionId into cu_g + from cu in cu_g.DefaultIfEmpty() + where ou.AccessAll && + cu.OrganizationUserId == ou.Id + join groupUser in dbContext.GroupUsers + on ou.Id equals groupUser.OrganizationUserId into gu_g + from gu in gu_g.DefaultIfEmpty() + where cu.CollectionId == null && + !ou.AccessAll + join grp in dbContext.Groups + on gu.GroupId equals grp.Id into g_g + from g in g_g.DefaultIfEmpty() + join collectionGroup in dbContext.CollectionGroups + on cc.CollectionId equals collectionGroup.CollectionId into cg_g + from cg in cg_g.DefaultIfEmpty() + where !g.AccessAll && + cg.GroupId == gu.GroupId + where ou.OrganizationId == _cipher.OrganizationId && + ou.Status == OrganizationUserStatusType.Confirmed && + (cu.CollectionId != null || + cg.CollectionId != null || + ou.AccessAll || + g.AccessAll) + select u; + return query; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/UserBumpAccountRevisionDateByOrganizationIdQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/UserBumpAccountRevisionDateByOrganizationIdQuery.cs index 87c2bcf08..d18cdc064 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/UserBumpAccountRevisionDateByOrganizationIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/UserBumpAccountRevisionDateByOrganizationIdQuery.cs @@ -1,27 +1,26 @@ using Bit.Core.Enums; using Bit.Infrastructure.EntityFramework.Models; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class UserBumpAccountRevisionDateByOrganizationIdQuery : IQuery { - public class UserBumpAccountRevisionDateByOrganizationIdQuery : IQuery + private readonly Guid _organizationId; + + public UserBumpAccountRevisionDateByOrganizationIdQuery(Guid organizationId) { - private readonly Guid _organizationId; + _organizationId = organizationId; + } - public UserBumpAccountRevisionDateByOrganizationIdQuery(Guid organizationId) - { - _organizationId = organizationId; - } + public IQueryable Run(DatabaseContext dbContext) + { + var query = from u in dbContext.Users + join ou in dbContext.OrganizationUsers + on u.Id equals ou.UserId + where ou.OrganizationId == _organizationId && + ou.Status == OrganizationUserStatusType.Confirmed + select u; - public IQueryable Run(DatabaseContext dbContext) - { - var query = from u in dbContext.Users - join ou in dbContext.OrganizationUsers - on u.Id equals ou.UserId - where ou.OrganizationId == _organizationId && - ou.Status == OrganizationUserStatusType.Confirmed - select u; - - return query; - } + return query; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/UserCipherDetailsQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/UserCipherDetailsQuery.cs index 3417abd30..b74060ba3 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/UserCipherDetailsQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/UserCipherDetailsQuery.cs @@ -2,71 +2,70 @@ using Core.Models.Data; using Newtonsoft.Json.Linq; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class UserCipherDetailsQuery : IQuery { - public class UserCipherDetailsQuery : IQuery + private readonly Guid? _userId; + public UserCipherDetailsQuery(Guid? userId) { - private readonly Guid? _userId; - public UserCipherDetailsQuery(Guid? userId) - { - _userId = userId; - } - public virtual IQueryable Run(DatabaseContext dbContext) - { - var query = from c in dbContext.Ciphers - join ou in dbContext.OrganizationUsers - on c.OrganizationId equals ou.OrganizationId - where ou.UserId == _userId && - ou.Status == OrganizationUserStatusType.Confirmed - join o in dbContext.Organizations - on c.OrganizationId equals o.Id - where o.Id == ou.OrganizationId && o.Enabled - join cc in dbContext.CollectionCiphers - on c.Id equals cc.CipherId into cc_g - from cc in cc_g.DefaultIfEmpty() - where ou.AccessAll - join cu in dbContext.CollectionUsers - on cc.CollectionId equals cu.CollectionId into cu_g - from cu in cu_g.DefaultIfEmpty() - where cu.OrganizationUserId == ou.Id - join gu in dbContext.GroupUsers - on ou.Id equals gu.OrganizationUserId into gu_g - from gu in gu_g.DefaultIfEmpty() - where cu.CollectionId == null && !ou.AccessAll - join g in dbContext.Groups - on gu.GroupId equals g.Id into g_g - from g in g_g.DefaultIfEmpty() - join cg in dbContext.CollectionGroups - on cc.CollectionId equals cg.CollectionId into cg_g - from cg in cg_g.DefaultIfEmpty() - where !g.AccessAll && cg.GroupId == gu.GroupId && - ou.AccessAll || cu.CollectionId != null || g.AccessAll || cg.CollectionId != null - select new { c, ou, o, cc, cu, gu, g, cg }.c; + _userId = userId; + } + public virtual IQueryable Run(DatabaseContext dbContext) + { + var query = from c in dbContext.Ciphers + join ou in dbContext.OrganizationUsers + on c.OrganizationId equals ou.OrganizationId + where ou.UserId == _userId && + ou.Status == OrganizationUserStatusType.Confirmed + join o in dbContext.Organizations + on c.OrganizationId equals o.Id + where o.Id == ou.OrganizationId && o.Enabled + join cc in dbContext.CollectionCiphers + on c.Id equals cc.CipherId into cc_g + from cc in cc_g.DefaultIfEmpty() + where ou.AccessAll + join cu in dbContext.CollectionUsers + on cc.CollectionId equals cu.CollectionId into cu_g + from cu in cu_g.DefaultIfEmpty() + where cu.OrganizationUserId == ou.Id + join gu in dbContext.GroupUsers + on ou.Id equals gu.OrganizationUserId into gu_g + from gu in gu_g.DefaultIfEmpty() + where cu.CollectionId == null && !ou.AccessAll + join g in dbContext.Groups + on gu.GroupId equals g.Id into g_g + from g in g_g.DefaultIfEmpty() + join cg in dbContext.CollectionGroups + on cc.CollectionId equals cg.CollectionId into cg_g + from cg in cg_g.DefaultIfEmpty() + where !g.AccessAll && cg.GroupId == gu.GroupId && + ou.AccessAll || cu.CollectionId != null || g.AccessAll || cg.CollectionId != null + select new { c, ou, o, cc, cu, gu, g, cg }.c; - var query2 = from c in dbContext.Ciphers - where c.UserId == _userId - select c; + var query2 = from c in dbContext.Ciphers + where c.UserId == _userId + select c; - var union = query.Union(query2).Select(c => new CipherDetails - { - Id = c.Id, - UserId = c.UserId, - OrganizationId = c.OrganizationId, - Type = c.Type, - Data = c.Data, - Attachments = c.Attachments, - CreationDate = c.CreationDate, - RevisionDate = c.RevisionDate, - DeletedDate = c.DeletedDate, - Favorite = _userId.HasValue && c.Favorites != null && c.Favorites.Contains($"\"{_userId}\":true"), - FolderId = _userId.HasValue && !string.IsNullOrWhiteSpace(c.Folders) ? - Guid.Parse(JObject.Parse(c.Folders)[_userId.Value.ToString()].Value()) : - null, - Edit = true, - ViewPassword = true, - OrganizationUseTotp = false, - }); - return union; - } + var union = query.Union(query2).Select(c => new CipherDetails + { + Id = c.Id, + UserId = c.UserId, + OrganizationId = c.OrganizationId, + Type = c.Type, + Data = c.Data, + Attachments = c.Attachments, + CreationDate = c.CreationDate, + RevisionDate = c.RevisionDate, + DeletedDate = c.DeletedDate, + Favorite = _userId.HasValue && c.Favorites != null && c.Favorites.Contains($"\"{_userId}\":true"), + FolderId = _userId.HasValue && !string.IsNullOrWhiteSpace(c.Folders) ? + Guid.Parse(JObject.Parse(c.Folders)[_userId.Value.ToString()].Value()) : + null, + Edit = true, + ViewPassword = true, + OrganizationUseTotp = false, + }); + return union; } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/UserCollectionDetailsQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/UserCollectionDetailsQuery.cs index c2325f10f..7004a6f75 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/UserCollectionDetailsQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/UserCollectionDetailsQuery.cs @@ -1,53 +1,52 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class UserCollectionDetailsQuery : IQuery { - public class UserCollectionDetailsQuery : IQuery + private readonly Guid? _userId; + public UserCollectionDetailsQuery(Guid? userId) { - private readonly Guid? _userId; - public UserCollectionDetailsQuery(Guid? userId) + _userId = userId; + } + public virtual IQueryable Run(DatabaseContext dbContext) + { + var query = from c in dbContext.Collections + join ou in dbContext.OrganizationUsers + on c.OrganizationId equals ou.OrganizationId + join o in dbContext.Organizations + on c.OrganizationId equals o.Id + join cu in dbContext.CollectionUsers + on c.Id equals cu.CollectionId into cu_g + from cu in cu_g.DefaultIfEmpty() + where ou.AccessAll && cu.OrganizationUserId == ou.Id + join gu in dbContext.GroupUsers + on ou.Id equals gu.OrganizationUserId into gu_g + from gu in gu_g.DefaultIfEmpty() + where cu.CollectionId == null && !ou.AccessAll + join g in dbContext.Groups + on gu.GroupId equals g.Id into g_g + from g in g_g.DefaultIfEmpty() + join cg in dbContext.CollectionGroups + on gu.GroupId equals cg.GroupId into cg_g + from cg in cg_g.DefaultIfEmpty() + where !g.AccessAll && cg.CollectionId == c.Id && + ou.UserId == _userId && + ou.Status == OrganizationUserStatusType.Confirmed && + o.Enabled && + (ou.AccessAll || cu.CollectionId != null || g.AccessAll || cg.CollectionId != null) + select new { c, ou, o, cu, gu, g, cg }; + return query.Select(x => new CollectionDetails { - _userId = userId; - } - public virtual IQueryable Run(DatabaseContext dbContext) - { - var query = from c in dbContext.Collections - join ou in dbContext.OrganizationUsers - on c.OrganizationId equals ou.OrganizationId - join o in dbContext.Organizations - on c.OrganizationId equals o.Id - join cu in dbContext.CollectionUsers - on c.Id equals cu.CollectionId into cu_g - from cu in cu_g.DefaultIfEmpty() - where ou.AccessAll && cu.OrganizationUserId == ou.Id - join gu in dbContext.GroupUsers - on ou.Id equals gu.OrganizationUserId into gu_g - from gu in gu_g.DefaultIfEmpty() - where cu.CollectionId == null && !ou.AccessAll - join g in dbContext.Groups - on gu.GroupId equals g.Id into g_g - from g in g_g.DefaultIfEmpty() - join cg in dbContext.CollectionGroups - on gu.GroupId equals cg.GroupId into cg_g - from cg in cg_g.DefaultIfEmpty() - where !g.AccessAll && cg.CollectionId == c.Id && - ou.UserId == _userId && - ou.Status == OrganizationUserStatusType.Confirmed && - o.Enabled && - (ou.AccessAll || cu.CollectionId != null || g.AccessAll || cg.CollectionId != null) - select new { c, ou, o, cu, gu, g, cg }; - return query.Select(x => new CollectionDetails - { - Id = x.c.Id, - OrganizationId = x.c.OrganizationId, - Name = x.c.Name, - ExternalId = x.c.ExternalId, - CreationDate = x.c.CreationDate, - RevisionDate = x.c.RevisionDate, - ReadOnly = !x.ou.AccessAll || !x.g.AccessAll || (x.cu.ReadOnly || x.cg.ReadOnly), - HidePasswords = !x.ou.AccessAll || !x.g.AccessAll || (x.cu.HidePasswords || x.cg.HidePasswords), - }); - } + Id = x.c.Id, + OrganizationId = x.c.OrganizationId, + Name = x.c.Name, + ExternalId = x.c.ExternalId, + CreationDate = x.c.CreationDate, + RevisionDate = x.c.RevisionDate, + ReadOnly = !x.ou.AccessAll || !x.g.AccessAll || (x.cu.ReadOnly || x.cg.ReadOnly), + HidePasswords = !x.ou.AccessAll || !x.g.AccessAll || (x.cu.HidePasswords || x.cg.HidePasswords), + }); } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/UserReadPublicKeysByProviderUserIdsQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/UserReadPublicKeysByProviderUserIdsQuery.cs index 10782e0ea..db347b99b 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/UserReadPublicKeysByProviderUserIdsQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/UserReadPublicKeysByProviderUserIdsQuery.cs @@ -1,33 +1,32 @@ using Bit.Core.Enums.Provider; using Bit.Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class UserReadPublicKeysByProviderUserIdsQuery : IQuery { - public class UserReadPublicKeysByProviderUserIdsQuery : IQuery + private readonly Guid _providerId; + private readonly IEnumerable _ids; + + public UserReadPublicKeysByProviderUserIdsQuery(Guid providerId, IEnumerable Ids) { - private readonly Guid _providerId; - private readonly IEnumerable _ids; + _providerId = providerId; + _ids = Ids; + } - public UserReadPublicKeysByProviderUserIdsQuery(Guid providerId, IEnumerable Ids) + public virtual IQueryable Run(DatabaseContext dbContext) + { + var query = from pu in dbContext.ProviderUsers + join u in dbContext.Users + on pu.UserId equals u.Id + where _ids.Contains(pu.Id) && + pu.Status == ProviderUserStatusType.Accepted && + pu.ProviderId == _providerId + select new { pu, u }; + return query.Select(x => new ProviderUserPublicKey { - _providerId = providerId; - _ids = Ids; - } - - public virtual IQueryable Run(DatabaseContext dbContext) - { - var query = from pu in dbContext.ProviderUsers - join u in dbContext.Users - on pu.UserId equals u.Id - where _ids.Contains(pu.Id) && - pu.Status == ProviderUserStatusType.Accepted && - pu.ProviderId == _providerId - select new { pu, u }; - return query.Select(x => new ProviderUserPublicKey - { - Id = x.pu.Id, - PublicKey = x.u.PublicKey, - }); - } + Id = x.pu.Id, + PublicKey = x.u.PublicKey, + }); } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Repository.cs b/src/Infrastructure.EntityFramework/Repositories/Repository.cs index 2d933da80..4c509540d 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Repository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Repository.cs @@ -5,118 +5,117 @@ using Bit.Infrastructure.EntityFramework.Repositories.Queries; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public abstract class Repository : BaseEntityFrameworkRepository, IRepository + where TId : IEquatable + where T : class, ITableObject + where TEntity : class, ITableObject { - public abstract class Repository : BaseEntityFrameworkRepository, IRepository - where TId : IEquatable - where T : class, ITableObject - where TEntity : class, ITableObject + public Repository(IServiceScopeFactory serviceScopeFactory, IMapper mapper, Func> getDbSet) + : base(serviceScopeFactory, mapper) { - public Repository(IServiceScopeFactory serviceScopeFactory, IMapper mapper, Func> getDbSet) - : base(serviceScopeFactory, mapper) + GetDbSet = getDbSet; + } + + protected Func> GetDbSet { get; private set; } + + public virtual async Task GetByIdAsync(TId id) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - GetDbSet = getDbSet; + var dbContext = GetDatabaseContext(scope); + var entity = await GetDbSet(dbContext).FindAsync(id); + return Mapper.Map(entity); } + } - protected Func> GetDbSet { get; private set; } - - public virtual async Task GetByIdAsync(TId id) + public virtual async Task CreateAsync(T obj) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entity = await GetDbSet(dbContext).FindAsync(id); - return Mapper.Map(entity); - } + var dbContext = GetDatabaseContext(scope); + obj.SetNewId(); + var entity = Mapper.Map(obj); + await dbContext.AddAsync(entity); + await dbContext.SaveChangesAsync(); + obj.Id = entity.Id; + return obj; } + } - public virtual async Task CreateAsync(T obj) + public virtual async Task ReplaceAsync(T obj) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var entity = await GetDbSet(dbContext).FindAsync(obj.Id); + if (entity != null) { - var dbContext = GetDatabaseContext(scope); - obj.SetNewId(); - var entity = Mapper.Map(obj); - await dbContext.AddAsync(entity); + var mappedEntity = Mapper.Map(obj); + dbContext.Entry(entity).CurrentValues.SetValues(mappedEntity); await dbContext.SaveChangesAsync(); - obj.Id = entity.Id; - return obj; - } - } - - public virtual async Task ReplaceAsync(T obj) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entity = await GetDbSet(dbContext).FindAsync(obj.Id); - if (entity != null) - { - var mappedEntity = Mapper.Map(obj); - dbContext.Entry(entity).CurrentValues.SetValues(mappedEntity); - await dbContext.SaveChangesAsync(); - } - } - } - - public virtual async Task UpsertAsync(T obj) - { - if (obj.Id.Equals(default(TId))) - { - await CreateAsync(obj); - } - else - { - await ReplaceAsync(obj); - } - } - - public virtual async Task DeleteAsync(T obj) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entity = Mapper.Map(obj); - dbContext.Remove(entity); - await dbContext.SaveChangesAsync(); - } - } - - public virtual async Task RefreshDb() - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var context = GetDatabaseContext(scope); - await context.Database.EnsureDeletedAsync(); - await context.Database.EnsureCreatedAsync(); - } - } - - public virtual async Task> CreateMany(List objs) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var entities = new List(); - foreach (var o in objs) - { - o.SetNewId(); - var entity = Mapper.Map(o); - entities.Add(entity); - } - var dbContext = GetDatabaseContext(scope); - await GetDbSet(dbContext).AddRangeAsync(entities); - await dbContext.SaveChangesAsync(); - return objs; - } - } - - public IQueryable Run(IQuery query) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - return query.Run(dbContext); } } } + + public virtual async Task UpsertAsync(T obj) + { + if (obj.Id.Equals(default(TId))) + { + await CreateAsync(obj); + } + else + { + await ReplaceAsync(obj); + } + } + + public virtual async Task DeleteAsync(T obj) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var entity = Mapper.Map(obj); + dbContext.Remove(entity); + await dbContext.SaveChangesAsync(); + } + } + + public virtual async Task RefreshDb() + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var context = GetDatabaseContext(scope); + await context.Database.EnsureDeletedAsync(); + await context.Database.EnsureCreatedAsync(); + } + } + + public virtual async Task> CreateMany(List objs) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var entities = new List(); + foreach (var o in objs) + { + o.SetNewId(); + var entity = Mapper.Map(o); + entities.Add(entity); + } + var dbContext = GetDatabaseContext(scope); + await GetDbSet(dbContext).AddRangeAsync(entities); + await dbContext.SaveChangesAsync(); + return objs; + } + } + + public IQueryable Run(IQuery query) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + return query.Run(dbContext); + } + } } diff --git a/src/Infrastructure.EntityFramework/Repositories/SendRepository.cs b/src/Infrastructure.EntityFramework/Repositories/SendRepository.cs index 691a86c3b..e102ddea7 100644 --- a/src/Infrastructure.EntityFramework/Repositories/SendRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/SendRepository.cs @@ -4,43 +4,42 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class SendRepository : Repository, ISendRepository { - public class SendRepository : Repository, ISendRepository + public SendRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Sends) + { } + + public override async Task CreateAsync(Core.Entities.Send send) { - public SendRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Sends) - { } - - public override async Task CreateAsync(Core.Entities.Send send) + send = await base.CreateAsync(send); + if (send.UserId.HasValue) { - send = await base.CreateAsync(send); - if (send.UserId.HasValue) - { - await UserUpdateStorage(send.UserId.Value); - await UserBumpAccountRevisionDate(send.UserId.Value); - } - return send; + await UserUpdateStorage(send.UserId.Value); + await UserBumpAccountRevisionDate(send.UserId.Value); } + return send; + } - public async Task> GetManyByDeletionDateAsync(DateTime deletionDateBefore) + public async Task> GetManyByDeletionDateAsync(DateTime deletionDateBefore) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.Sends.Where(s => s.DeletionDate < deletionDateBefore).ToListAsync(); - return Mapper.Map>(results); - } + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.Sends.Where(s => s.DeletionDate < deletionDateBefore).ToListAsync(); + return Mapper.Map>(results); } + } - public async Task> GetManyByUserIdAsync(Guid userId) + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.Sends.Where(s => s.UserId == userId).ToListAsync(); - return Mapper.Map>(results); - } + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.Sends.Where(s => s.UserId == userId).ToListAsync(); + return Mapper.Map>(results); } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/SsoConfigRepository.cs b/src/Infrastructure.EntityFramework/Repositories/SsoConfigRepository.cs index 8c0d22165..c9a772e9a 100644 --- a/src/Infrastructure.EntityFramework/Repositories/SsoConfigRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/SsoConfigRepository.cs @@ -4,43 +4,42 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class SsoConfigRepository : Repository, ISsoConfigRepository { - public class SsoConfigRepository : Repository, ISsoConfigRepository + public SsoConfigRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.SsoConfigs) + { } + + public async Task GetByOrganizationIdAsync(Guid organizationId) { - public SsoConfigRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.SsoConfigs) - { } - - public async Task GetByOrganizationIdAsync(Guid organizationId) + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var ssoConfig = await GetDbSet(dbContext).SingleOrDefaultAsync(sc => sc.OrganizationId == organizationId); - return Mapper.Map(ssoConfig); - } + var dbContext = GetDatabaseContext(scope); + var ssoConfig = await GetDbSet(dbContext).SingleOrDefaultAsync(sc => sc.OrganizationId == organizationId); + return Mapper.Map(ssoConfig); } + } - public async Task GetByIdentifierAsync(string identifier) + public async Task GetByIdentifierAsync(string identifier) + { + + using (var scope = ServiceScopeFactory.CreateScope()) { - - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var ssoConfig = await GetDbSet(dbContext).SingleOrDefaultAsync(sc => sc.Organization.Identifier == identifier); - return Mapper.Map(ssoConfig); - } + var dbContext = GetDatabaseContext(scope); + var ssoConfig = await GetDbSet(dbContext).SingleOrDefaultAsync(sc => sc.Organization.Identifier == identifier); + return Mapper.Map(ssoConfig); } + } - public async Task> GetManyByRevisionNotBeforeDate(DateTime? notBefore) + public async Task> GetManyByRevisionNotBeforeDate(DateTime? notBefore) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var ssoConfigs = await GetDbSet(dbContext).Where(sc => sc.Enabled && sc.RevisionDate >= notBefore).ToListAsync(); - return Mapper.Map>(ssoConfigs); - } + var dbContext = GetDatabaseContext(scope); + var ssoConfigs = await GetDbSet(dbContext).Where(sc => sc.Enabled && sc.RevisionDate >= notBefore).ToListAsync(); + return Mapper.Map>(ssoConfigs); } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/SsoUserRepository.cs b/src/Infrastructure.EntityFramework/Repositories/SsoUserRepository.cs index c413648dc..f41f0d540 100644 --- a/src/Infrastructure.EntityFramework/Repositories/SsoUserRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/SsoUserRepository.cs @@ -4,34 +4,33 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class SsoUserRepository : Repository, ISsoUserRepository { - public class SsoUserRepository : Repository, ISsoUserRepository + public SsoUserRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.SsoUsers) + { } + + public async Task DeleteAsync(Guid userId, Guid? organizationId) { - public SsoUserRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.SsoUsers) - { } - - public async Task DeleteAsync(Guid userId, Guid? organizationId) + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entity = await GetDbSet(dbContext).SingleOrDefaultAsync(su => su.UserId == userId && su.OrganizationId == organizationId); - dbContext.Entry(entity).State = EntityState.Deleted; - await dbContext.SaveChangesAsync(); - } + var dbContext = GetDatabaseContext(scope); + var entity = await GetDbSet(dbContext).SingleOrDefaultAsync(su => su.UserId == userId && su.OrganizationId == organizationId); + dbContext.Entry(entity).State = EntityState.Deleted; + await dbContext.SaveChangesAsync(); } + } - public async Task GetByUserIdOrganizationIdAsync(Guid organizationId, Guid userId) + public async Task GetByUserIdOrganizationIdAsync(Guid organizationId, Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entity = await GetDbSet(dbContext) - .FirstOrDefaultAsync(e => e.OrganizationId == organizationId && e.UserId == userId); - return entity; - } + var dbContext = GetDatabaseContext(scope); + var entity = await GetDbSet(dbContext) + .FirstOrDefaultAsync(e => e.OrganizationId == organizationId && e.UserId == userId); + return entity; } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/TaxRateRepository.cs b/src/Infrastructure.EntityFramework/Repositories/TaxRateRepository.cs index a575892d4..fcf4014a1 100644 --- a/src/Infrastructure.EntityFramework/Repositories/TaxRateRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/TaxRateRepository.cs @@ -4,64 +4,63 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class TaxRateRepository : Repository, ITaxRateRepository { - public class TaxRateRepository : Repository, ITaxRateRepository + public TaxRateRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.TaxRates) + { } + + public async Task ArchiveAsync(Core.Entities.TaxRate model) { - public TaxRateRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.TaxRates) - { } - - public async Task ArchiveAsync(Core.Entities.TaxRate model) + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entity = await dbContext.FindAsync(model); - entity.Active = false; - await dbContext.SaveChangesAsync(); - } + var dbContext = GetDatabaseContext(scope); + var entity = await dbContext.FindAsync(model); + entity.Active = false; + await dbContext.SaveChangesAsync(); } + } - public async Task> GetAllActiveAsync() + public async Task> GetAllActiveAsync() + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.TaxRates - .Where(t => t.Active) - .ToListAsync(); - return Mapper.Map>(results); - } + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.TaxRates + .Where(t => t.Active) + .ToListAsync(); + return Mapper.Map>(results); } + } - public async Task> GetByLocationAsync(Core.Entities.TaxRate taxRate) + public async Task> GetByLocationAsync(Core.Entities.TaxRate taxRate) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.TaxRates - .Where(t => t.Active && - t.Country == taxRate.Country && - t.PostalCode == taxRate.PostalCode) - .ToListAsync(); - return Mapper.Map>(results); - } + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.TaxRates + .Where(t => t.Active && + t.Country == taxRate.Country && + t.PostalCode == taxRate.PostalCode) + .ToListAsync(); + return Mapper.Map>(results); } + } - public async Task> SearchAsync(int skip, int count) + public async Task> SearchAsync(int skip, int count) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.TaxRates - .Skip(skip) - .Take(count) - .Where(t => t.Active) - .OrderBy(t => t.Country).ThenByDescending(t => t.PostalCode) - .ToListAsync(); - return Mapper.Map>(results); - } + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.TaxRates + .Skip(skip) + .Take(count) + .Where(t => t.Active) + .OrderBy(t => t.Country).ThenByDescending(t => t.PostalCode) + .ToListAsync(); + return Mapper.Map>(results); } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/TransactionRepository.cs b/src/Infrastructure.EntityFramework/Repositories/TransactionRepository.cs index fad4389ed..45c052cbb 100644 --- a/src/Infrastructure.EntityFramework/Repositories/TransactionRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/TransactionRepository.cs @@ -5,47 +5,46 @@ using Bit.Infrastructure.EntityFramework.Models; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class TransactionRepository : Repository, ITransactionRepository { - public class TransactionRepository : Repository, ITransactionRepository + public TransactionRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Transactions) + { } + + public async Task GetByGatewayIdAsync(GatewayType gatewayType, string gatewayId) { - public TransactionRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Transactions) - { } - - public async Task GetByGatewayIdAsync(GatewayType gatewayType, string gatewayId) + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.Transactions - .FirstOrDefaultAsync(t => (t.GatewayId == gatewayId && t.Gateway == gatewayType)); - return Mapper.Map(results); - } + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.Transactions + .FirstOrDefaultAsync(t => (t.GatewayId == gatewayId && t.Gateway == gatewayType)); + return Mapper.Map(results); } + } - public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + public async Task> GetManyByOrganizationIdAsync(Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.Transactions - .Where(t => (t.OrganizationId == organizationId && !t.UserId.HasValue)) - .ToListAsync(); - return Mapper.Map>(results); - } + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.Transactions + .Where(t => (t.OrganizationId == organizationId && !t.UserId.HasValue)) + .ToListAsync(); + return Mapper.Map>(results); } + } - public async Task> GetManyByUserIdAsync(Guid userId) + public async Task> GetManyByUserIdAsync(Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var results = await dbContext.Transactions - .Where(t => (t.UserId == userId)) - .ToListAsync(); - return Mapper.Map>(results); - } + var dbContext = GetDatabaseContext(scope); + var results = await dbContext.Transactions + .Where(t => (t.UserId == userId)) + .ToListAsync(); + return Mapper.Map>(results); } } } diff --git a/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs b/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs index 4074debf2..be75e85ff 100644 --- a/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs @@ -5,142 +5,141 @@ using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; using DataModel = Bit.Core.Models.Data; -namespace Bit.Infrastructure.EntityFramework.Repositories +namespace Bit.Infrastructure.EntityFramework.Repositories; + +public class UserRepository : Repository, IUserRepository { - public class UserRepository : Repository, IUserRepository + public UserRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Users) + { } + + public async Task GetByEmailAsync(string email) { - public UserRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.Users) - { } - - public async Task GetByEmailAsync(string email) + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var entity = await GetDbSet(dbContext).FirstOrDefaultAsync(e => e.Email == email); - return Mapper.Map(entity); - } + var dbContext = GetDatabaseContext(scope); + var entity = await GetDbSet(dbContext).FirstOrDefaultAsync(e => e.Email == email); + return Mapper.Map(entity); } + } - public async Task GetKdfInformationByEmailAsync(string email) + public async Task GetKdfInformationByEmailAsync(string email) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - return await GetDbSet(dbContext).Where(e => e.Email == email) - .Select(e => new DataModel.UserKdfInformation - { - Kdf = e.Kdf, - KdfIterations = e.KdfIterations - }).SingleOrDefaultAsync(); - } - } - - public async Task> SearchAsync(string email, int skip, int take) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - List users; - if (dbContext.Database.IsNpgsql()) + var dbContext = GetDatabaseContext(scope); + return await GetDbSet(dbContext).Where(e => e.Email == email) + .Select(e => new DataModel.UserKdfInformation { - users = await GetDbSet(dbContext) - .Where(e => e.Email == null || - EF.Functions.ILike(EF.Functions.Collate(e.Email, "default"), "a%")) - .OrderBy(e => e.Email) - .Skip(skip).Take(take) - .ToListAsync(); - } - else - { - users = await GetDbSet(dbContext) - .Where(e => email == null || e.Email.StartsWith(email)) - .OrderBy(e => e.Email) - .Skip(skip).Take(take) - .ToListAsync(); - } - return Mapper.Map>(users); - } + Kdf = e.Kdf, + KdfIterations = e.KdfIterations + }).SingleOrDefaultAsync(); } + } - public async Task> GetManyByPremiumAsync(bool premium) + public async Task> SearchAsync(string email, int skip, int take) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + List users; + if (dbContext.Database.IsNpgsql()) { - var dbContext = GetDatabaseContext(scope); - var users = await GetDbSet(dbContext).Where(e => e.Premium == premium).ToListAsync(); - return Mapper.Map>(users); + users = await GetDbSet(dbContext) + .Where(e => e.Email == null || + EF.Functions.ILike(EF.Functions.Collate(e.Email, "default"), "a%")) + .OrderBy(e => e.Email) + .Skip(skip).Take(take) + .ToListAsync(); } - } - - public async Task GetPublicKeyAsync(Guid id) - { - using (var scope = ServiceScopeFactory.CreateScope()) + else { - var dbContext = GetDatabaseContext(scope); - return await GetDbSet(dbContext).Where(e => e.Id == id).Select(e => e.PublicKey).SingleOrDefaultAsync(); + users = await GetDbSet(dbContext) + .Where(e => email == null || e.Email.StartsWith(email)) + .OrderBy(e => e.Email) + .Skip(skip).Take(take) + .ToListAsync(); } + return Mapper.Map>(users); } + } - public async Task GetAccountRevisionDateAsync(Guid id) + public async Task> GetManyByPremiumAsync(bool premium) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) + var dbContext = GetDatabaseContext(scope); + var users = await GetDbSet(dbContext).Where(e => e.Premium == premium).ToListAsync(); + return Mapper.Map>(users); + } + } + + public async Task GetPublicKeyAsync(Guid id) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + return await GetDbSet(dbContext).Where(e => e.Id == id).Select(e => e.PublicKey).SingleOrDefaultAsync(); + } + } + + public async Task GetAccountRevisionDateAsync(Guid id) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + return await GetDbSet(dbContext).Where(e => e.Id == id).Select(e => e.AccountRevisionDate) + .SingleOrDefaultAsync(); + } + } + + public async Task UpdateStorageAsync(Guid id) + { + await base.UserUpdateStorage(id); + } + + public async Task UpdateRenewalReminderDateAsync(Guid id, DateTime renewalReminderDate) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var user = new User { - var dbContext = GetDatabaseContext(scope); - return await GetDbSet(dbContext).Where(e => e.Id == id).Select(e => e.AccountRevisionDate) - .SingleOrDefaultAsync(); - } + Id = id, + RenewalReminderDate = renewalReminderDate, + }; + var set = GetDbSet(dbContext); + set.Attach(user); + dbContext.Entry(user).Property(e => e.RenewalReminderDate).IsModified = true; + await dbContext.SaveChangesAsync(); } + } - public async Task UpdateStorageAsync(Guid id) + public async Task GetBySsoUserAsync(string externalId, Guid? organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - await base.UserUpdateStorage(id); - } + var dbContext = GetDatabaseContext(scope); + var ssoUser = await dbContext.SsoUsers.SingleOrDefaultAsync(e => + e.OrganizationId == organizationId && e.ExternalId == externalId); - public async Task UpdateRenewalReminderDateAsync(Guid id, DateTime renewalReminderDate) - { - using (var scope = ServiceScopeFactory.CreateScope()) + if (ssoUser == null) { - var dbContext = GetDatabaseContext(scope); - var user = new User - { - Id = id, - RenewalReminderDate = renewalReminderDate, - }; - var set = GetDbSet(dbContext); - set.Attach(user); - dbContext.Entry(user).Property(e => e.RenewalReminderDate).IsModified = true; - await dbContext.SaveChangesAsync(); + return null; } + + var entity = await dbContext.Users.SingleOrDefaultAsync(e => e.Id == ssoUser.UserId); + return Mapper.Map(entity); } + } - public async Task GetBySsoUserAsync(string externalId, Guid? organizationId) + public async Task> GetManyAsync(IEnumerable ids) + { + using (var scope = ServiceScopeFactory.CreateScope()) { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var ssoUser = await dbContext.SsoUsers.SingleOrDefaultAsync(e => - e.OrganizationId == organizationId && e.ExternalId == externalId); - - if (ssoUser == null) - { - return null; - } - - var entity = await dbContext.Users.SingleOrDefaultAsync(e => e.Id == ssoUser.UserId); - return Mapper.Map(entity); - } - } - - public async Task> GetManyAsync(IEnumerable ids) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var users = dbContext.Users.Where(x => ids.Contains(x.Id)); - return await users.ToListAsync(); - } + var dbContext = GetDatabaseContext(scope); + var users = dbContext.Users.Where(x => ids.Contains(x.Id)); + return await users.ToListAsync(); } } } diff --git a/src/Notifications/AzureQueueHostedService.cs b/src/Notifications/AzureQueueHostedService.cs index edc735d86..ba2e38d2c 100644 --- a/src/Notifications/AzureQueueHostedService.cs +++ b/src/Notifications/AzureQueueHostedService.cs @@ -3,91 +3,90 @@ using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.AspNetCore.SignalR; -namespace Bit.Notifications +namespace Bit.Notifications; + +public class AzureQueueHostedService : IHostedService, IDisposable { - public class AzureQueueHostedService : IHostedService, IDisposable + private readonly ILogger _logger; + private readonly IHubContext _hubContext; + private readonly GlobalSettings _globalSettings; + + private Task _executingTask; + private CancellationTokenSource _cts; + private QueueClient _queueClient; + + public AzureQueueHostedService( + ILogger logger, + IHubContext hubContext, + GlobalSettings globalSettings) { - private readonly ILogger _logger; - private readonly IHubContext _hubContext; - private readonly GlobalSettings _globalSettings; + _logger = logger; + _hubContext = hubContext; + _globalSettings = globalSettings; + } - private Task _executingTask; - private CancellationTokenSource _cts; - private QueueClient _queueClient; + public Task StartAsync(CancellationToken cancellationToken) + { + _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _executingTask = ExecuteAsync(_cts.Token); + return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; + } - public AzureQueueHostedService( - ILogger logger, - IHubContext hubContext, - GlobalSettings globalSettings) + public async Task StopAsync(CancellationToken cancellationToken) + { + if (_executingTask == null) { - _logger = logger; - _hubContext = hubContext; - _globalSettings = globalSettings; + return; } + _logger.LogWarning("Stopping service."); + _cts.Cancel(); + await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); + cancellationToken.ThrowIfCancellationRequested(); + } - public Task StartAsync(CancellationToken cancellationToken) - { - _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - _executingTask = ExecuteAsync(_cts.Token); - return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; - } + public void Dispose() + { } - public async Task StopAsync(CancellationToken cancellationToken) + private async Task ExecuteAsync(CancellationToken cancellationToken) + { + _queueClient = new QueueClient(_globalSettings.Notifications.ConnectionString, "notifications"); + while (!cancellationToken.IsCancellationRequested) { - if (_executingTask == null) + try { - return; - } - _logger.LogWarning("Stopping service."); - _cts.Cancel(); - await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); - cancellationToken.ThrowIfCancellationRequested(); - } - - public void Dispose() - { } - - private async Task ExecuteAsync(CancellationToken cancellationToken) - { - _queueClient = new QueueClient(_globalSettings.Notifications.ConnectionString, "notifications"); - while (!cancellationToken.IsCancellationRequested) - { - try + var messages = await _queueClient.ReceiveMessagesAsync(32); + if (messages.Value?.Any() ?? false) { - var messages = await _queueClient.ReceiveMessagesAsync(32); - if (messages.Value?.Any() ?? false) + foreach (var message in messages.Value) { - foreach (var message in messages.Value) + try { - try + await HubHelpers.SendNotificationToHubAsync( + message.DecodeMessageText(), _hubContext, cancellationToken); + await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + } + catch (Exception e) + { + _logger.LogError("Error processing dequeued message: " + + $"{message.MessageId} x{message.DequeueCount}. {e.Message}", e); + if (message.DequeueCount > 2) { - await HubHelpers.SendNotificationToHubAsync( - message.DecodeMessageText(), _hubContext, cancellationToken); await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); } - catch (Exception e) - { - _logger.LogError("Error processing dequeued message: " + - $"{message.MessageId} x{message.DequeueCount}. {e.Message}", e); - if (message.DequeueCount > 2) - { - await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); - } - } } } - else - { - await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken); - } } - catch (Exception e) + else { - _logger.LogError("Error processing messages.", e); + await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken); } } - - _logger.LogWarning("Done processing."); + catch (Exception e) + { + _logger.LogError("Error processing messages.", e); + } } + + _logger.LogWarning("Done processing."); } } diff --git a/src/Notifications/ConnectionCounter.cs b/src/Notifications/ConnectionCounter.cs index 330b8ee6d..25d315616 100644 --- a/src/Notifications/ConnectionCounter.cs +++ b/src/Notifications/ConnectionCounter.cs @@ -1,27 +1,26 @@ -namespace Bit.Notifications +namespace Bit.Notifications; + +public class ConnectionCounter { - public class ConnectionCounter + private int _count = 0; + + public void Increment() { - private int _count = 0; + Interlocked.Increment(ref _count); + } - public void Increment() - { - Interlocked.Increment(ref _count); - } + public void Decrement() + { + Interlocked.Decrement(ref _count); + } - public void Decrement() - { - Interlocked.Decrement(ref _count); - } + public void Reset() + { + _count = 0; + } - public void Reset() - { - _count = 0; - } - - public int GetCount() - { - return _count; - } + public int GetCount() + { + return _count; } } diff --git a/src/Notifications/Controllers/InfoController.cs b/src/Notifications/Controllers/InfoController.cs index 402fc4937..6a8eaf282 100644 --- a/src/Notifications/Controllers/InfoController.cs +++ b/src/Notifications/Controllers/InfoController.cs @@ -1,21 +1,20 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Mvc; -namespace Bit.Notifications.Controllers -{ - public class InfoController : Controller - { - [HttpGet("~/alive")] - [HttpGet("~/now")] - public DateTime GetAlive() - { - return DateTime.UtcNow; - } +namespace Bit.Notifications.Controllers; - [HttpGet("~/version")] - public JsonResult GetVersion() - { - return Json(CoreHelpers.GetVersion()); - } +public class InfoController : Controller +{ + [HttpGet("~/alive")] + [HttpGet("~/now")] + public DateTime GetAlive() + { + return DateTime.UtcNow; + } + + [HttpGet("~/version")] + public JsonResult GetVersion() + { + return Json(CoreHelpers.GetVersion()); } } diff --git a/src/Notifications/Controllers/SendController.cs b/src/Notifications/Controllers/SendController.cs index 81698c911..90fdac7d0 100644 --- a/src/Notifications/Controllers/SendController.cs +++ b/src/Notifications/Controllers/SendController.cs @@ -4,29 +4,28 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.SignalR; -namespace Bit.Notifications +namespace Bit.Notifications; + +[Authorize("Internal")] +public class SendController : Controller { - [Authorize("Internal")] - public class SendController : Controller + private readonly IHubContext _hubContext; + + public SendController(IHubContext hubContext) { - private readonly IHubContext _hubContext; + _hubContext = hubContext; + } - public SendController(IHubContext hubContext) + [HttpPost("~/send")] + [SelfHosted(SelfHostedOnly = true)] + public async Task PostSend() + { + using (var reader = new StreamReader(Request.Body, Encoding.UTF8)) { - _hubContext = hubContext; - } - - [HttpPost("~/send")] - [SelfHosted(SelfHostedOnly = true)] - public async Task PostSend() - { - using (var reader = new StreamReader(Request.Body, Encoding.UTF8)) + var notificationJson = await reader.ReadToEndAsync(); + if (!string.IsNullOrWhiteSpace(notificationJson)) { - var notificationJson = await reader.ReadToEndAsync(); - if (!string.IsNullOrWhiteSpace(notificationJson)) - { - await HubHelpers.SendNotificationToHubAsync(notificationJson, _hubContext); - } + await HubHelpers.SendNotificationToHubAsync(notificationJson, _hubContext); } } } diff --git a/src/Notifications/HeartbeatHostedService.cs b/src/Notifications/HeartbeatHostedService.cs index e91666926..717fdeb78 100644 --- a/src/Notifications/HeartbeatHostedService.cs +++ b/src/Notifications/HeartbeatHostedService.cs @@ -1,57 +1,56 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.SignalR; -namespace Bit.Notifications +namespace Bit.Notifications; + +public class HeartbeatHostedService : IHostedService, IDisposable { - public class HeartbeatHostedService : IHostedService, IDisposable + private readonly ILogger _logger; + private readonly IHubContext _hubContext; + private readonly GlobalSettings _globalSettings; + + private Task _executingTask; + private CancellationTokenSource _cts; + + public HeartbeatHostedService( + ILogger logger, + IHubContext hubContext, + GlobalSettings globalSettings) { - private readonly ILogger _logger; - private readonly IHubContext _hubContext; - private readonly GlobalSettings _globalSettings; + _logger = logger; + _hubContext = hubContext; + _globalSettings = globalSettings; + } - private Task _executingTask; - private CancellationTokenSource _cts; + public Task StartAsync(CancellationToken cancellationToken) + { + _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _executingTask = ExecuteAsync(_cts.Token); + return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; + } - public HeartbeatHostedService( - ILogger logger, - IHubContext hubContext, - GlobalSettings globalSettings) + public async Task StopAsync(CancellationToken cancellationToken) + { + if (_executingTask == null) { - _logger = logger; - _hubContext = hubContext; - _globalSettings = globalSettings; + return; } + _logger.LogWarning("Stopping service."); + _cts.Cancel(); + await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); + cancellationToken.ThrowIfCancellationRequested(); + } - public Task StartAsync(CancellationToken cancellationToken) + public void Dispose() + { } + + private async Task ExecuteAsync(CancellationToken cancellationToken) + { + while (!cancellationToken.IsCancellationRequested) { - _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - _executingTask = ExecuteAsync(_cts.Token); - return _executingTask.IsCompleted ? _executingTask : Task.CompletedTask; - } - - public async Task StopAsync(CancellationToken cancellationToken) - { - if (_executingTask == null) - { - return; - } - _logger.LogWarning("Stopping service."); - _cts.Cancel(); - await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); - cancellationToken.ThrowIfCancellationRequested(); - } - - public void Dispose() - { } - - private async Task ExecuteAsync(CancellationToken cancellationToken) - { - while (!cancellationToken.IsCancellationRequested) - { - await _hubContext.Clients.All.SendAsync("Heartbeat"); - await Task.Delay(120000); - } - _logger.LogWarning("Done with heartbeat."); + await _hubContext.Clients.All.SendAsync("Heartbeat"); + await Task.Delay(120000); } + _logger.LogWarning("Done with heartbeat."); } } diff --git a/src/Notifications/HubHelpers.cs b/src/Notifications/HubHelpers.cs index 2ba0037f4..38b87e227 100644 --- a/src/Notifications/HubHelpers.cs +++ b/src/Notifications/HubHelpers.cs @@ -3,67 +3,66 @@ using Bit.Core.Enums; using Bit.Core.Models; using Microsoft.AspNetCore.SignalR; -namespace Bit.Notifications +namespace Bit.Notifications; + +public static class HubHelpers { - public static class HubHelpers + public static async Task SendNotificationToHubAsync(string notificationJson, + IHubContext hubContext, CancellationToken cancellationToken = default(CancellationToken)) { - public static async Task SendNotificationToHubAsync(string notificationJson, - IHubContext hubContext, CancellationToken cancellationToken = default(CancellationToken)) + var notification = JsonSerializer.Deserialize>(notificationJson); + switch (notification.Type) { - var notification = JsonSerializer.Deserialize>(notificationJson); - switch (notification.Type) - { - case PushType.SyncCipherUpdate: - case PushType.SyncCipherCreate: - case PushType.SyncCipherDelete: - case PushType.SyncLoginDelete: - var cipherNotification = - JsonSerializer.Deserialize>( + case PushType.SyncCipherUpdate: + case PushType.SyncCipherCreate: + case PushType.SyncCipherDelete: + case PushType.SyncLoginDelete: + var cipherNotification = + JsonSerializer.Deserialize>( + notificationJson); + if (cipherNotification.Payload.UserId.HasValue) + { + await hubContext.Clients.User(cipherNotification.Payload.UserId.ToString()) + .SendAsync("ReceiveMessage", cipherNotification, cancellationToken); + } + else if (cipherNotification.Payload.OrganizationId.HasValue) + { + await hubContext.Clients.Group( + $"Organization_{cipherNotification.Payload.OrganizationId}") + .SendAsync("ReceiveMessage", cipherNotification, cancellationToken); + } + break; + case PushType.SyncFolderUpdate: + case PushType.SyncFolderCreate: + case PushType.SyncFolderDelete: + var folderNotification = + JsonSerializer.Deserialize>( + notificationJson); + await hubContext.Clients.User(folderNotification.Payload.UserId.ToString()) + .SendAsync("ReceiveMessage", folderNotification, cancellationToken); + break; + case PushType.SyncCiphers: + case PushType.SyncVault: + case PushType.SyncOrgKeys: + case PushType.SyncSettings: + case PushType.LogOut: + var userNotification = + JsonSerializer.Deserialize>( + notificationJson); + await hubContext.Clients.User(userNotification.Payload.UserId.ToString()) + .SendAsync("ReceiveMessage", userNotification, cancellationToken); + break; + case PushType.SyncSendCreate: + case PushType.SyncSendUpdate: + case PushType.SyncSendDelete: + var sendNotification = + JsonSerializer.Deserialize>( notificationJson); - if (cipherNotification.Payload.UserId.HasValue) - { - await hubContext.Clients.User(cipherNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", cipherNotification, cancellationToken); - } - else if (cipherNotification.Payload.OrganizationId.HasValue) - { - await hubContext.Clients.Group( - $"Organization_{cipherNotification.Payload.OrganizationId}") - .SendAsync("ReceiveMessage", cipherNotification, cancellationToken); - } - break; - case PushType.SyncFolderUpdate: - case PushType.SyncFolderCreate: - case PushType.SyncFolderDelete: - var folderNotification = - JsonSerializer.Deserialize>( - notificationJson); - await hubContext.Clients.User(folderNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", folderNotification, cancellationToken); - break; - case PushType.SyncCiphers: - case PushType.SyncVault: - case PushType.SyncOrgKeys: - case PushType.SyncSettings: - case PushType.LogOut: - var userNotification = - JsonSerializer.Deserialize>( - notificationJson); - await hubContext.Clients.User(userNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", userNotification, cancellationToken); - break; - case PushType.SyncSendCreate: - case PushType.SyncSendUpdate: - case PushType.SyncSendDelete: - var sendNotification = - JsonSerializer.Deserialize>( - notificationJson); - await hubContext.Clients.User(sendNotification.Payload.UserId.ToString()) - .SendAsync("ReceiveMessage", sendNotification, cancellationToken); - break; - default: - break; - } + await hubContext.Clients.User(sendNotification.Payload.UserId.ToString()) + .SendAsync("ReceiveMessage", sendNotification, cancellationToken); + break; + default: + break; } } } diff --git a/src/Notifications/Jobs/JobsHostedService.cs b/src/Notifications/Jobs/JobsHostedService.cs index 326e8dce0..a1f84e18b 100644 --- a/src/Notifications/Jobs/JobsHostedService.cs +++ b/src/Notifications/Jobs/JobsHostedService.cs @@ -2,36 +2,35 @@ using Bit.Core.Settings; using Quartz; -namespace Bit.Notifications.Jobs +namespace Bit.Notifications.Jobs; + +public class JobsHostedService : BaseJobsHostedService { - public class JobsHostedService : BaseJobsHostedService + public JobsHostedService( + GlobalSettings globalSettings, + IServiceProvider serviceProvider, + ILogger logger, + ILogger listenerLogger) + : base(globalSettings, serviceProvider, logger, listenerLogger) { } + + public override async Task StartAsync(CancellationToken cancellationToken) { - public JobsHostedService( - GlobalSettings globalSettings, - IServiceProvider serviceProvider, - ILogger logger, - ILogger listenerLogger) - : base(globalSettings, serviceProvider, logger, listenerLogger) { } + var everyFiveMinutesTrigger = TriggerBuilder.Create() + .WithIdentity("EveryFiveMinutesTrigger") + .StartNow() + .WithCronSchedule("0 */30 * * * ?") + .Build(); - public override async Task StartAsync(CancellationToken cancellationToken) + Jobs = new List> { - var everyFiveMinutesTrigger = TriggerBuilder.Create() - .WithIdentity("EveryFiveMinutesTrigger") - .StartNow() - .WithCronSchedule("0 */30 * * * ?") - .Build(); + new Tuple(typeof(LogConnectionCounterJob), everyFiveMinutesTrigger) + }; - Jobs = new List> - { - new Tuple(typeof(LogConnectionCounterJob), everyFiveMinutesTrigger) - }; + await base.StartAsync(cancellationToken); + } - await base.StartAsync(cancellationToken); - } - - public static void AddJobsServices(IServiceCollection services) - { - services.AddTransient(); - } + public static void AddJobsServices(IServiceCollection services) + { + services.AddTransient(); } } diff --git a/src/Notifications/Jobs/LogConnectionCounterJob.cs b/src/Notifications/Jobs/LogConnectionCounterJob.cs index 6b7bc70ff..9b4e2ee4f 100644 --- a/src/Notifications/Jobs/LogConnectionCounterJob.cs +++ b/src/Notifications/Jobs/LogConnectionCounterJob.cs @@ -2,25 +2,24 @@ using Bit.Core.Jobs; using Quartz; -namespace Bit.Notifications.Jobs +namespace Bit.Notifications.Jobs; + +public class LogConnectionCounterJob : BaseJob { - public class LogConnectionCounterJob : BaseJob + private readonly ConnectionCounter _connectionCounter; + + public LogConnectionCounterJob( + ILogger logger, + ConnectionCounter connectionCounter) + : base(logger) { - private readonly ConnectionCounter _connectionCounter; + _connectionCounter = connectionCounter; + } - public LogConnectionCounterJob( - ILogger logger, - ConnectionCounter connectionCounter) - : base(logger) - { - _connectionCounter = connectionCounter; - } - - protected override Task ExecuteJobAsync(IJobExecutionContext context) - { - _logger.LogInformation(Constants.BypassFiltersEventId, - "Connection count for server {0}: {1}", Environment.MachineName, _connectionCounter.GetCount()); - return Task.FromResult(0); - } + protected override Task ExecuteJobAsync(IJobExecutionContext context) + { + _logger.LogInformation(Constants.BypassFiltersEventId, + "Connection count for server {0}: {1}", Environment.MachineName, _connectionCounter.GetCount()); + return Task.FromResult(0); } } diff --git a/src/Notifications/NotificationsHub.cs b/src/Notifications/NotificationsHub.cs index 7d6e94a42..6d7a66b89 100644 --- a/src/Notifications/NotificationsHub.cs +++ b/src/Notifications/NotificationsHub.cs @@ -2,48 +2,47 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Authorization; -namespace Bit.Notifications +namespace Bit.Notifications; + +[Authorize("Application")] +public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub { - [Authorize("Application")] - public class NotificationsHub : Microsoft.AspNetCore.SignalR.Hub + private readonly ConnectionCounter _connectionCounter; + private readonly GlobalSettings _globalSettings; + + public NotificationsHub(ConnectionCounter connectionCounter, GlobalSettings globalSettings) { - private readonly ConnectionCounter _connectionCounter; - private readonly GlobalSettings _globalSettings; + _connectionCounter = connectionCounter; + _globalSettings = globalSettings; + } - public NotificationsHub(ConnectionCounter connectionCounter, GlobalSettings globalSettings) + public override async Task OnConnectedAsync() + { + var currentContext = new CurrentContext(null); + await currentContext.BuildAsync(Context.User, _globalSettings); + if (currentContext.Organizations != null) { - _connectionCounter = connectionCounter; - _globalSettings = globalSettings; - } - - public override async Task OnConnectedAsync() - { - var currentContext = new CurrentContext(null); - await currentContext.BuildAsync(Context.User, _globalSettings); - if (currentContext.Organizations != null) + foreach (var org in currentContext.Organizations) { - foreach (var org in currentContext.Organizations) - { - await Groups.AddToGroupAsync(Context.ConnectionId, $"Organization_{org.Id}"); - } + await Groups.AddToGroupAsync(Context.ConnectionId, $"Organization_{org.Id}"); } - _connectionCounter.Increment(); - await base.OnConnectedAsync(); } + _connectionCounter.Increment(); + await base.OnConnectedAsync(); + } - public override async Task OnDisconnectedAsync(Exception exception) + public override async Task OnDisconnectedAsync(Exception exception) + { + var currentContext = new CurrentContext(null); + await currentContext.BuildAsync(Context.User, _globalSettings); + if (currentContext.Organizations != null) { - var currentContext = new CurrentContext(null); - await currentContext.BuildAsync(Context.User, _globalSettings); - if (currentContext.Organizations != null) + foreach (var org in currentContext.Organizations) { - foreach (var org in currentContext.Organizations) - { - await Groups.RemoveFromGroupAsync(Context.ConnectionId, $"Organization_{org.Id}"); - } + await Groups.RemoveFromGroupAsync(Context.ConnectionId, $"Organization_{org.Id}"); } - _connectionCounter.Decrement(); - await base.OnDisconnectedAsync(exception); } + _connectionCounter.Decrement(); + await base.OnDisconnectedAsync(exception); } } diff --git a/src/Notifications/Program.cs b/src/Notifications/Program.cs index 8ea3a5a1b..4834972ab 100644 --- a/src/Notifications/Program.cs +++ b/src/Notifications/Program.cs @@ -1,51 +1,50 @@ using Bit.Core.Utilities; using Serilog.Events; -namespace Bit.Notifications +namespace Bit.Notifications; + +public class Program { - public class Program + public static void Main(string[] args) { - public static void Main(string[] args) - { - Host - .CreateDefaultBuilder(args) - .ConfigureCustomAppConfiguration(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, e => + Host + .CreateDefaultBuilder(args) + .ConfigureCustomAppConfiguration(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder.UseStartup(); + webBuilder.ConfigureLogging((hostingContext, logging) => + logging.AddSerilog(hostingContext, e => + { + var context = e.Properties["SourceContext"].ToString(); + if (context.Contains("IdentityServer4.Validation.TokenValidator") || + context.Contains("IdentityServer4.Validation.TokenRequestValidator")) { - var context = e.Properties["SourceContext"].ToString(); - if (context.Contains("IdentityServer4.Validation.TokenValidator") || - context.Contains("IdentityServer4.Validation.TokenRequestValidator")) - { - return e.Level > LogEventLevel.Error; - } + return e.Level > LogEventLevel.Error; + } - if (e.Level == LogEventLevel.Error && - e.MessageTemplate.Text == "Failed connection handshake.") - { - return false; - } + if (e.Level == LogEventLevel.Error && + e.MessageTemplate.Text == "Failed connection handshake.") + { + return false; + } - if (e.Level == LogEventLevel.Error && - e.MessageTemplate.Text.StartsWith("Failed writing message.")) - { - return false; - } + if (e.Level == LogEventLevel.Error && + e.MessageTemplate.Text.StartsWith("Failed writing message.")) + { + return false; + } - if (e.Level == LogEventLevel.Warning && - e.MessageTemplate.Text.StartsWith("Heartbeat took longer")) - { - return false; - } + if (e.Level == LogEventLevel.Warning && + e.MessageTemplate.Text.StartsWith("Heartbeat took longer")) + { + return false; + } - return e.Level >= LogEventLevel.Warning; - })); - }) - .Build() - .Run(); - } + return e.Level >= LogEventLevel.Warning; + })); + }) + .Build() + .Run(); } } diff --git a/src/Notifications/Startup.cs b/src/Notifications/Startup.cs index c42b9fcd2..e2509e46c 100644 --- a/src/Notifications/Startup.cs +++ b/src/Notifications/Startup.cs @@ -6,115 +6,114 @@ using IdentityModel; using Microsoft.AspNetCore.SignalR; using Microsoft.IdentityModel.Logging; -namespace Bit.Notifications +namespace Bit.Notifications; + +public class Startup { - public class Startup + public Startup(IWebHostEnvironment env, IConfiguration configuration) { - public Startup(IWebHostEnvironment env, IConfiguration configuration) + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Configuration = configuration; + Environment = env; + } + + public IConfiguration Configuration { get; } + public IWebHostEnvironment Environment { get; set; } + + public void ConfigureServices(IServiceCollection services) + { + // Options + services.AddOptions(); + + // Settings + var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); + + // Identity + services.AddIdentityAuthenticationServices(globalSettings, Environment, config => { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - Configuration = configuration; - Environment = env; + config.AddPolicy("Application", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external"); + policy.RequireClaim(JwtClaimTypes.Scope, "api"); + }); + config.AddPolicy("Internal", policy => + { + policy.RequireAuthenticatedUser(); + policy.RequireClaim(JwtClaimTypes.Scope, "internal"); + }); + }); + + // SignalR + var signalRServerBuilder = services.AddSignalR().AddMessagePackProtocol(options => + { + options.SerializerOptions = MessagePack.MessagePackSerializerOptions.Standard + .WithResolver(MessagePack.Resolvers.ContractlessStandardResolver.Instance); + }); + if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.RedisConnectionString)) + { + signalRServerBuilder.AddStackExchangeRedis(globalSettings.Notifications.RedisConnectionString, + options => + { + options.Configuration.ChannelPrefix = "Notifications"; + }); } + services.AddSingleton(); + services.AddSingleton(); - public IConfiguration Configuration { get; } - public IWebHostEnvironment Environment { get; set; } + // Mvc + services.AddMvc(); - public void ConfigureServices(IServiceCollection services) + services.AddHostedService(); + if (!globalSettings.SelfHosted) { - // Options - services.AddOptions(); - - // Settings - var globalSettings = services.AddGlobalSettingsServices(Configuration, Environment); - - // Identity - services.AddIdentityAuthenticationServices(globalSettings, Environment, config => + // Hosted Services + Jobs.JobsHostedService.AddJobsServices(services); + services.AddHostedService(); + if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString)) { - config.AddPolicy("Application", policy => - { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.AuthenticationMethod, "Application", "external"); - policy.RequireClaim(JwtClaimTypes.Scope, "api"); - }); - config.AddPolicy("Internal", policy => - { - policy.RequireAuthenticatedUser(); - policy.RequireClaim(JwtClaimTypes.Scope, "internal"); - }); - }); - - // SignalR - var signalRServerBuilder = services.AddSignalR().AddMessagePackProtocol(options => - { - options.SerializerOptions = MessagePack.MessagePackSerializerOptions.Standard - .WithResolver(MessagePack.Resolvers.ContractlessStandardResolver.Instance); - }); - if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.RedisConnectionString)) - { - signalRServerBuilder.AddStackExchangeRedis(globalSettings.Notifications.RedisConnectionString, - options => - { - options.Configuration.ChannelPrefix = "Notifications"; - }); + services.AddHostedService(); } - services.AddSingleton(); - services.AddSingleton(); - - // Mvc - services.AddMvc(); - - services.AddHostedService(); - if (!globalSettings.SelfHosted) - { - // Hosted Services - Jobs.JobsHostedService.AddJobsServices(services); - services.AddHostedService(); - if (CoreHelpers.SettingHasValue(globalSettings.Notifications?.ConnectionString)) - { - services.AddHostedService(); - } - } - } - - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings) - { - IdentityModelEventSource.ShowPII = true; - app.UseSerilog(env, appLifetime, globalSettings); - - // Add general security headers - app.UseMiddleware(); - - if (env.IsDevelopment()) - { - app.UseDeveloperExceptionPage(); - } - - // Add routing - app.UseRouting(); - - // Add Cors - app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) - .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); - - // Add authentication to the request pipeline. - app.UseAuthentication(); - app.UseAuthorization(); - - // Add endpoints to the request pipeline. - app.UseEndpoints(endpoints => - { - endpoints.MapHub("/hub", options => - { - options.ApplicationMaxBufferSize = 2048; // client => server messages are not even used - options.TransportMaxBufferSize = 4096; - }); - endpoints.MapDefaultControllerRoute(); - }); } } + + public void Configure( + IApplicationBuilder app, + IWebHostEnvironment env, + IHostApplicationLifetime appLifetime, + GlobalSettings globalSettings) + { + IdentityModelEventSource.ShowPII = true; + app.UseSerilog(env, appLifetime, globalSettings); + + // Add general security headers + app.UseMiddleware(); + + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); + } + + // Add routing + app.UseRouting(); + + // Add Cors + app.UseCors(policy => policy.SetIsOriginAllowed(o => CoreHelpers.IsCorsOriginAllowed(o, globalSettings)) + .AllowAnyMethod().AllowAnyHeader().AllowCredentials()); + + // Add authentication to the request pipeline. + app.UseAuthentication(); + app.UseAuthorization(); + + // Add endpoints to the request pipeline. + app.UseEndpoints(endpoints => + { + endpoints.MapHub("/hub", options => + { + options.ApplicationMaxBufferSize = 2048; // client => server messages are not even used + options.TransportMaxBufferSize = 4096; + }); + endpoints.MapDefaultControllerRoute(); + }); + } } diff --git a/src/Notifications/SubjectUserIdProvider.cs b/src/Notifications/SubjectUserIdProvider.cs index ee6ab6be5..261394d06 100644 --- a/src/Notifications/SubjectUserIdProvider.cs +++ b/src/Notifications/SubjectUserIdProvider.cs @@ -1,13 +1,12 @@ using IdentityModel; using Microsoft.AspNetCore.SignalR; -namespace Bit.Notifications +namespace Bit.Notifications; + +public class SubjectUserIdProvider : IUserIdProvider { - public class SubjectUserIdProvider : IUserIdProvider + public string GetUserId(HubConnectionContext connection) { - public string GetUserId(HubConnectionContext connection) - { - return connection.User?.FindFirst(JwtClaimTypes.Subject)?.Value; - } + return connection.User?.FindFirst(JwtClaimTypes.Subject)?.Value; } } diff --git a/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs b/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs index 2de2a4d73..f43544bca 100644 --- a/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs +++ b/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs @@ -8,84 +8,83 @@ using Microsoft.Extensions.Logging; using Microsoft.IdentityModel.Tokens; using InternalApi = Bit.Core.Models.Api; -namespace Bit.SharedWeb.Utilities +namespace Bit.SharedWeb.Utilities; + +public class ExceptionHandlerFilterAttribute : ExceptionFilterAttribute { - public class ExceptionHandlerFilterAttribute : ExceptionFilterAttribute + public ExceptionHandlerFilterAttribute() { - public ExceptionHandlerFilterAttribute() + } + + public override void OnException(ExceptionContext context) + { + var errorMessage = "An error has occurred."; + + var exception = context.Exception; + if (exception == null) { + // Should never happen. + return; } - public override void OnException(ExceptionContext context) + InternalApi.ErrorResponseModel internalErrorModel = null; + if (exception is BadRequestException badRequestException) { - var errorMessage = "An error has occurred."; - - var exception = context.Exception; - if (exception == null) + context.HttpContext.Response.StatusCode = 400; + if (badRequestException.ModelState != null) { - // Should never happen. - return; - } - - InternalApi.ErrorResponseModel internalErrorModel = null; - if (exception is BadRequestException badRequestException) - { - context.HttpContext.Response.StatusCode = 400; - if (badRequestException.ModelState != null) - { - internalErrorModel = new InternalApi.ErrorResponseModel(badRequestException.ModelState); - } - else - { - errorMessage = badRequestException.Message; - } - } - else if (exception is GatewayException) - { - errorMessage = exception.Message; - context.HttpContext.Response.StatusCode = 400; - } - else if (exception is NotSupportedException && !string.IsNullOrWhiteSpace(exception.Message)) - { - errorMessage = exception.Message; - context.HttpContext.Response.StatusCode = 400; - } - else if (exception is ApplicationException) - { - context.HttpContext.Response.StatusCode = 402; - } - else if (exception is NotFoundException) - { - errorMessage = "Resource not found."; - context.HttpContext.Response.StatusCode = 404; - } - else if (exception is SecurityTokenValidationException) - { - errorMessage = "Invalid token."; - context.HttpContext.Response.StatusCode = 403; - } - else if (exception is UnauthorizedAccessException) - { - errorMessage = "Unauthorized."; - context.HttpContext.Response.StatusCode = 401; + internalErrorModel = new InternalApi.ErrorResponseModel(badRequestException.ModelState); } else { - var logger = context.HttpContext.RequestServices.GetRequiredService>(); - logger.LogError(0, exception, exception.Message); - errorMessage = "An unhandled server error has occurred."; - context.HttpContext.Response.StatusCode = 500; + errorMessage = badRequestException.Message; } - - var errorModel = internalErrorModel ?? new InternalApi.ErrorResponseModel(errorMessage); - var env = context.HttpContext.RequestServices.GetRequiredService(); - if (env.IsDevelopment()) - { - errorModel.ExceptionMessage = exception.Message; - errorModel.ExceptionStackTrace = exception.StackTrace; - errorModel.InnerExceptionMessage = exception?.InnerException?.Message; - } - context.Result = new ObjectResult(errorModel); } + else if (exception is GatewayException) + { + errorMessage = exception.Message; + context.HttpContext.Response.StatusCode = 400; + } + else if (exception is NotSupportedException && !string.IsNullOrWhiteSpace(exception.Message)) + { + errorMessage = exception.Message; + context.HttpContext.Response.StatusCode = 400; + } + else if (exception is ApplicationException) + { + context.HttpContext.Response.StatusCode = 402; + } + else if (exception is NotFoundException) + { + errorMessage = "Resource not found."; + context.HttpContext.Response.StatusCode = 404; + } + else if (exception is SecurityTokenValidationException) + { + errorMessage = "Invalid token."; + context.HttpContext.Response.StatusCode = 403; + } + else if (exception is UnauthorizedAccessException) + { + errorMessage = "Unauthorized."; + context.HttpContext.Response.StatusCode = 401; + } + else + { + var logger = context.HttpContext.RequestServices.GetRequiredService>(); + logger.LogError(0, exception, exception.Message); + errorMessage = "An unhandled server error has occurred."; + context.HttpContext.Response.StatusCode = 500; + } + + var errorModel = internalErrorModel ?? new InternalApi.ErrorResponseModel(errorMessage); + var env = context.HttpContext.RequestServices.GetRequiredService(); + if (env.IsDevelopment()) + { + errorModel.ExceptionMessage = exception.Message; + errorModel.ExceptionStackTrace = exception.StackTrace; + errorModel.InnerExceptionMessage = exception?.InnerException?.Message; + } + context.Result = new ObjectResult(errorModel); } } diff --git a/src/SharedWeb/Utilities/ModelStateValidationFilterAttribute.cs b/src/SharedWeb/Utilities/ModelStateValidationFilterAttribute.cs index 11d642f32..c4dfbfb89 100644 --- a/src/SharedWeb/Utilities/ModelStateValidationFilterAttribute.cs +++ b/src/SharedWeb/Utilities/ModelStateValidationFilterAttribute.cs @@ -2,31 +2,30 @@ using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.Filters; -namespace Bit.SharedWeb.Utilities +namespace Bit.SharedWeb.Utilities; + +public class ModelStateValidationFilterAttribute : ActionFilterAttribute { - public class ModelStateValidationFilterAttribute : ActionFilterAttribute + public ModelStateValidationFilterAttribute() { - public ModelStateValidationFilterAttribute() + } + + public override void OnActionExecuting(ActionExecutingContext context) + { + var model = context.ActionArguments.FirstOrDefault(a => a.Key == "model"); + if (model.Key == "model" && model.Value == null) { + context.ModelState.AddModelError(string.Empty, "Body is empty."); } - public override void OnActionExecuting(ActionExecutingContext context) + if (!context.ModelState.IsValid) { - var model = context.ActionArguments.FirstOrDefault(a => a.Key == "model"); - if (model.Key == "model" && model.Value == null) - { - context.ModelState.AddModelError(string.Empty, "Body is empty."); - } - - if (!context.ModelState.IsValid) - { - OnModelStateInvalid(context); - } - } - - protected virtual void OnModelStateInvalid(ActionExecutingContext context) - { - context.Result = new BadRequestObjectResult(new ErrorResponseModel(context.ModelState)); + OnModelStateInvalid(context); } } + + protected virtual void OnModelStateInvalid(ActionExecutingContext context) + { + context.Result = new BadRequestObjectResult(new ErrorResponseModel(context.ModelState)); + } } diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index 9102a3a18..b2efe511d 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -43,613 +43,612 @@ using NoopRepos = Bit.Core.Repositories.Noop; using Role = Bit.Core.Entities.Role; using TableStorageRepos = Bit.Core.Repositories.TableStorage; -namespace Bit.SharedWeb.Utilities +namespace Bit.SharedWeb.Utilities; + +public static class ServiceCollectionExtensions { - public static class ServiceCollectionExtensions + public static void AddSqlServerRepositories(this IServiceCollection services, GlobalSettings globalSettings) { - public static void AddSqlServerRepositories(this IServiceCollection services, GlobalSettings globalSettings) + var selectedDatabaseProvider = globalSettings.DatabaseProvider; + var provider = SupportedDatabaseProviders.SqlServer; + var connectionString = string.Empty; + if (!string.IsNullOrWhiteSpace(selectedDatabaseProvider)) { - var selectedDatabaseProvider = globalSettings.DatabaseProvider; - var provider = SupportedDatabaseProviders.SqlServer; - var connectionString = string.Empty; - if (!string.IsNullOrWhiteSpace(selectedDatabaseProvider)) + switch (selectedDatabaseProvider.ToLowerInvariant()) { - switch (selectedDatabaseProvider.ToLowerInvariant()) - { - case "postgres": - case "postgresql": - provider = SupportedDatabaseProviders.Postgres; - connectionString = globalSettings.PostgreSql.ConnectionString; - break; - case "mysql": - case "mariadb": - provider = SupportedDatabaseProviders.MySql; - connectionString = globalSettings.MySql.ConnectionString; - break; - default: - break; - } - } - - var useEf = (provider != SupportedDatabaseProviders.SqlServer); - - if (useEf) - { - services.AddEFRepositories(globalSettings.SelfHosted, connectionString, provider); - } - else - { - services.AddDapperRepositories(globalSettings.SelfHosted); - } - - if (globalSettings.SelfHosted) - { - services.AddSingleton(); - services.AddSingleton(); - } - else - { - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); + case "postgres": + case "postgresql": + provider = SupportedDatabaseProviders.Postgres; + connectionString = globalSettings.PostgreSql.ConnectionString; + break; + case "mysql": + case "mariadb": + provider = SupportedDatabaseProviders.MySql; + connectionString = globalSettings.MySql.ConnectionString; + break; + default: + break; } } - public static void AddBaseServices(this IServiceCollection services, IGlobalSettings globalSettings) + var useEf = (provider != SupportedDatabaseProviders.SqlServer); + + if (useEf) { - services.AddScoped(); - services.AddScoped(); - services.AddOrganizationServices(globalSettings); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddSingleton(); - services.AddSingleton(); - services.AddScoped(); - services.AddScoped(); + services.AddEFRepositories(globalSettings.SelfHosted, connectionString, provider); + } + else + { + services.AddDapperRepositories(globalSettings.SelfHosted); } - public static void AddTokenizers(this IServiceCollection services) + if (globalSettings.SelfHosted) { - services.AddSingleton>(serviceProvider => - new DataProtectorTokenFactory( - EmergencyAccessInviteTokenable.ClearTextPrefix, - EmergencyAccessInviteTokenable.DataProtectorPurpose, - serviceProvider.GetDataProtectionProvider()) - ); - services.AddSingleton>(serviceProvider => - new DataProtectorTokenFactory( - HCaptchaTokenable.ClearTextPrefix, - HCaptchaTokenable.DataProtectorPurpose, - serviceProvider.GetDataProtectionProvider()) - ); - services.AddSingleton>(serviceProvider => - new DataProtectorTokenFactory( - SsoTokenable.ClearTextPrefix, - SsoTokenable.DataProtectorPurpose, - serviceProvider.GetDataProtectionProvider())); + services.AddSingleton(); + services.AddSingleton(); + } + else + { + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + } + } + + public static void AddBaseServices(this IServiceCollection services, IGlobalSettings globalSettings) + { + services.AddScoped(); + services.AddScoped(); + services.AddOrganizationServices(globalSettings); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddSingleton(); + services.AddSingleton(); + services.AddScoped(); + services.AddScoped(); + } + + public static void AddTokenizers(this IServiceCollection services) + { + services.AddSingleton>(serviceProvider => + new DataProtectorTokenFactory( + EmergencyAccessInviteTokenable.ClearTextPrefix, + EmergencyAccessInviteTokenable.DataProtectorPurpose, + serviceProvider.GetDataProtectionProvider()) + ); + services.AddSingleton>(serviceProvider => + new DataProtectorTokenFactory( + HCaptchaTokenable.ClearTextPrefix, + HCaptchaTokenable.DataProtectorPurpose, + serviceProvider.GetDataProtectionProvider()) + ); + services.AddSingleton>(serviceProvider => + new DataProtectorTokenFactory( + SsoTokenable.ClearTextPrefix, + SsoTokenable.DataProtectorPurpose, + serviceProvider.GetDataProtectionProvider())); + } + + public static void AddDefaultServices(this IServiceCollection services, GlobalSettings globalSettings) + { + // Required for UserService + services.AddWebAuthn(globalSettings); + // Required for HTTP calls + services.AddHttpClient(); + + services.AddSingleton(); + services.AddSingleton((serviceProvider) => + { + return new Braintree.BraintreeGateway + { + Environment = globalSettings.Braintree.Production ? + Braintree.Environment.PRODUCTION : Braintree.Environment.SANDBOX, + MerchantId = globalSettings.Braintree.MerchantId, + PublicKey = globalSettings.Braintree.PublicKey, + PrivateKey = globalSettings.Braintree.PrivateKey + }; + }); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddTokenizers(); + + if (CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ConnectionString) && + CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ApplicationCacheTopicName)) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); } - public static void AddDefaultServices(this IServiceCollection services, GlobalSettings globalSettings) + var awsConfigured = CoreHelpers.SettingHasValue(globalSettings.Amazon?.AccessKeySecret); + if (awsConfigured && CoreHelpers.SettingHasValue(globalSettings.Mail?.SendGridApiKey)) { - // Required for UserService - services.AddWebAuthn(globalSettings); - // Required for HTTP calls - services.AddHttpClient(); - - services.AddSingleton(); - services.AddSingleton((serviceProvider) => - { - return new Braintree.BraintreeGateway - { - Environment = globalSettings.Braintree.Production ? - Braintree.Environment.PRODUCTION : Braintree.Environment.SANDBOX, - MerchantId = globalSettings.Braintree.MerchantId, - PublicKey = globalSettings.Braintree.PublicKey, - PrivateKey = globalSettings.Braintree.PrivateKey - }; - }); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); - services.AddTokenizers(); - - if (CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ConnectionString) && - CoreHelpers.SettingHasValue(globalSettings.ServiceBus.ApplicationCacheTopicName)) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } - - var awsConfigured = CoreHelpers.SettingHasValue(globalSettings.Amazon?.AccessKeySecret); - if (awsConfigured && CoreHelpers.SettingHasValue(globalSettings.Mail?.SendGridApiKey)) - { - services.AddSingleton(); - } - else if (awsConfigured) - { - services.AddSingleton(); - } - else if (CoreHelpers.SettingHasValue(globalSettings.Mail?.Smtp?.Host)) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } - - services.AddSingleton(); - if (globalSettings.SelfHosted && - CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) && - globalSettings.Installation?.Id != null && - CoreHelpers.SettingHasValue(globalSettings.Installation?.Key)) - { - services.AddSingleton(); - } - else if (!globalSettings.SelfHosted) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } - - if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Storage?.ConnectionString)) - { - services.AddSingleton(); - } - else if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Amazon?.AccessKeySecret)) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } - - if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Mail.ConnectionString)) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } - - if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Events.ConnectionString)) - { - services.AddSingleton(); - } - else if (globalSettings.SelfHosted) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } - - if (CoreHelpers.SettingHasValue(globalSettings.Attachment.ConnectionString)) - { - services.AddSingleton(); - } - else if (CoreHelpers.SettingHasValue(globalSettings.Attachment.BaseDirectory)) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } - - if (CoreHelpers.SettingHasValue(globalSettings.Send.ConnectionString)) - { - services.AddSingleton(); - } - else if (CoreHelpers.SettingHasValue(globalSettings.Send.BaseDirectory)) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } - - if (globalSettings.SelfHosted) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } - - if (CoreHelpers.SettingHasValue(globalSettings.Captcha?.HCaptchaSecretKey) && - CoreHelpers.SettingHasValue(globalSettings.Captcha?.HCaptchaSiteKey)) - { - services.AddSingleton(); - } - else - { - services.AddSingleton(); - } + services.AddSingleton(); } - - public static void AddOosServices(this IServiceCollection services) + else if (awsConfigured) { - services.AddScoped(); + services.AddSingleton(); } - - public static void AddNoopServices(this IServiceCollection services) + else if (CoreHelpers.SettingHasValue(globalSettings.Mail?.Smtp?.Host)) + { + services.AddSingleton(); + } + else { - services.AddSingleton(); services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); + } + + services.AddSingleton(); + if (globalSettings.SelfHosted && + CoreHelpers.SettingHasValue(globalSettings.PushRelayBaseUri) && + globalSettings.Installation?.Id != null && + CoreHelpers.SettingHasValue(globalSettings.Installation?.Key)) + { + services.AddSingleton(); + } + else if (!globalSettings.SelfHosted) + { + services.AddSingleton(); + } + else + { services.AddSingleton(); - services.AddSingleton(); - services.AddSingleton(); + } + + if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Storage?.ConnectionString)) + { + services.AddSingleton(); + } + else if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Amazon?.AccessKeySecret)) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); + } + + if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Mail.ConnectionString)) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); + } + + if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Events.ConnectionString)) + { + services.AddSingleton(); + } + else if (globalSettings.SelfHosted) + { + services.AddSingleton(); + } + else + { services.AddSingleton(); } - public static IdentityBuilder AddCustomIdentityServices( - this IServiceCollection services, GlobalSettings globalSettings) + if (CoreHelpers.SettingHasValue(globalSettings.Attachment.ConnectionString)) { - services.AddSingleton(); - services.Configure(options => options.IterationCount = 100000); - services.Configure(options => - { - options.TokenLifespan = TimeSpan.FromDays(30); - }); - - var identityBuilder = services.AddIdentityWithoutCookieAuth(options => - { - options.User = new UserOptions - { - RequireUniqueEmail = true, - AllowedUserNameCharacters = null // all - }; - options.Password = new PasswordOptions - { - RequireDigit = false, - RequireLowercase = false, - RequiredLength = 8, - RequireNonAlphanumeric = false, - RequireUppercase = false - }; - options.ClaimsIdentity = new ClaimsIdentityOptions - { - SecurityStampClaimType = "sstamp", - UserNameClaimType = JwtClaimTypes.Email, - UserIdClaimType = JwtClaimTypes.Subject - }; - options.Tokens.ChangeEmailTokenProvider = TokenOptions.DefaultEmailProvider; - }); - - identityBuilder - .AddUserStore() - .AddRoleStore() - .AddTokenProvider>(TokenOptions.DefaultProvider) - .AddTokenProvider( - CoreHelpers.CustomProviderName(TwoFactorProviderType.Authenticator)) - .AddTokenProvider( - CoreHelpers.CustomProviderName(TwoFactorProviderType.Email)) - .AddTokenProvider( - CoreHelpers.CustomProviderName(TwoFactorProviderType.YubiKey)) - .AddTokenProvider( - CoreHelpers.CustomProviderName(TwoFactorProviderType.Duo)) - .AddTokenProvider( - CoreHelpers.CustomProviderName(TwoFactorProviderType.Remember)) - .AddTokenProvider>(TokenOptions.DefaultEmailProvider) - .AddTokenProvider( - CoreHelpers.CustomProviderName(TwoFactorProviderType.WebAuthn)); - - return identityBuilder; + services.AddSingleton(); + } + else if (CoreHelpers.SettingHasValue(globalSettings.Attachment.BaseDirectory)) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); } - public static Tuple AddPasswordlessIdentityServices( - this IServiceCollection services, GlobalSettings globalSettings) where TUserStore : class + if (CoreHelpers.SettingHasValue(globalSettings.Send.ConnectionString)) { - services.TryAddTransient(); - services.Configure(options => - { - options.TokenLifespan = TimeSpan.FromMinutes(15); - }); - - var passwordlessIdentityBuilder = services.AddIdentity() - .AddUserStore() - .AddRoleStore() - .AddDefaultTokenProviders(); - - var regularIdentityBuilder = services.AddIdentityCore() - .AddUserStore(); - - services.TryAddScoped, PasswordlessSignInManager>(); - - services.ConfigureApplicationCookie(options => - { - options.LoginPath = "/login"; - options.LogoutPath = "/"; - options.AccessDeniedPath = "/login?accessDenied=true"; - options.Cookie.Name = $"Bitwarden_{globalSettings.ProjectName}"; - options.Cookie.HttpOnly = true; - options.ExpireTimeSpan = TimeSpan.FromDays(2); - options.ReturnUrlParameter = "returnUrl"; - options.SlidingExpiration = true; - }); - - return new Tuple(passwordlessIdentityBuilder, regularIdentityBuilder); + services.AddSingleton(); + } + else if (CoreHelpers.SettingHasValue(globalSettings.Send.BaseDirectory)) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); } - public static void AddIdentityAuthenticationServices( - this IServiceCollection services, GlobalSettings globalSettings, IWebHostEnvironment environment, - Action addAuthorization) + if (globalSettings.SelfHosted) { - services - .AddAuthentication(IdentityServerAuthenticationDefaults.AuthenticationScheme) - .AddIdentityServerAuthentication(options => - { - options.Authority = globalSettings.BaseServiceUri.InternalIdentity; - options.RequireHttpsMetadata = !environment.IsDevelopment() && - globalSettings.BaseServiceUri.InternalIdentity.StartsWith("https"); - options.TokenRetriever = TokenRetrieval.FromAuthorizationHeaderOrQueryString(); - options.NameClaimType = ClaimTypes.Email; - options.SupportedTokens = SupportedTokens.Jwt; - }); - - if (addAuthorization != null) - { - services.AddAuthorization(config => - { - addAuthorization.Invoke(config); - }); - } - - if (environment.IsDevelopment()) - { - Microsoft.IdentityModel.Logging.IdentityModelEventSource.ShowPII = true; - } + services.AddSingleton(); + } + else + { + services.AddSingleton(); } - public static void AddCustomDataProtectionServices( - this IServiceCollection services, IWebHostEnvironment env, GlobalSettings globalSettings) + if (CoreHelpers.SettingHasValue(globalSettings.Captcha?.HCaptchaSecretKey) && + CoreHelpers.SettingHasValue(globalSettings.Captcha?.HCaptchaSiteKey)) { - var builder = services.AddDataProtection(options => options.ApplicationDiscriminator = "Bitwarden"); - if (env.IsDevelopment()) - { - return; - } - - if (globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.DataProtection.Directory)) - { - builder.PersistKeysToFileSystem(new DirectoryInfo(globalSettings.DataProtection.Directory)); - } - - if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Storage?.ConnectionString)) - { - X509Certificate2 dataProtectionCert = null; - if (CoreHelpers.SettingHasValue(globalSettings.DataProtection.CertificateThumbprint)) - { - dataProtectionCert = CoreHelpers.GetCertificate( - globalSettings.DataProtection.CertificateThumbprint); - } - else if (CoreHelpers.SettingHasValue(globalSettings.DataProtection.CertificatePassword)) - { - dataProtectionCert = CoreHelpers.GetBlobCertificateAsync(globalSettings.Storage.ConnectionString, "certificates", - "dataprotection.pfx", globalSettings.DataProtection.CertificatePassword) - .GetAwaiter().GetResult(); - } - //TODO djsmith85 Check if this is the correct container name - builder - .PersistKeysToAzureBlobStorage(globalSettings.Storage.ConnectionString, "aspnet-dataprotection", "keys.xml") - .ProtectKeysWithCertificate(dataProtectionCert); - } + services.AddSingleton(); } - - public static IIdentityServerBuilder AddIdentityServerCertificate( - this IIdentityServerBuilder identityServerBuilder, IWebHostEnvironment env, GlobalSettings globalSettings) + else { - var certificate = CoreHelpers.GetIdentityServerCertificate(globalSettings); - if (certificate != null) - { - identityServerBuilder.AddSigningCredential(certificate); - } - else if (env.IsDevelopment()) - { - identityServerBuilder.AddDeveloperSigningCredential(false); - } - else - { - throw new Exception("No identity certificate to use."); - } - return identityServerBuilder; - } - - public static GlobalSettings AddGlobalSettingsServices(this IServiceCollection services, - IConfiguration configuration, IWebHostEnvironment environment) - { - var globalSettings = new GlobalSettings(); - ConfigurationBinder.Bind(configuration.GetSection("GlobalSettings"), globalSettings); - - if (environment.IsDevelopment() && configuration.GetValue("developSelfHosted")) - { - // Override settings with selfHostedOverride settings - ConfigurationBinder.Bind(configuration.GetSection("Dev:SelfHostOverride:GlobalSettings"), globalSettings); - } - - services.AddSingleton(s => globalSettings); - services.AddSingleton(s => globalSettings); - return globalSettings; - } - - public static void UseDefaultMiddleware(this IApplicationBuilder app, - IWebHostEnvironment env, GlobalSettings globalSettings) - { - string GetHeaderValue(HttpContext httpContext, string header) - { - if (httpContext.Request.Headers.ContainsKey(header)) - { - return httpContext.Request.Headers[header]; - } - return null; - } - - // Add version information to response headers - app.Use(async (httpContext, next) => - { - using (LogContext.PushProperty("IPAddress", httpContext.GetIpAddress(globalSettings))) - using (LogContext.PushProperty("UserAgent", GetHeaderValue(httpContext, "user-agent"))) - using (LogContext.PushProperty("DeviceType", GetHeaderValue(httpContext, "device-type"))) - using (LogContext.PushProperty("Origin", GetHeaderValue(httpContext, "origin"))) - { - httpContext.Response.OnStarting((state) => - { - httpContext.Response.Headers.Append("Server-Version", CoreHelpers.GetVersion()); - return Task.FromResult(0); - }, null); - await next.Invoke(); - } - }); - } - - public static void UseForwardedHeaders(this IApplicationBuilder app, GlobalSettings globalSettings) - { - var options = new ForwardedHeadersOptions - { - ForwardedHeaders = ForwardedHeaders.XForwardedFor | ForwardedHeaders.XForwardedProto - }; - if (!string.IsNullOrWhiteSpace(globalSettings.KnownProxies)) - { - var proxies = globalSettings.KnownProxies.Split(','); - foreach (var proxy in proxies) - { - if (System.Net.IPAddress.TryParse(proxy.Trim(), out var ip)) - { - options.KnownProxies.Add(ip); - } - } - } - if (options.KnownProxies.Count > 1) - { - options.ForwardLimit = null; - } - app.UseForwardedHeaders(options); - } - - public static void AddCoreLocalizationServices(this IServiceCollection services) - { - services.AddTransient(); - services.AddLocalization(options => options.ResourcesPath = "Resources"); - } - - public static IApplicationBuilder UseCoreLocalization(this IApplicationBuilder app) - { - var supportedCultures = new[] { "en" }; - return app.UseRequestLocalization(options => options - .SetDefaultCulture(supportedCultures[0]) - .AddSupportedCultures(supportedCultures) - .AddSupportedUICultures(supportedCultures)); - } - - public static IMvcBuilder AddViewAndDataAnnotationLocalization(this IMvcBuilder mvc) - { - mvc.Services.AddTransient(); - return mvc.AddViewLocalization(options => options.ResourcesPath = "Resources") - .AddDataAnnotationsLocalization(options => - options.DataAnnotationLocalizerProvider = (type, factory) => - { - var assemblyName = new AssemblyName(typeof(SharedResources).GetTypeInfo().Assembly.FullName); - return factory.Create("SharedResources", assemblyName.Name); - }); - } - - public static IServiceCollection AddDistributedIdentityServices(this IServiceCollection services, GlobalSettings globalSettings) - { - services.AddOidcStateDataFormatterCache(); - services.AddSession(); - services.ConfigureApplicationCookie(configure => configure.CookieManager = new DistributedCacheCookieManager()); - services.ConfigureExternalCookie(configure => configure.CookieManager = new DistributedCacheCookieManager()); - services.AddSingleton>( - svcs => new ConfigureOpenIdConnectDistributedOptions( - svcs.GetRequiredService(), - globalSettings, - svcs.GetRequiredService()) - ); - - return services; - } - - public static void AddWebAuthn(this IServiceCollection services, GlobalSettings globalSettings) - { - services.AddFido2(options => - { - options.ServerDomain = new Uri(globalSettings.BaseServiceUri.Vault).Host; - options.ServerName = "Bitwarden"; - options.Origins = new HashSet { globalSettings.BaseServiceUri.Vault, }; - options.TimestampDriftTolerance = 300000; - }); - } - - /// - /// Adds either an in-memory or distributed IP rate limiter depending if a Redis connection string is available. - /// - /// - /// - public static void AddIpRateLimiting(this IServiceCollection services, - GlobalSettings globalSettings) - { - services.AddHostedService(); - services.AddSingleton(); - - if (string.IsNullOrEmpty(globalSettings.Redis.ConnectionString)) - { - services.AddInMemoryRateLimiting(); - } - else - { - services.AddRedisRateLimiting(); // Requires a registered IConnectionMultiplexer - } - } - - /// - /// Adds an implementation of to the service collection. Uses a memory - /// cache if self hosted or no Redis connection string is available in GlobalSettings. - /// - public static void AddDistributedCache( - this IServiceCollection services, - GlobalSettings globalSettings) - { - if (globalSettings.SelfHosted || string.IsNullOrEmpty(globalSettings.Redis.ConnectionString)) - { - services.AddDistributedMemoryCache(); - return; - } - - // Register the IConnectionMultiplexer explicitly so it can be accessed via DI - // (e.g. for the IP rate limiting store) - services.AddSingleton( - _ => ConnectionMultiplexer.Connect(globalSettings.Redis.ConnectionString)); - - // Explicitly register IDistributedCache to re-use existing IConnectionMultiplexer - // to reduce the number of redundant connections to the Redis instance - services.AddSingleton(s => - { - return new RedisCache(new RedisCacheOptions - { - // Use "ProjectName:" as an instance name to namespace keys and avoid conflicts between projects - InstanceName = $"{globalSettings.ProjectName}:", - ConnectionMultiplexerFactory = () => - Task.FromResult(s.GetRequiredService()) - }); - }); + services.AddSingleton(); } } + + public static void AddOosServices(this IServiceCollection services) + { + services.AddScoped(); + } + + public static void AddNoopServices(this IServiceCollection services) + { + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + } + + public static IdentityBuilder AddCustomIdentityServices( + this IServiceCollection services, GlobalSettings globalSettings) + { + services.AddSingleton(); + services.Configure(options => options.IterationCount = 100000); + services.Configure(options => + { + options.TokenLifespan = TimeSpan.FromDays(30); + }); + + var identityBuilder = services.AddIdentityWithoutCookieAuth(options => + { + options.User = new UserOptions + { + RequireUniqueEmail = true, + AllowedUserNameCharacters = null // all + }; + options.Password = new PasswordOptions + { + RequireDigit = false, + RequireLowercase = false, + RequiredLength = 8, + RequireNonAlphanumeric = false, + RequireUppercase = false + }; + options.ClaimsIdentity = new ClaimsIdentityOptions + { + SecurityStampClaimType = "sstamp", + UserNameClaimType = JwtClaimTypes.Email, + UserIdClaimType = JwtClaimTypes.Subject + }; + options.Tokens.ChangeEmailTokenProvider = TokenOptions.DefaultEmailProvider; + }); + + identityBuilder + .AddUserStore() + .AddRoleStore() + .AddTokenProvider>(TokenOptions.DefaultProvider) + .AddTokenProvider( + CoreHelpers.CustomProviderName(TwoFactorProviderType.Authenticator)) + .AddTokenProvider( + CoreHelpers.CustomProviderName(TwoFactorProviderType.Email)) + .AddTokenProvider( + CoreHelpers.CustomProviderName(TwoFactorProviderType.YubiKey)) + .AddTokenProvider( + CoreHelpers.CustomProviderName(TwoFactorProviderType.Duo)) + .AddTokenProvider( + CoreHelpers.CustomProviderName(TwoFactorProviderType.Remember)) + .AddTokenProvider>(TokenOptions.DefaultEmailProvider) + .AddTokenProvider( + CoreHelpers.CustomProviderName(TwoFactorProviderType.WebAuthn)); + + return identityBuilder; + } + + public static Tuple AddPasswordlessIdentityServices( + this IServiceCollection services, GlobalSettings globalSettings) where TUserStore : class + { + services.TryAddTransient(); + services.Configure(options => + { + options.TokenLifespan = TimeSpan.FromMinutes(15); + }); + + var passwordlessIdentityBuilder = services.AddIdentity() + .AddUserStore() + .AddRoleStore() + .AddDefaultTokenProviders(); + + var regularIdentityBuilder = services.AddIdentityCore() + .AddUserStore(); + + services.TryAddScoped, PasswordlessSignInManager>(); + + services.ConfigureApplicationCookie(options => + { + options.LoginPath = "/login"; + options.LogoutPath = "/"; + options.AccessDeniedPath = "/login?accessDenied=true"; + options.Cookie.Name = $"Bitwarden_{globalSettings.ProjectName}"; + options.Cookie.HttpOnly = true; + options.ExpireTimeSpan = TimeSpan.FromDays(2); + options.ReturnUrlParameter = "returnUrl"; + options.SlidingExpiration = true; + }); + + return new Tuple(passwordlessIdentityBuilder, regularIdentityBuilder); + } + + public static void AddIdentityAuthenticationServices( + this IServiceCollection services, GlobalSettings globalSettings, IWebHostEnvironment environment, + Action addAuthorization) + { + services + .AddAuthentication(IdentityServerAuthenticationDefaults.AuthenticationScheme) + .AddIdentityServerAuthentication(options => + { + options.Authority = globalSettings.BaseServiceUri.InternalIdentity; + options.RequireHttpsMetadata = !environment.IsDevelopment() && + globalSettings.BaseServiceUri.InternalIdentity.StartsWith("https"); + options.TokenRetriever = TokenRetrieval.FromAuthorizationHeaderOrQueryString(); + options.NameClaimType = ClaimTypes.Email; + options.SupportedTokens = SupportedTokens.Jwt; + }); + + if (addAuthorization != null) + { + services.AddAuthorization(config => + { + addAuthorization.Invoke(config); + }); + } + + if (environment.IsDevelopment()) + { + Microsoft.IdentityModel.Logging.IdentityModelEventSource.ShowPII = true; + } + } + + public static void AddCustomDataProtectionServices( + this IServiceCollection services, IWebHostEnvironment env, GlobalSettings globalSettings) + { + var builder = services.AddDataProtection(options => options.ApplicationDiscriminator = "Bitwarden"); + if (env.IsDevelopment()) + { + return; + } + + if (globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.DataProtection.Directory)) + { + builder.PersistKeysToFileSystem(new DirectoryInfo(globalSettings.DataProtection.Directory)); + } + + if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Storage?.ConnectionString)) + { + X509Certificate2 dataProtectionCert = null; + if (CoreHelpers.SettingHasValue(globalSettings.DataProtection.CertificateThumbprint)) + { + dataProtectionCert = CoreHelpers.GetCertificate( + globalSettings.DataProtection.CertificateThumbprint); + } + else if (CoreHelpers.SettingHasValue(globalSettings.DataProtection.CertificatePassword)) + { + dataProtectionCert = CoreHelpers.GetBlobCertificateAsync(globalSettings.Storage.ConnectionString, "certificates", + "dataprotection.pfx", globalSettings.DataProtection.CertificatePassword) + .GetAwaiter().GetResult(); + } + //TODO djsmith85 Check if this is the correct container name + builder + .PersistKeysToAzureBlobStorage(globalSettings.Storage.ConnectionString, "aspnet-dataprotection", "keys.xml") + .ProtectKeysWithCertificate(dataProtectionCert); + } + } + + public static IIdentityServerBuilder AddIdentityServerCertificate( + this IIdentityServerBuilder identityServerBuilder, IWebHostEnvironment env, GlobalSettings globalSettings) + { + var certificate = CoreHelpers.GetIdentityServerCertificate(globalSettings); + if (certificate != null) + { + identityServerBuilder.AddSigningCredential(certificate); + } + else if (env.IsDevelopment()) + { + identityServerBuilder.AddDeveloperSigningCredential(false); + } + else + { + throw new Exception("No identity certificate to use."); + } + return identityServerBuilder; + } + + public static GlobalSettings AddGlobalSettingsServices(this IServiceCollection services, + IConfiguration configuration, IWebHostEnvironment environment) + { + var globalSettings = new GlobalSettings(); + ConfigurationBinder.Bind(configuration.GetSection("GlobalSettings"), globalSettings); + + if (environment.IsDevelopment() && configuration.GetValue("developSelfHosted")) + { + // Override settings with selfHostedOverride settings + ConfigurationBinder.Bind(configuration.GetSection("Dev:SelfHostOverride:GlobalSettings"), globalSettings); + } + + services.AddSingleton(s => globalSettings); + services.AddSingleton(s => globalSettings); + return globalSettings; + } + + public static void UseDefaultMiddleware(this IApplicationBuilder app, + IWebHostEnvironment env, GlobalSettings globalSettings) + { + string GetHeaderValue(HttpContext httpContext, string header) + { + if (httpContext.Request.Headers.ContainsKey(header)) + { + return httpContext.Request.Headers[header]; + } + return null; + } + + // Add version information to response headers + app.Use(async (httpContext, next) => + { + using (LogContext.PushProperty("IPAddress", httpContext.GetIpAddress(globalSettings))) + using (LogContext.PushProperty("UserAgent", GetHeaderValue(httpContext, "user-agent"))) + using (LogContext.PushProperty("DeviceType", GetHeaderValue(httpContext, "device-type"))) + using (LogContext.PushProperty("Origin", GetHeaderValue(httpContext, "origin"))) + { + httpContext.Response.OnStarting((state) => + { + httpContext.Response.Headers.Append("Server-Version", CoreHelpers.GetVersion()); + return Task.FromResult(0); + }, null); + await next.Invoke(); + } + }); + } + + public static void UseForwardedHeaders(this IApplicationBuilder app, GlobalSettings globalSettings) + { + var options = new ForwardedHeadersOptions + { + ForwardedHeaders = ForwardedHeaders.XForwardedFor | ForwardedHeaders.XForwardedProto + }; + if (!string.IsNullOrWhiteSpace(globalSettings.KnownProxies)) + { + var proxies = globalSettings.KnownProxies.Split(','); + foreach (var proxy in proxies) + { + if (System.Net.IPAddress.TryParse(proxy.Trim(), out var ip)) + { + options.KnownProxies.Add(ip); + } + } + } + if (options.KnownProxies.Count > 1) + { + options.ForwardLimit = null; + } + app.UseForwardedHeaders(options); + } + + public static void AddCoreLocalizationServices(this IServiceCollection services) + { + services.AddTransient(); + services.AddLocalization(options => options.ResourcesPath = "Resources"); + } + + public static IApplicationBuilder UseCoreLocalization(this IApplicationBuilder app) + { + var supportedCultures = new[] { "en" }; + return app.UseRequestLocalization(options => options + .SetDefaultCulture(supportedCultures[0]) + .AddSupportedCultures(supportedCultures) + .AddSupportedUICultures(supportedCultures)); + } + + public static IMvcBuilder AddViewAndDataAnnotationLocalization(this IMvcBuilder mvc) + { + mvc.Services.AddTransient(); + return mvc.AddViewLocalization(options => options.ResourcesPath = "Resources") + .AddDataAnnotationsLocalization(options => + options.DataAnnotationLocalizerProvider = (type, factory) => + { + var assemblyName = new AssemblyName(typeof(SharedResources).GetTypeInfo().Assembly.FullName); + return factory.Create("SharedResources", assemblyName.Name); + }); + } + + public static IServiceCollection AddDistributedIdentityServices(this IServiceCollection services, GlobalSettings globalSettings) + { + services.AddOidcStateDataFormatterCache(); + services.AddSession(); + services.ConfigureApplicationCookie(configure => configure.CookieManager = new DistributedCacheCookieManager()); + services.ConfigureExternalCookie(configure => configure.CookieManager = new DistributedCacheCookieManager()); + services.AddSingleton>( + svcs => new ConfigureOpenIdConnectDistributedOptions( + svcs.GetRequiredService(), + globalSettings, + svcs.GetRequiredService()) + ); + + return services; + } + + public static void AddWebAuthn(this IServiceCollection services, GlobalSettings globalSettings) + { + services.AddFido2(options => + { + options.ServerDomain = new Uri(globalSettings.BaseServiceUri.Vault).Host; + options.ServerName = "Bitwarden"; + options.Origins = new HashSet { globalSettings.BaseServiceUri.Vault, }; + options.TimestampDriftTolerance = 300000; + }); + } + + /// + /// Adds either an in-memory or distributed IP rate limiter depending if a Redis connection string is available. + /// + /// + /// + public static void AddIpRateLimiting(this IServiceCollection services, + GlobalSettings globalSettings) + { + services.AddHostedService(); + services.AddSingleton(); + + if (string.IsNullOrEmpty(globalSettings.Redis.ConnectionString)) + { + services.AddInMemoryRateLimiting(); + } + else + { + services.AddRedisRateLimiting(); // Requires a registered IConnectionMultiplexer + } + } + + /// + /// Adds an implementation of to the service collection. Uses a memory + /// cache if self hosted or no Redis connection string is available in GlobalSettings. + /// + public static void AddDistributedCache( + this IServiceCollection services, + GlobalSettings globalSettings) + { + if (globalSettings.SelfHosted || string.IsNullOrEmpty(globalSettings.Redis.ConnectionString)) + { + services.AddDistributedMemoryCache(); + return; + } + + // Register the IConnectionMultiplexer explicitly so it can be accessed via DI + // (e.g. for the IP rate limiting store) + services.AddSingleton( + _ => ConnectionMultiplexer.Connect(globalSettings.Redis.ConnectionString)); + + // Explicitly register IDistributedCache to re-use existing IConnectionMultiplexer + // to reduce the number of redundant connections to the Redis instance + services.AddSingleton(s => + { + return new RedisCache(new RedisCacheOptions + { + // Use "ProjectName:" as an instance name to namespace keys and avoid conflicts between projects + InstanceName = $"{globalSettings.ProjectName}:", + ConnectionMultiplexerFactory = () => + Task.FromResult(s.GetRequiredService()) + }); + }); + } } diff --git a/test/Api.IntegrationTest/Factories/ApiApplicationFactory.cs b/test/Api.IntegrationTest/Factories/ApiApplicationFactory.cs index 8d678c3f4..693e53083 100644 --- a/test/Api.IntegrationTest/Factories/ApiApplicationFactory.cs +++ b/test/Api.IntegrationTest/Factories/ApiApplicationFactory.cs @@ -3,42 +3,41 @@ using Bit.IntegrationTestCommon.Factories; using IdentityServer4.AccessTokenValidation; using Microsoft.AspNetCore.TestHost; -namespace Bit.Api.IntegrationTest.Factories +namespace Bit.Api.IntegrationTest.Factories; + +public class ApiApplicationFactory : WebApplicationFactoryBase { - public class ApiApplicationFactory : WebApplicationFactoryBase + private readonly IdentityApplicationFactory _identityApplicationFactory; + + public ApiApplicationFactory() { - private readonly IdentityApplicationFactory _identityApplicationFactory; + _identityApplicationFactory = new IdentityApplicationFactory(); + } - public ApiApplicationFactory() + protected override void ConfigureWebHost(IWebHostBuilder builder) + { + base.ConfigureWebHost(builder); + + builder.ConfigureTestServices(services => { - _identityApplicationFactory = new IdentityApplicationFactory(); - } - - protected override void ConfigureWebHost(IWebHostBuilder builder) - { - base.ConfigureWebHost(builder); - - builder.ConfigureTestServices(services => + services.PostConfigure(IdentityServerAuthenticationDefaults.AuthenticationScheme, options => { - services.PostConfigure(IdentityServerAuthenticationDefaults.AuthenticationScheme, options => - { - options.JwtBackChannelHandler = _identityApplicationFactory.Server.CreateHandler(); - }); + options.JwtBackChannelHandler = _identityApplicationFactory.Server.CreateHandler(); }); - } + }); + } - /// - /// Helper for registering and logging in to a new account - /// - public async Task<(string Token, string RefreshToken)> LoginWithNewAccount(string email = "integration-test@bitwarden.com", string masterPasswordHash = "master_password_hash") + /// + /// Helper for registering and logging in to a new account + /// + public async Task<(string Token, string RefreshToken)> LoginWithNewAccount(string email = "integration-test@bitwarden.com", string masterPasswordHash = "master_password_hash") + { + await _identityApplicationFactory.RegisterAsync(new RegisterRequestModel { - await _identityApplicationFactory.RegisterAsync(new RegisterRequestModel - { - Email = email, - MasterPasswordHash = masterPasswordHash, - }); + Email = email, + MasterPasswordHash = masterPasswordHash, + }); - return await _identityApplicationFactory.TokenFromPasswordAsync(email, masterPasswordHash); - } + return await _identityApplicationFactory.TokenFromPasswordAsync(email, masterPasswordHash); } } diff --git a/test/Api.Test/Controllers/AccountsControllerTests.cs b/test/Api.Test/Controllers/AccountsControllerTests.cs index cd33de4c8..0b2747c38 100644 --- a/test/Api.Test/Controllers/AccountsControllerTests.cs +++ b/test/Api.Test/Controllers/AccountsControllerTests.cs @@ -13,413 +13,412 @@ using Microsoft.AspNetCore.Identity; using NSubstitute; using Xunit; -namespace Bit.Api.Test.Controllers +namespace Bit.Api.Test.Controllers; + +public class AccountsControllerTests : IDisposable { - public class AccountsControllerTests : IDisposable + + private readonly AccountsController _sut; + private readonly GlobalSettings _globalSettings; + private readonly ICipherRepository _cipherRepository; + private readonly IFolderRepository _folderRepository; + private readonly IOrganizationService _organizationService; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IPaymentService _paymentService; + private readonly IUserRepository _userRepository; + private readonly IUserService _userService; + private readonly ISendRepository _sendRepository; + private readonly ISendService _sendService; + private readonly IProviderUserRepository _providerUserRepository; + + public AccountsControllerTests() { + _userService = Substitute.For(); + _userRepository = Substitute.For(); + _cipherRepository = Substitute.For(); + _folderRepository = Substitute.For(); + _organizationService = Substitute.For(); + _organizationUserRepository = Substitute.For(); + _providerUserRepository = Substitute.For(); + _paymentService = Substitute.For(); + _globalSettings = new GlobalSettings(); + _sendRepository = Substitute.For(); + _sendService = Substitute.For(); + _sut = new AccountsController( + _globalSettings, + _cipherRepository, + _folderRepository, + _organizationService, + _organizationUserRepository, + _providerUserRepository, + _paymentService, + _userRepository, + _userService, + _sendRepository, + _sendService + ); + } - private readonly AccountsController _sut; - private readonly GlobalSettings _globalSettings; - private readonly ICipherRepository _cipherRepository; - private readonly IFolderRepository _folderRepository; - private readonly IOrganizationService _organizationService; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPaymentService _paymentService; - private readonly IUserRepository _userRepository; - private readonly IUserService _userService; - private readonly ISendRepository _sendRepository; - private readonly ISendService _sendService; - private readonly IProviderUserRepository _providerUserRepository; + public void Dispose() + { + _sut?.Dispose(); + } - public AccountsControllerTests() + [Fact] + public async Task PostPrelogin_WhenUserExists_ShouldReturnUserKdfInfo() + { + var userKdfInfo = new UserKdfInformation { - _userService = Substitute.For(); - _userRepository = Substitute.For(); - _cipherRepository = Substitute.For(); - _folderRepository = Substitute.For(); - _organizationService = Substitute.For(); - _organizationUserRepository = Substitute.For(); - _providerUserRepository = Substitute.For(); - _paymentService = Substitute.For(); - _globalSettings = new GlobalSettings(); - _sendRepository = Substitute.For(); - _sendService = Substitute.For(); - _sut = new AccountsController( - _globalSettings, - _cipherRepository, - _folderRepository, - _organizationService, - _organizationUserRepository, - _providerUserRepository, - _paymentService, - _userRepository, - _userService, - _sendRepository, - _sendService - ); - } + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 5000 + }; + _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(Task.FromResult(userKdfInfo)); - public void Dispose() + var response = await _sut.PostPrelogin(new PreloginRequestModel { Email = "user@example.com" }); + + Assert.Equal(userKdfInfo.Kdf, response.Kdf); + Assert.Equal(userKdfInfo.KdfIterations, response.KdfIterations); + } + + [Fact] + public async Task PostPrelogin_WhenUserDoesNotExist_ShouldDefaultToSha256And100000Iterations() + { + _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(Task.FromResult((UserKdfInformation)null)); + + var response = await _sut.PostPrelogin(new PreloginRequestModel { Email = "user@example.com" }); + + Assert.Equal(KdfType.PBKDF2_SHA256, response.Kdf); + Assert.Equal(100000, response.KdfIterations); + } + + [Fact] + public async Task PostRegister_ShouldRegisterUser() + { + var passwordHash = "abcdef"; + var token = "123456"; + var userGuid = new Guid(); + _userService.RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid) + .Returns(Task.FromResult(IdentityResult.Success)); + var request = new RegisterRequestModel { - _sut?.Dispose(); - } + Name = "Example User", + Email = "user@example.com", + MasterPasswordHash = passwordHash, + MasterPasswordHint = "example", + Token = token, + OrganizationUserId = userGuid + }; - [Fact] - public async Task PostPrelogin_WhenUserExists_ShouldReturnUserKdfInfo() + await _sut.PostRegister(request); + + await _userService.Received(1).RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid); + } + + [Fact] + public async Task PostRegister_WhenUserServiceFails_ShouldThrowBadRequestException() + { + var passwordHash = "abcdef"; + var token = "123456"; + var userGuid = new Guid(); + _userService.RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid) + .Returns(Task.FromResult(IdentityResult.Failed())); + var request = new RegisterRequestModel { - var userKdfInfo = new UserKdfInformation - { - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 5000 - }; - _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(Task.FromResult(userKdfInfo)); + Name = "Example User", + Email = "user@example.com", + MasterPasswordHash = passwordHash, + MasterPasswordHint = "example", + Token = token, + OrganizationUserId = userGuid + }; - var response = await _sut.PostPrelogin(new PreloginRequestModel { Email = "user@example.com" }); + await Assert.ThrowsAsync(() => _sut.PostRegister(request)); + } - Assert.Equal(userKdfInfo.Kdf, response.Kdf); - Assert.Equal(userKdfInfo.KdfIterations, response.KdfIterations); - } + [Fact] + public async Task PostPasswordHint_ShouldNotifyUserService() + { + var email = "user@example.com"; - [Fact] - public async Task PostPrelogin_WhenUserDoesNotExist_ShouldDefaultToSha256And100000Iterations() + await _sut.PostPasswordHint(new PasswordHintRequestModel { Email = email }); + + await _userService.Received(1).SendMasterPasswordHintAsync(email); + } + + [Fact] + public async Task PostEmailToken_ShouldInitiateEmailChange() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + ConfigureUserServiceToAcceptPasswordFor(user); + var newEmail = "example@user.com"; + + await _sut.PostEmailToken(new EmailTokenRequestModel { NewEmail = newEmail }); + + await _userService.Received(1).InitiateEmailChangeAsync(user, newEmail); + } + + [Fact] + public async Task PostEmailToken_WhenNotAuthorized_ShouldThrowUnauthorizedAccessException() + { + ConfigureUserServiceToReturnNullPrincipal(); + + await Assert.ThrowsAsync( + () => _sut.PostEmailToken(new EmailTokenRequestModel()) + ); + } + + [Fact] + public async Task PostEmailToken_WhenInvalidPasssword_ShouldThrowBadRequestException() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + ConfigureUserServiceToRejectPasswordFor(user); + + await Assert.ThrowsAsync( + () => _sut.PostEmailToken(new EmailTokenRequestModel()) + ); + } + + [Fact] + public async Task PostEmail_ShouldChangeUserEmail() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + _userService.ChangeEmailAsync(user, default, default, default, default, default) + .Returns(Task.FromResult(IdentityResult.Success)); + + await _sut.PostEmail(new EmailRequestModel()); + + await _userService.Received(1).ChangeEmailAsync(user, default, default, default, default, default); + } + + [Fact] + public async Task PostEmail_WhenNotAuthorized_ShouldThrownUnauthorizedAccessException() + { + ConfigureUserServiceToReturnNullPrincipal(); + + await Assert.ThrowsAsync( + () => _sut.PostEmail(new EmailRequestModel()) + ); + } + + [Fact] + public async Task PostEmail_WhenEmailCannotBeChanged_ShouldThrowBadRequestException() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + _userService.ChangeEmailAsync(user, default, default, default, default, default) + .Returns(Task.FromResult(IdentityResult.Failed())); + + await Assert.ThrowsAsync( + () => _sut.PostEmail(new EmailRequestModel()) + ); + } + + [Fact] + public async Task PostVerifyEmail_ShouldSendEmailVerification() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + + await _sut.PostVerifyEmail(); + + await _userService.Received(1).SendEmailVerificationAsync(user); + } + + [Fact] + public async Task PostVerifyEmail_WhenNotAuthorized_ShouldThrownUnauthorizedAccessException() + { + ConfigureUserServiceToReturnNullPrincipal(); + + await Assert.ThrowsAsync( + () => _sut.PostVerifyEmail() + ); + } + + [Fact] + public async Task PostVerifyEmailToken_ShouldConfirmEmail() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidIdFor(user); + _userService.ConfirmEmailAsync(user, Arg.Any()) + .Returns(Task.FromResult(IdentityResult.Success)); + + await _sut.PostVerifyEmailToken(new VerifyEmailRequestModel { UserId = "12345678-1234-1234-1234-123456789012" }); + + await _userService.Received(1).ConfirmEmailAsync(user, Arg.Any()); + } + + [Fact] + public async Task PostVerifyEmailToken_WhenUserDoesNotExist_ShouldThrowUnauthorizedAccessException() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnNullUserId(); + + await Assert.ThrowsAsync( + () => _sut.PostVerifyEmailToken(new VerifyEmailRequestModel { UserId = "12345678-1234-1234-1234-123456789012" }) + ); + } + + [Fact] + public async Task PostVerifyEmailToken_WhenEmailConfirmationFails_ShouldThrowBadRequestException() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidIdFor(user); + _userService.ConfirmEmailAsync(user, Arg.Any()) + .Returns(Task.FromResult(IdentityResult.Failed())); + + await Assert.ThrowsAsync( + () => _sut.PostVerifyEmailToken(new VerifyEmailRequestModel { UserId = "12345678-1234-1234-1234-123456789012" }) + ); + } + + [Fact] + public async Task PostPassword_ShouldChangePassword() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + _userService.ChangePasswordAsync(user, default, default, default, default) + .Returns(Task.FromResult(IdentityResult.Success)); + + await _sut.PostPassword(new PasswordRequestModel()); + + await _userService.Received(1).ChangePasswordAsync(user, default, default, default, default); + } + + [Fact] + public async Task PostPassword_WhenNotAuthorized_ShouldThrowUnauthorizedAccessException() + { + ConfigureUserServiceToReturnNullPrincipal(); + + await Assert.ThrowsAsync( + () => _sut.PostPassword(new PasswordRequestModel()) + ); + } + + [Fact] + public async Task PostPassword_WhenPasswordChangeFails_ShouldBadRequestException() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + _userService.ChangePasswordAsync(user, default, default, default, default) + .Returns(Task.FromResult(IdentityResult.Failed())); + + await Assert.ThrowsAsync( + () => _sut.PostPassword(new PasswordRequestModel()) + ); + } + + [Fact] + public async Task GetApiKey_ShouldReturnApiKeyResponse() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + ConfigureUserServiceToAcceptPasswordFor(user); + await _sut.ApiKey(new SecretVerificationRequestModel()); + } + + [Fact] + public async Task GetApiKey_WhenUserDoesNotExist_ShouldThrowUnauthorizedAccessException() + { + ConfigureUserServiceToReturnNullPrincipal(); + + await Assert.ThrowsAsync( + () => _sut.ApiKey(new SecretVerificationRequestModel()) + ); + } + + [Fact] + public async Task GetApiKey_WhenPasswordCheckFails_ShouldThrowBadRequestException() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + ConfigureUserServiceToRejectPasswordFor(user); + await Assert.ThrowsAsync( + () => _sut.ApiKey(new SecretVerificationRequestModel()) + ); + } + + [Fact] + public async Task PostRotateApiKey_ShouldRotateApiKey() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + ConfigureUserServiceToAcceptPasswordFor(user); + await _sut.RotateApiKey(new SecretVerificationRequestModel()); + } + + [Fact] + public async Task PostRotateApiKey_WhenUserDoesNotExist_ShouldThrowUnauthorizedAccessException() + { + ConfigureUserServiceToReturnNullPrincipal(); + + await Assert.ThrowsAsync( + () => _sut.ApiKey(new SecretVerificationRequestModel()) + ); + } + + [Fact] + public async Task PostRotateApiKey_WhenPasswordCheckFails_ShouldThrowBadRequestException() + { + var user = GenerateExampleUser(); + ConfigureUserServiceToReturnValidPrincipalFor(user); + ConfigureUserServiceToRejectPasswordFor(user); + await Assert.ThrowsAsync( + () => _sut.ApiKey(new SecretVerificationRequestModel()) + ); + } + + // Below are helper functions that currently belong to this + // test class, but ultimately may need to be split out into + // something greater in order to share common test steps with + // other test suites. They are included here for the time being + // until that day comes. + private User GenerateExampleUser() + { + return new User { - _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(Task.FromResult((UserKdfInformation)null)); + Email = "user@example.com" + }; + } - var response = await _sut.PostPrelogin(new PreloginRequestModel { Email = "user@example.com" }); + private void ConfigureUserServiceToReturnNullPrincipal() + { + _userService.GetUserByPrincipalAsync(Arg.Any()) + .Returns(Task.FromResult((User)null)); + } - Assert.Equal(KdfType.PBKDF2_SHA256, response.Kdf); - Assert.Equal(100000, response.KdfIterations); - } + private void ConfigureUserServiceToReturnValidPrincipalFor(User user) + { + _userService.GetUserByPrincipalAsync(Arg.Any()) + .Returns(Task.FromResult(user)); + } - [Fact] - public async Task PostRegister_ShouldRegisterUser() - { - var passwordHash = "abcdef"; - var token = "123456"; - var userGuid = new Guid(); - _userService.RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid) - .Returns(Task.FromResult(IdentityResult.Success)); - var request = new RegisterRequestModel - { - Name = "Example User", - Email = "user@example.com", - MasterPasswordHash = passwordHash, - MasterPasswordHint = "example", - Token = token, - OrganizationUserId = userGuid - }; + private void ConfigureUserServiceToRejectPasswordFor(User user) + { + _userService.CheckPasswordAsync(user, Arg.Any()) + .Returns(Task.FromResult(false)); + } - await _sut.PostRegister(request); + private void ConfigureUserServiceToAcceptPasswordFor(User user) + { + _userService.CheckPasswordAsync(user, Arg.Any()) + .Returns(Task.FromResult(true)); + _userService.VerifySecretAsync(user, Arg.Any()) + .Returns(Task.FromResult(true)); + } - await _userService.Received(1).RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid); - } + private void ConfigureUserServiceToReturnValidIdFor(User user) + { + _userService.GetUserByIdAsync(Arg.Any()) + .Returns(Task.FromResult(user)); + } - [Fact] - public async Task PostRegister_WhenUserServiceFails_ShouldThrowBadRequestException() - { - var passwordHash = "abcdef"; - var token = "123456"; - var userGuid = new Guid(); - _userService.RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid) - .Returns(Task.FromResult(IdentityResult.Failed())); - var request = new RegisterRequestModel - { - Name = "Example User", - Email = "user@example.com", - MasterPasswordHash = passwordHash, - MasterPasswordHint = "example", - Token = token, - OrganizationUserId = userGuid - }; - - await Assert.ThrowsAsync(() => _sut.PostRegister(request)); - } - - [Fact] - public async Task PostPasswordHint_ShouldNotifyUserService() - { - var email = "user@example.com"; - - await _sut.PostPasswordHint(new PasswordHintRequestModel { Email = email }); - - await _userService.Received(1).SendMasterPasswordHintAsync(email); - } - - [Fact] - public async Task PostEmailToken_ShouldInitiateEmailChange() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - ConfigureUserServiceToAcceptPasswordFor(user); - var newEmail = "example@user.com"; - - await _sut.PostEmailToken(new EmailTokenRequestModel { NewEmail = newEmail }); - - await _userService.Received(1).InitiateEmailChangeAsync(user, newEmail); - } - - [Fact] - public async Task PostEmailToken_WhenNotAuthorized_ShouldThrowUnauthorizedAccessException() - { - ConfigureUserServiceToReturnNullPrincipal(); - - await Assert.ThrowsAsync( - () => _sut.PostEmailToken(new EmailTokenRequestModel()) - ); - } - - [Fact] - public async Task PostEmailToken_WhenInvalidPasssword_ShouldThrowBadRequestException() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - ConfigureUserServiceToRejectPasswordFor(user); - - await Assert.ThrowsAsync( - () => _sut.PostEmailToken(new EmailTokenRequestModel()) - ); - } - - [Fact] - public async Task PostEmail_ShouldChangeUserEmail() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - _userService.ChangeEmailAsync(user, default, default, default, default, default) - .Returns(Task.FromResult(IdentityResult.Success)); - - await _sut.PostEmail(new EmailRequestModel()); - - await _userService.Received(1).ChangeEmailAsync(user, default, default, default, default, default); - } - - [Fact] - public async Task PostEmail_WhenNotAuthorized_ShouldThrownUnauthorizedAccessException() - { - ConfigureUserServiceToReturnNullPrincipal(); - - await Assert.ThrowsAsync( - () => _sut.PostEmail(new EmailRequestModel()) - ); - } - - [Fact] - public async Task PostEmail_WhenEmailCannotBeChanged_ShouldThrowBadRequestException() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - _userService.ChangeEmailAsync(user, default, default, default, default, default) - .Returns(Task.FromResult(IdentityResult.Failed())); - - await Assert.ThrowsAsync( - () => _sut.PostEmail(new EmailRequestModel()) - ); - } - - [Fact] - public async Task PostVerifyEmail_ShouldSendEmailVerification() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - - await _sut.PostVerifyEmail(); - - await _userService.Received(1).SendEmailVerificationAsync(user); - } - - [Fact] - public async Task PostVerifyEmail_WhenNotAuthorized_ShouldThrownUnauthorizedAccessException() - { - ConfigureUserServiceToReturnNullPrincipal(); - - await Assert.ThrowsAsync( - () => _sut.PostVerifyEmail() - ); - } - - [Fact] - public async Task PostVerifyEmailToken_ShouldConfirmEmail() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidIdFor(user); - _userService.ConfirmEmailAsync(user, Arg.Any()) - .Returns(Task.FromResult(IdentityResult.Success)); - - await _sut.PostVerifyEmailToken(new VerifyEmailRequestModel { UserId = "12345678-1234-1234-1234-123456789012" }); - - await _userService.Received(1).ConfirmEmailAsync(user, Arg.Any()); - } - - [Fact] - public async Task PostVerifyEmailToken_WhenUserDoesNotExist_ShouldThrowUnauthorizedAccessException() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnNullUserId(); - - await Assert.ThrowsAsync( - () => _sut.PostVerifyEmailToken(new VerifyEmailRequestModel { UserId = "12345678-1234-1234-1234-123456789012" }) - ); - } - - [Fact] - public async Task PostVerifyEmailToken_WhenEmailConfirmationFails_ShouldThrowBadRequestException() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidIdFor(user); - _userService.ConfirmEmailAsync(user, Arg.Any()) - .Returns(Task.FromResult(IdentityResult.Failed())); - - await Assert.ThrowsAsync( - () => _sut.PostVerifyEmailToken(new VerifyEmailRequestModel { UserId = "12345678-1234-1234-1234-123456789012" }) - ); - } - - [Fact] - public async Task PostPassword_ShouldChangePassword() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - _userService.ChangePasswordAsync(user, default, default, default, default) - .Returns(Task.FromResult(IdentityResult.Success)); - - await _sut.PostPassword(new PasswordRequestModel()); - - await _userService.Received(1).ChangePasswordAsync(user, default, default, default, default); - } - - [Fact] - public async Task PostPassword_WhenNotAuthorized_ShouldThrowUnauthorizedAccessException() - { - ConfigureUserServiceToReturnNullPrincipal(); - - await Assert.ThrowsAsync( - () => _sut.PostPassword(new PasswordRequestModel()) - ); - } - - [Fact] - public async Task PostPassword_WhenPasswordChangeFails_ShouldBadRequestException() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - _userService.ChangePasswordAsync(user, default, default, default, default) - .Returns(Task.FromResult(IdentityResult.Failed())); - - await Assert.ThrowsAsync( - () => _sut.PostPassword(new PasswordRequestModel()) - ); - } - - [Fact] - public async Task GetApiKey_ShouldReturnApiKeyResponse() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - ConfigureUserServiceToAcceptPasswordFor(user); - await _sut.ApiKey(new SecretVerificationRequestModel()); - } - - [Fact] - public async Task GetApiKey_WhenUserDoesNotExist_ShouldThrowUnauthorizedAccessException() - { - ConfigureUserServiceToReturnNullPrincipal(); - - await Assert.ThrowsAsync( - () => _sut.ApiKey(new SecretVerificationRequestModel()) - ); - } - - [Fact] - public async Task GetApiKey_WhenPasswordCheckFails_ShouldThrowBadRequestException() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - ConfigureUserServiceToRejectPasswordFor(user); - await Assert.ThrowsAsync( - () => _sut.ApiKey(new SecretVerificationRequestModel()) - ); - } - - [Fact] - public async Task PostRotateApiKey_ShouldRotateApiKey() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - ConfigureUserServiceToAcceptPasswordFor(user); - await _sut.RotateApiKey(new SecretVerificationRequestModel()); - } - - [Fact] - public async Task PostRotateApiKey_WhenUserDoesNotExist_ShouldThrowUnauthorizedAccessException() - { - ConfigureUserServiceToReturnNullPrincipal(); - - await Assert.ThrowsAsync( - () => _sut.ApiKey(new SecretVerificationRequestModel()) - ); - } - - [Fact] - public async Task PostRotateApiKey_WhenPasswordCheckFails_ShouldThrowBadRequestException() - { - var user = GenerateExampleUser(); - ConfigureUserServiceToReturnValidPrincipalFor(user); - ConfigureUserServiceToRejectPasswordFor(user); - await Assert.ThrowsAsync( - () => _sut.ApiKey(new SecretVerificationRequestModel()) - ); - } - - // Below are helper functions that currently belong to this - // test class, but ultimately may need to be split out into - // something greater in order to share common test steps with - // other test suites. They are included here for the time being - // until that day comes. - private User GenerateExampleUser() - { - return new User - { - Email = "user@example.com" - }; - } - - private void ConfigureUserServiceToReturnNullPrincipal() - { - _userService.GetUserByPrincipalAsync(Arg.Any()) - .Returns(Task.FromResult((User)null)); - } - - private void ConfigureUserServiceToReturnValidPrincipalFor(User user) - { - _userService.GetUserByPrincipalAsync(Arg.Any()) - .Returns(Task.FromResult(user)); - } - - private void ConfigureUserServiceToRejectPasswordFor(User user) - { - _userService.CheckPasswordAsync(user, Arg.Any()) - .Returns(Task.FromResult(false)); - } - - private void ConfigureUserServiceToAcceptPasswordFor(User user) - { - _userService.CheckPasswordAsync(user, Arg.Any()) - .Returns(Task.FromResult(true)); - _userService.VerifySecretAsync(user, Arg.Any()) - .Returns(Task.FromResult(true)); - } - - private void ConfigureUserServiceToReturnValidIdFor(User user) - { - _userService.GetUserByIdAsync(Arg.Any()) - .Returns(Task.FromResult(user)); - } - - private void ConfigureUserServiceToReturnNullUserId() - { - _userService.GetUserByIdAsync(Arg.Any()) - .Returns(Task.FromResult((User)null)); - } + private void ConfigureUserServiceToReturnNullUserId() + { + _userService.GetUserByIdAsync(Arg.Any()) + .Returns(Task.FromResult((User)null)); } } diff --git a/test/Api.Test/Controllers/CollectionsControllerTests.cs b/test/Api.Test/Controllers/CollectionsControllerTests.cs index ba9620a00..b5d304f7a 100644 --- a/test/Api.Test/Controllers/CollectionsControllerTests.cs +++ b/test/Api.Test/Controllers/CollectionsControllerTests.cs @@ -11,79 +11,78 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Api.Test.Controllers +namespace Bit.Api.Test.Controllers; + +[ControllerCustomize(typeof(CollectionsController))] +[SutProviderCustomize] +public class CollectionsControllerTests { - [ControllerCustomize(typeof(CollectionsController))] - [SutProviderCustomize] - public class CollectionsControllerTests + [Theory, BitAutoData] + public async Task Post_Success(Guid orgId, SutProvider sutProvider) { - [Theory, BitAutoData] - public async Task Post_Success(Guid orgId, SutProvider sutProvider) + sutProvider.GetDependency() + .CreateNewCollections(orgId) + .Returns(true); + + sutProvider.GetDependency() + .EditAnyCollection(orgId) + .Returns(false); + + var collectionRequest = new CollectionRequestModel { - sutProvider.GetDependency() - .CreateNewCollections(orgId) - .Returns(true); + Name = "encrypted_string", + ExternalId = "my_external_id" + }; - sutProvider.GetDependency() - .EditAnyCollection(orgId) - .Returns(false); + _ = await sutProvider.Sut.Post(orgId, collectionRequest); - var collectionRequest = new CollectionRequestModel + await sutProvider.GetDependency() + .Received(1) + .SaveAsync(Arg.Any(), Arg.Any>(), null); + } + + [Theory, BitAutoData] + public async Task Put_Success(Guid orgId, Guid collectionId, Guid userId, CollectionRequestModel collectionRequest, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .ViewAssignedCollections(orgId) + .Returns(true); + + sutProvider.GetDependency() + .EditAssignedCollections(orgId) + .Returns(true); + + sutProvider.GetDependency() + .UserId + .Returns(userId); + + sutProvider.GetDependency() + .GetByIdAsync(collectionId, userId) + .Returns(new CollectionDetails { - Name = "encrypted_string", - ExternalId = "my_external_id" - }; + OrganizationId = orgId, + }); - _ = await sutProvider.Sut.Post(orgId, collectionRequest); + _ = await sutProvider.Sut.Put(orgId, collectionId, collectionRequest); + } - await sutProvider.GetDependency() - .Received(1) - .SaveAsync(Arg.Any(), Arg.Any>(), null); - } + [Theory, BitAutoData] + public async Task Put_CanNotEditAssignedCollection_ThrowsNotFound(Guid orgId, Guid collectionId, Guid userId, CollectionRequestModel collectionRequest, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .EditAssignedCollections(orgId) + .Returns(true); - [Theory, BitAutoData] - public async Task Put_Success(Guid orgId, Guid collectionId, Guid userId, CollectionRequestModel collectionRequest, - SutProvider sutProvider) - { - sutProvider.GetDependency() - .ViewAssignedCollections(orgId) - .Returns(true); + sutProvider.GetDependency() + .UserId + .Returns(userId); - sutProvider.GetDependency() - .EditAssignedCollections(orgId) - .Returns(true); + sutProvider.GetDependency() + .GetByIdAsync(collectionId, userId) + .Returns(Task.FromResult(null)); - sutProvider.GetDependency() - .UserId - .Returns(userId); - - sutProvider.GetDependency() - .GetByIdAsync(collectionId, userId) - .Returns(new CollectionDetails - { - OrganizationId = orgId, - }); - - _ = await sutProvider.Sut.Put(orgId, collectionId, collectionRequest); - } - - [Theory, BitAutoData] - public async Task Put_CanNotEditAssignedCollection_ThrowsNotFound(Guid orgId, Guid collectionId, Guid userId, CollectionRequestModel collectionRequest, - SutProvider sutProvider) - { - sutProvider.GetDependency() - .EditAssignedCollections(orgId) - .Returns(true); - - sutProvider.GetDependency() - .UserId - .Returns(userId); - - sutProvider.GetDependency() - .GetByIdAsync(collectionId, userId) - .Returns(Task.FromResult(null)); - - _ = await Assert.ThrowsAsync(async () => await sutProvider.Sut.Put(orgId, collectionId, collectionRequest)); - } + _ = await Assert.ThrowsAsync(async () => await sutProvider.Sut.Put(orgId, collectionId, collectionRequest)); } } diff --git a/test/Api.Test/Controllers/OrganizationConnectionsControllerTests.cs b/test/Api.Test/Controllers/OrganizationConnectionsControllerTests.cs index 88834621a..80bfcfe00 100644 --- a/test/Api.Test/Controllers/OrganizationConnectionsControllerTests.cs +++ b/test/Api.Test/Controllers/OrganizationConnectionsControllerTests.cs @@ -18,302 +18,301 @@ using Bit.Test.Common.Helpers; using NSubstitute; using Xunit; -namespace Bit.Api.Test.Controllers +namespace Bit.Api.Test.Controllers; + +[ControllerCustomize(typeof(OrganizationConnectionsController))] +[SutProviderCustomize] +[JsonDocumentCustomize] +public class OrganizationConnectionsControllerTests { - [ControllerCustomize(typeof(OrganizationConnectionsController))] - [SutProviderCustomize] - [JsonDocumentCustomize] - public class OrganizationConnectionsControllerTests + public static IEnumerable ConnectionTypes => + Enum.GetValues().Select(p => new object[] { p }); + + + [Theory] + [BitAutoData(true, true)] + [BitAutoData(false, true)] + [BitAutoData(true, false)] + [BitAutoData(false, false)] + public void ConnectionEnabled_RequiresBothSelfHostAndCommunications(bool selfHosted, bool enableCloudCommunication, SutProvider sutProvider) { - public static IEnumerable ConnectionTypes => - Enum.GetValues().Select(p => new object[] { p }); + var globalSettingsMock = sutProvider.GetDependency(); + globalSettingsMock.SelfHosted.Returns(selfHosted); + globalSettingsMock.EnableCloudCommunication.Returns(enableCloudCommunication); + Action assert = selfHosted && enableCloudCommunication ? Assert.True : Assert.False; - [Theory] - [BitAutoData(true, true)] - [BitAutoData(false, true)] - [BitAutoData(true, false)] - [BitAutoData(false, false)] - public void ConnectionEnabled_RequiresBothSelfHostAndCommunications(bool selfHosted, bool enableCloudCommunication, SutProvider sutProvider) - { - var globalSettingsMock = sutProvider.GetDependency(); - globalSettingsMock.SelfHosted.Returns(selfHosted); - globalSettingsMock.EnableCloudCommunication.Returns(enableCloudCommunication); + var result = sutProvider.Sut.ConnectionsEnabled(); - Action assert = selfHosted && enableCloudCommunication ? Assert.True : Assert.False; - - var result = sutProvider.Sut.ConnectionsEnabled(); - - assert(result); - } - - [Theory] - [BitAutoData] - public async Task CreateConnection_CloudBillingSync_RequiresOwnerPermissions(SutProvider sutProvider) - { - var model = new OrganizationConnectionRequestModel - { - Type = OrganizationConnectionType.CloudBillingSync, - }; - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.CreateConnection(model)); - - Assert.Contains($"You do not have permission to create a connection of type", exception.Message); - } - - [Theory] - [BitMemberAutoData(nameof(ConnectionTypes))] - public async Task CreateConnection_OnlyOneConnectionOfEachType(OrganizationConnectionType type, - OrganizationConnectionRequestModel model, BillingSyncConfig config, Guid existingEntityId, - SutProvider sutProvider) - { - model.Type = type; - model.Config = JsonDocumentFromObject(config); - var typedModel = new OrganizationConnectionRequestModel(model); - var existing = typedModel.ToData(existingEntityId).ToEntity(); - - sutProvider.GetDependency().OrganizationOwner(model.OrganizationId).Returns(true); - - sutProvider.GetDependency().GetByOrganizationIdTypeAsync(model.OrganizationId, type).Returns(new[] { existing }); - - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.CreateConnection(model)); - - Assert.Contains($"The requested organization already has a connection of type {model.Type}. Only one of each connection type may exist per organization.", exception.Message); - } - - [Theory] - [BitAutoData] - public async Task CreateConnection_BillingSyncType_InvalidLicense_Throws(OrganizationConnectionRequestModel model, - BillingSyncConfig config, Guid cloudOrgId, OrganizationLicense organizationLicense, - SutProvider sutProvider) - { - model.Type = OrganizationConnectionType.CloudBillingSync; - organizationLicense.Id = cloudOrgId; - - model.Config = JsonDocumentFromObject(config); - var typedModel = new OrganizationConnectionRequestModel(model); - typedModel.ParsedConfig.CloudOrganizationId = cloudOrgId; - - sutProvider.GetDependency() - .OrganizationOwner(model.OrganizationId) - .Returns(true); - - sutProvider.GetDependency() - .ReadOrganizationLicenseAsync(model.OrganizationId) - .Returns(organizationLicense); - - sutProvider.GetDependency() - .VerifyLicense(organizationLicense) - .Returns(false); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateConnection(model)); - } - - [Theory] - [BitAutoData] - public async Task CreateConnection_Success(OrganizationConnectionRequestModel model, BillingSyncConfig config, - Guid cloudOrgId, OrganizationLicense organizationLicense, SutProvider sutProvider) - { - organizationLicense.Id = cloudOrgId; - - model.Config = JsonDocumentFromObject(config); - var typedModel = new OrganizationConnectionRequestModel(model); - typedModel.ParsedConfig.CloudOrganizationId = cloudOrgId; - - sutProvider.GetDependency().SelfHosted.Returns(true); - sutProvider.GetDependency().CreateAsync(default) - .ReturnsForAnyArgs(typedModel.ToData(Guid.NewGuid()).ToEntity()); - sutProvider.GetDependency().OrganizationOwner(model.OrganizationId).Returns(true); - sutProvider.GetDependency() - .ReadOrganizationLicenseAsync(Arg.Any()) - .Returns(organizationLicense); - - sutProvider.GetDependency() - .VerifyLicense(organizationLicense) - .Returns(true); - - await sutProvider.Sut.CreateConnection(model); - - await sutProvider.GetDependency().Received(1) - .CreateAsync(Arg.Is(AssertHelper.AssertPropertyEqual(typedModel.ToData()))); - } - - [Theory] - [BitAutoData] - public async Task UpdateConnection_RequiresOwnerPermissions(SutProvider sutProvider) - { - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(new OrganizationConnection()); - - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateConnection(default, null)); - - Assert.Contains("You do not have permission to update this connection.", exception.Message); - } - - [Theory] - [BitAutoData(OrganizationConnectionType.CloudBillingSync)] - public async Task UpdateConnection_BillingSync_OnlyOneConnectionOfEachType(OrganizationConnectionType type, - OrganizationConnection existing1, OrganizationConnection existing2, BillingSyncConfig config, - SutProvider sutProvider) - { - existing1.Type = existing2.Type = type; - existing1.Config = JsonSerializer.Serialize(config); - var typedModel = RequestModelFromEntity(existing1); - - sutProvider.GetDependency().OrganizationOwner(typedModel.OrganizationId).Returns(true); - - var orgConnectionRepository = sutProvider.GetDependency(); - orgConnectionRepository.GetByIdAsync(existing1.Id).Returns(existing1); - orgConnectionRepository.GetByIdAsync(existing2.Id).Returns(existing2); - orgConnectionRepository.GetByOrganizationIdTypeAsync(typedModel.OrganizationId, type).Returns(new[] { existing1, existing2 }); - - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateConnection(existing1.Id, typedModel)); - - Assert.Contains($"The requested organization already has a connection of type {typedModel.Type}. Only one of each connection type may exist per organization.", exception.Message); - } - - [Theory] - [BitAutoData(OrganizationConnectionType.Scim)] - public async Task UpdateConnection_Scim_OnlyOneConnectionOfEachType(OrganizationConnectionType type, - OrganizationConnection existing1, OrganizationConnection existing2, ScimConfig config, - SutProvider sutProvider) - { - existing1.Type = existing2.Type = type; - existing1.Config = JsonSerializer.Serialize(config); - var typedModel = RequestModelFromEntity(existing1); - - sutProvider.GetDependency().OrganizationOwner(typedModel.OrganizationId).Returns(true); - - sutProvider.GetDependency() - .GetByIdAsync(existing1.Id) - .Returns(existing1); - - sutProvider.GetDependency().ManageScim(typedModel.OrganizationId).Returns(true); - - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(typedModel.OrganizationId, type) - .Returns(new[] { existing1, existing2 }); - - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateConnection(existing1.Id, typedModel)); - - Assert.Contains($"The requested organization already has a connection of type {typedModel.Type}. Only one of each connection type may exist per organization.", exception.Message); - } - - [Theory] - [BitAutoData] - public async Task UpdateConnection_Success(OrganizationConnection existing, BillingSyncConfig config, - OrganizationConnection updated, - SutProvider sutProvider) - { - existing.SetConfig(new BillingSyncConfig - { - CloudOrganizationId = config.CloudOrganizationId, - }); - updated.Config = JsonSerializer.Serialize(config); - updated.Id = existing.Id; - updated.Type = OrganizationConnectionType.CloudBillingSync; - var model = RequestModelFromEntity(updated); - - sutProvider.GetDependency().OrganizationOwner(model.OrganizationId).Returns(true); - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(model.OrganizationId, model.Type) - .Returns(new[] { existing }); - sutProvider.GetDependency() - .UpdateAsync(default) - .ReturnsForAnyArgs(updated); - sutProvider.GetDependency() - .GetByIdAsync(existing.Id) - .Returns(existing); - - var expected = new OrganizationConnectionResponseModel(updated, typeof(BillingSyncConfig)); - var result = await sutProvider.Sut.UpdateConnection(existing.Id, model); - - AssertHelper.AssertPropertyEqual(expected, result); - await sutProvider.GetDependency().Received(1) - .UpdateAsync(Arg.Is(AssertHelper.AssertPropertyEqual(model.ToData(updated.Id)))); - } - - [Theory] - [BitAutoData] - public async Task UpdateConnection_DoesNotExist_ThrowsNotFound(SutProvider sutProvider) - { - await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateConnection(Guid.NewGuid(), null)); - } - - [Theory] - [BitAutoData] - public async Task GetConnection_RequiresOwnerPermissions(Guid connectionId, SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.GetConnection(connectionId, OrganizationConnectionType.CloudBillingSync)); - - Assert.Contains("You do not have permission to retrieve a connection of type", exception.Message); - } - - [Theory] - [BitAutoData] - public async Task GetConnection_Success(OrganizationConnection connection, BillingSyncConfig config, - SutProvider sutProvider) - { - connection.Config = JsonSerializer.Serialize(config); - - sutProvider.GetDependency().SelfHosted.Returns(true); - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(connection.OrganizationId, connection.Type) - .Returns(new[] { connection }); - sutProvider.GetDependency().OrganizationOwner(connection.OrganizationId).Returns(true); - - var expected = new OrganizationConnectionResponseModel(connection, typeof(BillingSyncConfig)); - var actual = await sutProvider.Sut.GetConnection(connection.OrganizationId, connection.Type); - - AssertHelper.AssertPropertyEqual(expected, actual); - } - - [Theory] - [BitAutoData] - public async Task DeleteConnection_NotFound(Guid connectionId, - SutProvider sutProvider) - { - await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteConnection(connectionId)); - } - - [Theory] - [BitAutoData] - public async Task DeleteConnection_RequiresOwnerPermissions(OrganizationConnection connection, - SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(connection.Id).Returns(connection); - - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteConnection(connection.Id)); - - Assert.Contains("You do not have permission to remove this connection of type", exception.Message); - } - - [Theory] - [BitAutoData] - public async Task DeleteConnection_Success(OrganizationConnection connection, - SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(connection.Id).Returns(connection); - sutProvider.GetDependency().OrganizationOwner(connection.OrganizationId).Returns(true); - - await sutProvider.Sut.DeleteConnection(connection.Id); - - await sutProvider.GetDependency().DeleteAsync(connection); - } - - private static OrganizationConnectionRequestModel RequestModelFromEntity(OrganizationConnection entity) - where T : new() - { - return new(new OrganizationConnectionRequestModel() - { - Type = entity.Type, - OrganizationId = entity.OrganizationId, - Enabled = entity.Enabled, - Config = JsonDocument.Parse(entity.Config), - }); - } - - private static JsonDocument JsonDocumentFromObject(T obj) => JsonDocument.Parse(JsonSerializer.Serialize(obj)); + assert(result); } + + [Theory] + [BitAutoData] + public async Task CreateConnection_CloudBillingSync_RequiresOwnerPermissions(SutProvider sutProvider) + { + var model = new OrganizationConnectionRequestModel + { + Type = OrganizationConnectionType.CloudBillingSync, + }; + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.CreateConnection(model)); + + Assert.Contains($"You do not have permission to create a connection of type", exception.Message); + } + + [Theory] + [BitMemberAutoData(nameof(ConnectionTypes))] + public async Task CreateConnection_OnlyOneConnectionOfEachType(OrganizationConnectionType type, + OrganizationConnectionRequestModel model, BillingSyncConfig config, Guid existingEntityId, + SutProvider sutProvider) + { + model.Type = type; + model.Config = JsonDocumentFromObject(config); + var typedModel = new OrganizationConnectionRequestModel(model); + var existing = typedModel.ToData(existingEntityId).ToEntity(); + + sutProvider.GetDependency().OrganizationOwner(model.OrganizationId).Returns(true); + + sutProvider.GetDependency().GetByOrganizationIdTypeAsync(model.OrganizationId, type).Returns(new[] { existing }); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.CreateConnection(model)); + + Assert.Contains($"The requested organization already has a connection of type {model.Type}. Only one of each connection type may exist per organization.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task CreateConnection_BillingSyncType_InvalidLicense_Throws(OrganizationConnectionRequestModel model, + BillingSyncConfig config, Guid cloudOrgId, OrganizationLicense organizationLicense, + SutProvider sutProvider) + { + model.Type = OrganizationConnectionType.CloudBillingSync; + organizationLicense.Id = cloudOrgId; + + model.Config = JsonDocumentFromObject(config); + var typedModel = new OrganizationConnectionRequestModel(model); + typedModel.ParsedConfig.CloudOrganizationId = cloudOrgId; + + sutProvider.GetDependency() + .OrganizationOwner(model.OrganizationId) + .Returns(true); + + sutProvider.GetDependency() + .ReadOrganizationLicenseAsync(model.OrganizationId) + .Returns(organizationLicense); + + sutProvider.GetDependency() + .VerifyLicense(organizationLicense) + .Returns(false); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateConnection(model)); + } + + [Theory] + [BitAutoData] + public async Task CreateConnection_Success(OrganizationConnectionRequestModel model, BillingSyncConfig config, + Guid cloudOrgId, OrganizationLicense organizationLicense, SutProvider sutProvider) + { + organizationLicense.Id = cloudOrgId; + + model.Config = JsonDocumentFromObject(config); + var typedModel = new OrganizationConnectionRequestModel(model); + typedModel.ParsedConfig.CloudOrganizationId = cloudOrgId; + + sutProvider.GetDependency().SelfHosted.Returns(true); + sutProvider.GetDependency().CreateAsync(default) + .ReturnsForAnyArgs(typedModel.ToData(Guid.NewGuid()).ToEntity()); + sutProvider.GetDependency().OrganizationOwner(model.OrganizationId).Returns(true); + sutProvider.GetDependency() + .ReadOrganizationLicenseAsync(Arg.Any()) + .Returns(organizationLicense); + + sutProvider.GetDependency() + .VerifyLicense(organizationLicense) + .Returns(true); + + await sutProvider.Sut.CreateConnection(model); + + await sutProvider.GetDependency().Received(1) + .CreateAsync(Arg.Is(AssertHelper.AssertPropertyEqual(typedModel.ToData()))); + } + + [Theory] + [BitAutoData] + public async Task UpdateConnection_RequiresOwnerPermissions(SutProvider sutProvider) + { + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(new OrganizationConnection()); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateConnection(default, null)); + + Assert.Contains("You do not have permission to update this connection.", exception.Message); + } + + [Theory] + [BitAutoData(OrganizationConnectionType.CloudBillingSync)] + public async Task UpdateConnection_BillingSync_OnlyOneConnectionOfEachType(OrganizationConnectionType type, + OrganizationConnection existing1, OrganizationConnection existing2, BillingSyncConfig config, + SutProvider sutProvider) + { + existing1.Type = existing2.Type = type; + existing1.Config = JsonSerializer.Serialize(config); + var typedModel = RequestModelFromEntity(existing1); + + sutProvider.GetDependency().OrganizationOwner(typedModel.OrganizationId).Returns(true); + + var orgConnectionRepository = sutProvider.GetDependency(); + orgConnectionRepository.GetByIdAsync(existing1.Id).Returns(existing1); + orgConnectionRepository.GetByIdAsync(existing2.Id).Returns(existing2); + orgConnectionRepository.GetByOrganizationIdTypeAsync(typedModel.OrganizationId, type).Returns(new[] { existing1, existing2 }); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateConnection(existing1.Id, typedModel)); + + Assert.Contains($"The requested organization already has a connection of type {typedModel.Type}. Only one of each connection type may exist per organization.", exception.Message); + } + + [Theory] + [BitAutoData(OrganizationConnectionType.Scim)] + public async Task UpdateConnection_Scim_OnlyOneConnectionOfEachType(OrganizationConnectionType type, + OrganizationConnection existing1, OrganizationConnection existing2, ScimConfig config, + SutProvider sutProvider) + { + existing1.Type = existing2.Type = type; + existing1.Config = JsonSerializer.Serialize(config); + var typedModel = RequestModelFromEntity(existing1); + + sutProvider.GetDependency().OrganizationOwner(typedModel.OrganizationId).Returns(true); + + sutProvider.GetDependency() + .GetByIdAsync(existing1.Id) + .Returns(existing1); + + sutProvider.GetDependency().ManageScim(typedModel.OrganizationId).Returns(true); + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(typedModel.OrganizationId, type) + .Returns(new[] { existing1, existing2 }); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateConnection(existing1.Id, typedModel)); + + Assert.Contains($"The requested organization already has a connection of type {typedModel.Type}. Only one of each connection type may exist per organization.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task UpdateConnection_Success(OrganizationConnection existing, BillingSyncConfig config, + OrganizationConnection updated, + SutProvider sutProvider) + { + existing.SetConfig(new BillingSyncConfig + { + CloudOrganizationId = config.CloudOrganizationId, + }); + updated.Config = JsonSerializer.Serialize(config); + updated.Id = existing.Id; + updated.Type = OrganizationConnectionType.CloudBillingSync; + var model = RequestModelFromEntity(updated); + + sutProvider.GetDependency().OrganizationOwner(model.OrganizationId).Returns(true); + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(model.OrganizationId, model.Type) + .Returns(new[] { existing }); + sutProvider.GetDependency() + .UpdateAsync(default) + .ReturnsForAnyArgs(updated); + sutProvider.GetDependency() + .GetByIdAsync(existing.Id) + .Returns(existing); + + var expected = new OrganizationConnectionResponseModel(updated, typeof(BillingSyncConfig)); + var result = await sutProvider.Sut.UpdateConnection(existing.Id, model); + + AssertHelper.AssertPropertyEqual(expected, result); + await sutProvider.GetDependency().Received(1) + .UpdateAsync(Arg.Is(AssertHelper.AssertPropertyEqual(model.ToData(updated.Id)))); + } + + [Theory] + [BitAutoData] + public async Task UpdateConnection_DoesNotExist_ThrowsNotFound(SutProvider sutProvider) + { + await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateConnection(Guid.NewGuid(), null)); + } + + [Theory] + [BitAutoData] + public async Task GetConnection_RequiresOwnerPermissions(Guid connectionId, SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.GetConnection(connectionId, OrganizationConnectionType.CloudBillingSync)); + + Assert.Contains("You do not have permission to retrieve a connection of type", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task GetConnection_Success(OrganizationConnection connection, BillingSyncConfig config, + SutProvider sutProvider) + { + connection.Config = JsonSerializer.Serialize(config); + + sutProvider.GetDependency().SelfHosted.Returns(true); + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(connection.OrganizationId, connection.Type) + .Returns(new[] { connection }); + sutProvider.GetDependency().OrganizationOwner(connection.OrganizationId).Returns(true); + + var expected = new OrganizationConnectionResponseModel(connection, typeof(BillingSyncConfig)); + var actual = await sutProvider.Sut.GetConnection(connection.OrganizationId, connection.Type); + + AssertHelper.AssertPropertyEqual(expected, actual); + } + + [Theory] + [BitAutoData] + public async Task DeleteConnection_NotFound(Guid connectionId, + SutProvider sutProvider) + { + await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteConnection(connectionId)); + } + + [Theory] + [BitAutoData] + public async Task DeleteConnection_RequiresOwnerPermissions(OrganizationConnection connection, + SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(connection.Id).Returns(connection); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteConnection(connection.Id)); + + Assert.Contains("You do not have permission to remove this connection of type", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task DeleteConnection_Success(OrganizationConnection connection, + SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(connection.Id).Returns(connection); + sutProvider.GetDependency().OrganizationOwner(connection.OrganizationId).Returns(true); + + await sutProvider.Sut.DeleteConnection(connection.Id); + + await sutProvider.GetDependency().DeleteAsync(connection); + } + + private static OrganizationConnectionRequestModel RequestModelFromEntity(OrganizationConnection entity) + where T : new() + { + return new(new OrganizationConnectionRequestModel() + { + Type = entity.Type, + OrganizationId = entity.OrganizationId, + Enabled = entity.Enabled, + Config = JsonDocument.Parse(entity.Config), + }); + } + + private static JsonDocument JsonDocumentFromObject(T obj) => JsonDocument.Parse(JsonSerializer.Serialize(obj)); } diff --git a/test/Api.Test/Controllers/OrganizationSponsorshipsControllerTests.cs b/test/Api.Test/Controllers/OrganizationSponsorshipsControllerTests.cs index 0cafdf9ff..e58add5ef 100644 --- a/test/Api.Test/Controllers/OrganizationSponsorshipsControllerTests.cs +++ b/test/Api.Test/Controllers/OrganizationSponsorshipsControllerTests.cs @@ -13,136 +13,135 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Api.Test.Controllers +namespace Bit.Api.Test.Controllers; + +[ControllerCustomize(typeof(OrganizationSponsorshipsController))] +[SutProviderCustomize] +public class OrganizationSponsorshipsControllerTests { - [ControllerCustomize(typeof(OrganizationSponsorshipsController))] - [SutProviderCustomize] - public class OrganizationSponsorshipsControllerTests + public static IEnumerable EnterprisePlanTypes => + Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product == ProductType.Enterprise).Select(p => new object[] { p }); + public static IEnumerable NonEnterprisePlanTypes => + Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product != ProductType.Enterprise).Select(p => new object[] { p }); + public static IEnumerable NonFamiliesPlanTypes => + Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product != ProductType.Families).Select(p => new object[] { p }); + + public static IEnumerable NonConfirmedOrganizationUsersStatuses => + Enum.GetValues() + .Where(s => s != OrganizationUserStatusType.Confirmed) + .Select(s => new object[] { s }); + + + [Theory] + [BitAutoData] + public async Task RedeemSponsorship_BadToken_ThrowsBadRequest(string sponsorshipToken, User user, + OrganizationSponsorshipRedeemRequestModel model, SutProvider sutProvider) { - public static IEnumerable EnterprisePlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product == ProductType.Enterprise).Select(p => new object[] { p }); - public static IEnumerable NonEnterprisePlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product != ProductType.Enterprise).Select(p => new object[] { p }); - public static IEnumerable NonFamiliesPlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product != ProductType.Families).Select(p => new object[] { p }); + sutProvider.GetDependency().UserId.Returns(user.Id); + sutProvider.GetDependency().GetUserByIdAsync(user.Id) + .Returns(user); + sutProvider.GetDependency().ValidateRedemptionTokenAsync(sponsorshipToken, + user.Email).Returns((false, null)); - public static IEnumerable NonConfirmedOrganizationUsersStatuses => - Enum.GetValues() - .Where(s => s != OrganizationUserStatusType.Confirmed) - .Select(s => new object[] { s }); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RedeemSponsorship(sponsorshipToken, model)); + Assert.Contains("Failed to parse sponsorship token.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SetUpSponsorshipAsync(default, default); + } - [Theory] - [BitAutoData] - public async Task RedeemSponsorship_BadToken_ThrowsBadRequest(string sponsorshipToken, User user, - OrganizationSponsorshipRedeemRequestModel model, SutProvider sutProvider) - { - sutProvider.GetDependency().UserId.Returns(user.Id); - sutProvider.GetDependency().GetUserByIdAsync(user.Id) - .Returns(user); - sutProvider.GetDependency().ValidateRedemptionTokenAsync(sponsorshipToken, - user.Email).Returns((false, null)); + [Theory] + [BitAutoData] + public async Task RedeemSponsorship_NotSponsoredOrgOwner_ThrowsBadRequest(string sponsorshipToken, User user, + OrganizationSponsorship sponsorship, OrganizationSponsorshipRedeemRequestModel model, + SutProvider sutProvider) + { + sutProvider.GetDependency().UserId.Returns(user.Id); + sutProvider.GetDependency().GetUserByIdAsync(user.Id) + .Returns(user); + sutProvider.GetDependency().ValidateRedemptionTokenAsync(sponsorshipToken, + user.Email).Returns((true, sponsorship)); + sutProvider.GetDependency().OrganizationOwner(model.SponsoredOrganizationId).Returns(false); - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.RedeemSponsorship(sponsorshipToken, model)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RedeemSponsorship(sponsorshipToken, model)); - Assert.Contains("Failed to parse sponsorship token.", exception.Message); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SetUpSponsorshipAsync(default, default); - } + Assert.Contains("Can only redeem sponsorship for an organization you own.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SetUpSponsorshipAsync(default, default); + } - [Theory] - [BitAutoData] - public async Task RedeemSponsorship_NotSponsoredOrgOwner_ThrowsBadRequest(string sponsorshipToken, User user, - OrganizationSponsorship sponsorship, OrganizationSponsorshipRedeemRequestModel model, - SutProvider sutProvider) - { - sutProvider.GetDependency().UserId.Returns(user.Id); - sutProvider.GetDependency().GetUserByIdAsync(user.Id) - .Returns(user); - sutProvider.GetDependency().ValidateRedemptionTokenAsync(sponsorshipToken, - user.Email).Returns((true, sponsorship)); - sutProvider.GetDependency().OrganizationOwner(model.SponsoredOrganizationId).Returns(false); + [Theory] + [BitAutoData] + public async Task RedeemSponsorship_NotSponsoredOrgOwner_Success(string sponsorshipToken, User user, + OrganizationSponsorship sponsorship, Organization sponsoringOrganization, + OrganizationSponsorshipRedeemRequestModel model, SutProvider sutProvider) + { + sutProvider.GetDependency().UserId.Returns(user.Id); + sutProvider.GetDependency().GetUserByIdAsync(user.Id) + .Returns(user); + sutProvider.GetDependency().ValidateRedemptionTokenAsync(sponsorshipToken, + user.Email).Returns((true, sponsorship)); + sutProvider.GetDependency().OrganizationOwner(model.SponsoredOrganizationId).Returns(true); + sutProvider.GetDependency().GetByIdAsync(model.SponsoredOrganizationId).Returns(sponsoringOrganization); - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.RedeemSponsorship(sponsorshipToken, model)); + await sutProvider.Sut.RedeemSponsorship(sponsorshipToken, model); - Assert.Contains("Can only redeem sponsorship for an organization you own.", exception.Message); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SetUpSponsorshipAsync(default, default); - } + await sutProvider.GetDependency().Received(1) + .SetUpSponsorshipAsync(sponsorship, sponsoringOrganization); + } - [Theory] - [BitAutoData] - public async Task RedeemSponsorship_NotSponsoredOrgOwner_Success(string sponsorshipToken, User user, - OrganizationSponsorship sponsorship, Organization sponsoringOrganization, - OrganizationSponsorshipRedeemRequestModel model, SutProvider sutProvider) - { - sutProvider.GetDependency().UserId.Returns(user.Id); - sutProvider.GetDependency().GetUserByIdAsync(user.Id) - .Returns(user); - sutProvider.GetDependency().ValidateRedemptionTokenAsync(sponsorshipToken, - user.Email).Returns((true, sponsorship)); - sutProvider.GetDependency().OrganizationOwner(model.SponsoredOrganizationId).Returns(true); - sutProvider.GetDependency().GetByIdAsync(model.SponsoredOrganizationId).Returns(sponsoringOrganization); + [Theory] + [BitAutoData] + public async Task PreValidateSponsorshipToken_ValidatesToken_Success(string sponsorshipToken, User user, + OrganizationSponsorship sponsorship, SutProvider sutProvider) + { + sutProvider.GetDependency().UserId.Returns(user.Id); + sutProvider.GetDependency().GetUserByIdAsync(user.Id) + .Returns(user); + sutProvider.GetDependency() + .ValidateRedemptionTokenAsync(sponsorshipToken, user.Email).Returns((true, sponsorship)); - await sutProvider.Sut.RedeemSponsorship(sponsorshipToken, model); + await sutProvider.Sut.PreValidateSponsorshipToken(sponsorshipToken); - await sutProvider.GetDependency().Received(1) - .SetUpSponsorshipAsync(sponsorship, sponsoringOrganization); - } + await sutProvider.GetDependency().Received(1) + .ValidateRedemptionTokenAsync(sponsorshipToken, user.Email); + } - [Theory] - [BitAutoData] - public async Task PreValidateSponsorshipToken_ValidatesToken_Success(string sponsorshipToken, User user, - OrganizationSponsorship sponsorship, SutProvider sutProvider) - { - sutProvider.GetDependency().UserId.Returns(user.Id); - sutProvider.GetDependency().GetUserByIdAsync(user.Id) - .Returns(user); - sutProvider.GetDependency() - .ValidateRedemptionTokenAsync(sponsorshipToken, user.Email).Returns((true, sponsorship)); + [Theory] + [BitAutoData] + public async Task RevokeSponsorship_WrongSponsoringUser_ThrowsBadRequest(OrganizationUser sponsoringOrgUser, + Guid currentUserId, SutProvider sutProvider) + { + sutProvider.GetDependency().UserId.Returns(currentUserId); + sutProvider.GetDependency().GetByIdAsync(sponsoringOrgUser.Id) + .Returns(sponsoringOrgUser); - await sutProvider.Sut.PreValidateSponsorshipToken(sponsorshipToken); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RevokeSponsorship(sponsoringOrgUser.Id)); - await sutProvider.GetDependency().Received(1) - .ValidateRedemptionTokenAsync(sponsorshipToken, user.Email); - } + Assert.Contains("Can only revoke a sponsorship you granted.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .RemoveSponsorshipAsync(default); + } - [Theory] - [BitAutoData] - public async Task RevokeSponsorship_WrongSponsoringUser_ThrowsBadRequest(OrganizationUser sponsoringOrgUser, - Guid currentUserId, SutProvider sutProvider) - { - sutProvider.GetDependency().UserId.Returns(currentUserId); - sutProvider.GetDependency().GetByIdAsync(sponsoringOrgUser.Id) - .Returns(sponsoringOrgUser); + [Theory] + [BitAutoData] + public async Task RemoveSponsorship_WrongOrgUserType_ThrowsBadRequest(Organization sponsoredOrg, + SutProvider sutProvider) + { + sutProvider.GetDependency().OrganizationOwner(Arg.Any()).Returns(false); - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.RevokeSponsorship(sponsoringOrgUser.Id)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RemoveSponsorship(sponsoredOrg.Id)); - Assert.Contains("Can only revoke a sponsorship you granted.", exception.Message); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .RemoveSponsorshipAsync(default); - } - - [Theory] - [BitAutoData] - public async Task RemoveSponsorship_WrongOrgUserType_ThrowsBadRequest(Organization sponsoredOrg, - SutProvider sutProvider) - { - sutProvider.GetDependency().OrganizationOwner(Arg.Any()).Returns(false); - - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.RemoveSponsorship(sponsoredOrg.Id)); - - Assert.Contains("Only the owner of an organization can remove sponsorship.", exception.Message); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .RemoveSponsorshipAsync(default); - } + Assert.Contains("Only the owner of an organization can remove sponsorship.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .RemoveSponsorshipAsync(default); } } diff --git a/test/Api.Test/Controllers/OrganizationUsersControllerTests.cs b/test/Api.Test/Controllers/OrganizationUsersControllerTests.cs index 585508c66..c5c1019df 100644 --- a/test/Api.Test/Controllers/OrganizationUsersControllerTests.cs +++ b/test/Api.Test/Controllers/OrganizationUsersControllerTests.cs @@ -10,57 +10,56 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Api.Test.Controllers +namespace Bit.Api.Test.Controllers; + +[ControllerCustomize(typeof(OrganizationUsersController))] +[SutProviderCustomize] +public class OrganizationUsersControllerTests { - [ControllerCustomize(typeof(OrganizationUsersController))] - [SutProviderCustomize] - public class OrganizationUsersControllerTests + [Theory] + [BitAutoData] + public async Task Accept_RequiresKnownUser(Guid orgId, Guid orgUserId, OrganizationUserAcceptRequestModel model, + SutProvider sutProvider) { - [Theory] - [BitAutoData] - public async Task Accept_RequiresKnownUser(Guid orgId, Guid orgUserId, OrganizationUserAcceptRequestModel model, - SutProvider sutProvider) + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs((User)null); + + await Assert.ThrowsAsync(() => sutProvider.Sut.Accept(orgId, orgUserId, model)); + } + + [Theory] + [BitAutoData] + public async Task Accept_NoMasterPasswordReset(Guid orgId, Guid orgUserId, + OrganizationUserAcceptRequestModel model, User user, SutProvider sutProvider) + { + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(user); + + await sutProvider.Sut.Accept(orgId, orgUserId, model); + + await sutProvider.GetDependency().Received(1) + .AcceptUserAsync(orgUserId, user, model.Token, sutProvider.GetDependency()); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpdateUserResetPasswordEnrollmentAsync(default, default, default, default); + } + + [Theory] + [BitAutoData] + public async Task Accept_RequireMasterPasswordReset(Guid orgId, Guid orgUserId, + OrganizationUserAcceptRequestModel model, User user, SutProvider sutProvider) + { + var policy = new Policy { - sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs((User)null); + Enabled = true, + Data = CoreHelpers.ClassToJsonData(new ResetPasswordDataModel { AutoEnrollEnabled = true, }), + }; + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(user); + sutProvider.GetDependency().GetByOrganizationIdTypeAsync(orgId, + Core.Enums.PolicyType.ResetPassword).Returns(policy); - await Assert.ThrowsAsync(() => sutProvider.Sut.Accept(orgId, orgUserId, model)); - } + await sutProvider.Sut.Accept(orgId, orgUserId, model); - [Theory] - [BitAutoData] - public async Task Accept_NoMasterPasswordReset(Guid orgId, Guid orgUserId, - OrganizationUserAcceptRequestModel model, User user, SutProvider sutProvider) - { - sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(user); - - await sutProvider.Sut.Accept(orgId, orgUserId, model); - - await sutProvider.GetDependency().Received(1) - .AcceptUserAsync(orgUserId, user, model.Token, sutProvider.GetDependency()); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpdateUserResetPasswordEnrollmentAsync(default, default, default, default); - } - - [Theory] - [BitAutoData] - public async Task Accept_RequireMasterPasswordReset(Guid orgId, Guid orgUserId, - OrganizationUserAcceptRequestModel model, User user, SutProvider sutProvider) - { - var policy = new Policy - { - Enabled = true, - Data = CoreHelpers.ClassToJsonData(new ResetPasswordDataModel { AutoEnrollEnabled = true, }), - }; - sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(user); - sutProvider.GetDependency().GetByOrganizationIdTypeAsync(orgId, - Core.Enums.PolicyType.ResetPassword).Returns(policy); - - await sutProvider.Sut.Accept(orgId, orgUserId, model); - - await sutProvider.GetDependency().Received(1) - .AcceptUserAsync(orgUserId, user, model.Token, sutProvider.GetDependency()); - await sutProvider.GetDependency().Received(1) - .UpdateUserResetPasswordEnrollmentAsync(orgId, user.Id, model.ResetPasswordKey, user.Id); - } + await sutProvider.GetDependency().Received(1) + .AcceptUserAsync(orgUserId, user, model.Token, sutProvider.GetDependency()); + await sutProvider.GetDependency().Received(1) + .UpdateUserResetPasswordEnrollmentAsync(orgId, user.Id, model.ResetPasswordKey, user.Id); } } diff --git a/test/Api.Test/Controllers/OrganizationsControllerTests.cs b/test/Api.Test/Controllers/OrganizationsControllerTests.cs index f31c31927..dddd9c5f0 100644 --- a/test/Api.Test/Controllers/OrganizationsControllerTests.cs +++ b/test/Api.Test/Controllers/OrganizationsControllerTests.cs @@ -12,109 +12,108 @@ using Bit.Core.Settings; using NSubstitute; using Xunit; -namespace Bit.Api.Test.Controllers +namespace Bit.Api.Test.Controllers; + +public class OrganizationsControllerTests : IDisposable { - public class OrganizationsControllerTests : IDisposable + private readonly GlobalSettings _globalSettings; + private readonly ICurrentContext _currentContext; + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationService _organizationService; + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IPaymentService _paymentService; + private readonly IPolicyRepository _policyRepository; + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly ISsoConfigService _ssoConfigService; + private readonly IUserService _userService; + private readonly IGetOrganizationApiKeyCommand _getOrganizationApiKeyCommand; + private readonly IRotateOrganizationApiKeyCommand _rotateOrganizationApiKeyCommand; + private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; + + private readonly OrganizationsController _sut; + + public OrganizationsControllerTests() { - private readonly GlobalSettings _globalSettings; - private readonly ICurrentContext _currentContext; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationService _organizationService; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPaymentService _paymentService; - private readonly IPolicyRepository _policyRepository; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly ISsoConfigService _ssoConfigService; - private readonly IUserService _userService; - private readonly IGetOrganizationApiKeyCommand _getOrganizationApiKeyCommand; - private readonly IRotateOrganizationApiKeyCommand _rotateOrganizationApiKeyCommand; - private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; + _currentContext = Substitute.For(); + _globalSettings = Substitute.For(); + _organizationRepository = Substitute.For(); + _organizationService = Substitute.For(); + _organizationUserRepository = Substitute.For(); + _paymentService = Substitute.For(); + _policyRepository = Substitute.For(); + _ssoConfigRepository = Substitute.For(); + _ssoConfigService = Substitute.For(); + _getOrganizationApiKeyCommand = Substitute.For(); + _rotateOrganizationApiKeyCommand = Substitute.For(); + _organizationApiKeyRepository = Substitute.For(); + _userService = Substitute.For(); - private readonly OrganizationsController _sut; + _sut = new OrganizationsController(_organizationRepository, _organizationUserRepository, + _policyRepository, _organizationService, _userService, _paymentService, _currentContext, + _ssoConfigRepository, _ssoConfigService, _getOrganizationApiKeyCommand, _rotateOrganizationApiKeyCommand, + _organizationApiKeyRepository, _globalSettings); + } - public OrganizationsControllerTests() + public void Dispose() + { + _sut?.Dispose(); + } + + [Theory, AutoData] + public async Task OrganizationsController_UserCannotLeaveOrganizationThatProvidesKeyConnector( + Guid orgId, User user) + { + var ssoConfig = new SsoConfig { - _currentContext = Substitute.For(); - _globalSettings = Substitute.For(); - _organizationRepository = Substitute.For(); - _organizationService = Substitute.For(); - _organizationUserRepository = Substitute.For(); - _paymentService = Substitute.For(); - _policyRepository = Substitute.For(); - _ssoConfigRepository = Substitute.For(); - _ssoConfigService = Substitute.For(); - _getOrganizationApiKeyCommand = Substitute.For(); - _rotateOrganizationApiKeyCommand = Substitute.For(); - _organizationApiKeyRepository = Substitute.For(); - _userService = Substitute.For(); - - _sut = new OrganizationsController(_organizationRepository, _organizationUserRepository, - _policyRepository, _organizationService, _userService, _paymentService, _currentContext, - _ssoConfigRepository, _ssoConfigService, _getOrganizationApiKeyCommand, _rotateOrganizationApiKeyCommand, - _organizationApiKeyRepository, _globalSettings); - } - - public void Dispose() - { - _sut?.Dispose(); - } - - [Theory, AutoData] - public async Task OrganizationsController_UserCannotLeaveOrganizationThatProvidesKeyConnector( - Guid orgId, User user) - { - var ssoConfig = new SsoConfig + Id = default, + Data = new SsoConfigurationData { - Id = default, - Data = new SsoConfigurationData - { - KeyConnectorEnabled = true, - }.Serialize(), - Enabled = true, - OrganizationId = orgId, - }; + KeyConnectorEnabled = true, + }.Serialize(), + Enabled = true, + OrganizationId = orgId, + }; - user.UsesKeyConnector = true; + user.UsesKeyConnector = true; - _currentContext.OrganizationUser(orgId).Returns(true); - _ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); - _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + _currentContext.OrganizationUser(orgId).Returns(true); + _ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); - var exception = await Assert.ThrowsAsync( - () => _sut.Leave(orgId.ToString())); + var exception = await Assert.ThrowsAsync( + () => _sut.Leave(orgId.ToString())); - Assert.Contains("Your organization's Single Sign-On settings prevent you from leaving.", - exception.Message); + Assert.Contains("Your organization's Single Sign-On settings prevent you from leaving.", + exception.Message); - await _organizationService.DidNotReceiveWithAnyArgs().DeleteUserAsync(default, default); - } + await _organizationService.DidNotReceiveWithAnyArgs().DeleteUserAsync(default, default); + } - [Theory] - [InlineAutoData(true, false)] - [InlineAutoData(false, true)] - [InlineAutoData(false, false)] - public async Task OrganizationsController_UserCanLeaveOrganizationThatDoesntProvideKeyConnector( - bool keyConnectorEnabled, bool userUsesKeyConnector, Guid orgId, User user) + [Theory] + [InlineAutoData(true, false)] + [InlineAutoData(false, true)] + [InlineAutoData(false, false)] + public async Task OrganizationsController_UserCanLeaveOrganizationThatDoesntProvideKeyConnector( + bool keyConnectorEnabled, bool userUsesKeyConnector, Guid orgId, User user) + { + var ssoConfig = new SsoConfig { - var ssoConfig = new SsoConfig + Id = default, + Data = new SsoConfigurationData { - Id = default, - Data = new SsoConfigurationData - { - KeyConnectorEnabled = keyConnectorEnabled, - }.Serialize(), - Enabled = true, - OrganizationId = orgId, - }; + KeyConnectorEnabled = keyConnectorEnabled, + }.Serialize(), + Enabled = true, + OrganizationId = orgId, + }; - user.UsesKeyConnector = userUsesKeyConnector; + user.UsesKeyConnector = userUsesKeyConnector; - _currentContext.OrganizationUser(orgId).Returns(true); - _ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); - _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + _currentContext.OrganizationUser(orgId).Returns(true); + _ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); - await _organizationService.DeleteUserAsync(orgId, user.Id); - await _organizationService.Received(1).DeleteUserAsync(orgId, user.Id); - } + await _organizationService.DeleteUserAsync(orgId, user.Id); + await _organizationService.Received(1).DeleteUserAsync(orgId, user.Id); } } diff --git a/test/Api.Test/Controllers/SendsControllerTests.cs b/test/Api.Test/Controllers/SendsControllerTests.cs index e86d95c55..07ca95a85 100644 --- a/test/Api.Test/Controllers/SendsControllerTests.cs +++ b/test/Api.Test/Controllers/SendsControllerTests.cs @@ -15,66 +15,65 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Api.Test.Controllers +namespace Bit.Api.Test.Controllers; + +public class SendsControllerTests : IDisposable { - public class SendsControllerTests : IDisposable + + private readonly SendsController _sut; + private readonly GlobalSettings _globalSettings; + private readonly IUserService _userService; + private readonly ISendRepository _sendRepository; + private readonly ISendService _sendService; + private readonly ISendFileStorageService _sendFileStorageService; + private readonly ILogger _logger; + private readonly ICurrentContext _currentContext; + + public SendsControllerTests() { + _userService = Substitute.For(); + _sendRepository = Substitute.For(); + _sendService = Substitute.For(); + _sendFileStorageService = Substitute.For(); + _globalSettings = new GlobalSettings(); + _logger = Substitute.For>(); + _currentContext = Substitute.For(); - private readonly SendsController _sut; - private readonly GlobalSettings _globalSettings; - private readonly IUserService _userService; - private readonly ISendRepository _sendRepository; - private readonly ISendService _sendService; - private readonly ISendFileStorageService _sendFileStorageService; - private readonly ILogger _logger; - private readonly ICurrentContext _currentContext; + _sut = new SendsController( + _sendRepository, + _userService, + _sendService, + _sendFileStorageService, + _logger, + _globalSettings, + _currentContext + ); + } - public SendsControllerTests() - { - _userService = Substitute.For(); - _sendRepository = Substitute.For(); - _sendService = Substitute.For(); - _sendFileStorageService = Substitute.For(); - _globalSettings = new GlobalSettings(); - _logger = Substitute.For>(); - _currentContext = Substitute.For(); + public void Dispose() + { + _sut?.Dispose(); + } - _sut = new SendsController( - _sendRepository, - _userService, - _sendService, - _sendFileStorageService, - _logger, - _globalSettings, - _currentContext - ); - } + [Theory, AutoData] + public async Task SendsController_WhenSendHidesEmail_CreatorIdentifierShouldBeNull( + Guid id, Send send, User user) + { + var accessId = CoreHelpers.Base64UrlEncode(id.ToByteArray()); - public void Dispose() - { - _sut?.Dispose(); - } + send.Id = default; + send.Type = SendType.Text; + send.Data = JsonSerializer.Serialize(new Dictionary()); + send.HideEmail = true; - [Theory, AutoData] - public async Task SendsController_WhenSendHidesEmail_CreatorIdentifierShouldBeNull( - Guid id, Send send, User user) - { - var accessId = CoreHelpers.Base64UrlEncode(id.ToByteArray()); + _sendService.AccessAsync(id, null).Returns((send, false, false)); + _userService.GetUserByIdAsync(Arg.Any()).Returns(user); - send.Id = default; - send.Type = SendType.Text; - send.Data = JsonSerializer.Serialize(new Dictionary()); - send.HideEmail = true; + var request = new SendAccessRequestModel(); + var actionResult = await _sut.Access(accessId, request); + var response = (actionResult as ObjectResult)?.Value as SendAccessResponseModel; - _sendService.AccessAsync(id, null).Returns((send, false, false)); - _userService.GetUserByIdAsync(Arg.Any()).Returns(user); - - var request = new SendAccessRequestModel(); - var actionResult = await _sut.Access(accessId, request); - var response = (actionResult as ObjectResult)?.Value as SendAccessResponseModel; - - Assert.NotNull(response); - Assert.Null(response.CreatorIdentifier); - } + Assert.NotNull(response); + Assert.Null(response.CreatorIdentifier); } } diff --git a/test/Api.Test/Models/Request/Accounts/PremiumRequestModelTests.cs b/test/Api.Test/Models/Request/Accounts/PremiumRequestModelTests.cs index 6d01e0a15..9f0870531 100644 --- a/test/Api.Test/Models/Request/Accounts/PremiumRequestModelTests.cs +++ b/test/Api.Test/Models/Request/Accounts/PremiumRequestModelTests.cs @@ -3,63 +3,62 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Http; using Xunit; -namespace Bit.Api.Test.Models.Request.Accounts +namespace Bit.Api.Test.Models.Request.Accounts; + +public class PremiumRequestModelTests { - public class PremiumRequestModelTests + public static IEnumerable GetValidateData() { - public static IEnumerable GetValidateData() - { - // 1. selfHosted - // 2. formFile - // 3. country - // 4. expected + // 1. selfHosted + // 2. formFile + // 3. country + // 4. expected - yield return new object[] { true, null, null, false }; - yield return new object[] { true, null, "US", false }; - yield return new object[] { true, new NotImplementedFormFile(), null, false }; - yield return new object[] { true, new NotImplementedFormFile(), "US", false }; + yield return new object[] { true, null, null, false }; + yield return new object[] { true, null, "US", false }; + yield return new object[] { true, new NotImplementedFormFile(), null, false }; + yield return new object[] { true, new NotImplementedFormFile(), "US", false }; - yield return new object[] { false, null, null, false }; - yield return new object[] { false, null, "US", true }; // Only true, cloud with null license AND a Country - yield return new object[] { false, new NotImplementedFormFile(), null, false }; - yield return new object[] { false, new NotImplementedFormFile(), "US", false }; - } - - [Theory] - [MemberData(nameof(GetValidateData))] - public void Validate_Success(bool selfHosted, IFormFile formFile, string country, bool expected) - { - var gs = new GlobalSettings - { - SelfHosted = selfHosted - }; - - var sut = new PremiumRequestModel - { - License = formFile, - Country = country, - }; - - Assert.Equal(expected, sut.Validate(gs)); - } + yield return new object[] { false, null, null, false }; + yield return new object[] { false, null, "US", true }; // Only true, cloud with null license AND a Country + yield return new object[] { false, new NotImplementedFormFile(), null, false }; + yield return new object[] { false, new NotImplementedFormFile(), "US", false }; } - public class NotImplementedFormFile : IFormFile + [Theory] + [MemberData(nameof(GetValidateData))] + public void Validate_Success(bool selfHosted, IFormFile formFile, string country, bool expected) { - public string ContentType => throw new NotImplementedException(); + var gs = new GlobalSettings + { + SelfHosted = selfHosted + }; - public string ContentDisposition => throw new NotImplementedException(); + var sut = new PremiumRequestModel + { + License = formFile, + Country = country, + }; - public IHeaderDictionary Headers => throw new NotImplementedException(); - - public long Length => throw new NotImplementedException(); - - public string Name => throw new NotImplementedException(); - - public string FileName => throw new NotImplementedException(); - - public void CopyTo(Stream target) => throw new NotImplementedException(); - public Task CopyToAsync(Stream target, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Stream OpenReadStream() => throw new NotImplementedException(); + Assert.Equal(expected, sut.Validate(gs)); } } + +public class NotImplementedFormFile : IFormFile +{ + public string ContentType => throw new NotImplementedException(); + + public string ContentDisposition => throw new NotImplementedException(); + + public IHeaderDictionary Headers => throw new NotImplementedException(); + + public long Length => throw new NotImplementedException(); + + public string Name => throw new NotImplementedException(); + + public string FileName => throw new NotImplementedException(); + + public void CopyTo(Stream target) => throw new NotImplementedException(); + public Task CopyToAsync(Stream target, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + public Stream OpenReadStream() => throw new NotImplementedException(); +} diff --git a/test/Api.Test/Models/Request/SendRequestModelTests.cs b/test/Api.Test/Models/Request/SendRequestModelTests.cs index ffcf043bd..7ad858d2e 100644 --- a/test/Api.Test/Models/Request/SendRequestModelTests.cs +++ b/test/Api.Test/Models/Request/SendRequestModelTests.cs @@ -7,54 +7,53 @@ using Bit.Test.Common.Helpers; using NSubstitute; using Xunit; -namespace Bit.Api.Test.Models.Request +namespace Bit.Api.Test.Models.Request; + +public class SendRequestModelTests { - public class SendRequestModelTests + [Fact] + public void ToSend_Text_Success() { - [Fact] - public void ToSend_Text_Success() + var deletionDate = DateTime.UtcNow.AddDays(5); + var sendRequest = new SendRequestModel { - var deletionDate = DateTime.UtcNow.AddDays(5); - var sendRequest = new SendRequestModel + DeletionDate = deletionDate, + Disabled = false, + ExpirationDate = null, + HideEmail = false, + Key = "encrypted_key", + MaxAccessCount = null, + Name = "encrypted_name", + Notes = null, + Password = "Password", + Text = new SendTextModel() { - DeletionDate = deletionDate, - Disabled = false, - ExpirationDate = null, - HideEmail = false, - Key = "encrypted_key", - MaxAccessCount = null, - Name = "encrypted_name", - Notes = null, - Password = "Password", - Text = new SendTextModel() - { - Hidden = false, - Text = "encrypted_text" - }, - Type = SendType.Text, - }; + Hidden = false, + Text = "encrypted_text" + }, + Type = SendType.Text, + }; - var sendService = Substitute.For(); - sendService.HashPassword(Arg.Any()) - .Returns((info) => $"hashed_{(string)info[0]}"); + var sendService = Substitute.For(); + sendService.HashPassword(Arg.Any()) + .Returns((info) => $"hashed_{(string)info[0]}"); - var send = sendRequest.ToSend(Guid.NewGuid(), sendService); + var send = sendRequest.ToSend(Guid.NewGuid(), sendService); - Assert.Equal(deletionDate, send.DeletionDate); - Assert.False(send.Disabled); - Assert.Null(send.ExpirationDate); - Assert.False(send.HideEmail); - Assert.Equal("encrypted_key", send.Key); - Assert.Equal("hashed_Password", send.Password); + Assert.Equal(deletionDate, send.DeletionDate); + Assert.False(send.Disabled); + Assert.Null(send.ExpirationDate); + Assert.False(send.HideEmail); + Assert.Equal("encrypted_key", send.Key); + Assert.Equal("hashed_Password", send.Password); - using var jsonDocument = JsonDocument.Parse(send.Data); - var root = jsonDocument.RootElement; - var text = AssertHelper.AssertJsonProperty(root, "Text", JsonValueKind.String).GetString(); - Assert.Equal("encrypted_text", text); - AssertHelper.AssertJsonProperty(root, "Hidden", JsonValueKind.False); - Assert.False(root.TryGetProperty("Notes", out var _)); - var name = AssertHelper.AssertJsonProperty(root, "Name", JsonValueKind.String).GetString(); - Assert.Equal("encrypted_name", name); - } + using var jsonDocument = JsonDocument.Parse(send.Data); + var root = jsonDocument.RootElement; + var text = AssertHelper.AssertJsonProperty(root, "Text", JsonValueKind.String).GetString(); + Assert.Equal("encrypted_text", text); + AssertHelper.AssertJsonProperty(root, "Hidden", JsonValueKind.False); + Assert.False(root.TryGetProperty("Notes", out var _)); + var name = AssertHelper.AssertJsonProperty(root, "Name", JsonValueKind.String).GetString(); + Assert.Equal("encrypted_name", name); } } diff --git a/test/Api.Test/Utilities/ApiHelpersTests.cs b/test/Api.Test/Utilities/ApiHelpersTests.cs index 718ec2eeb..4013a2222 100644 --- a/test/Api.Test/Utilities/ApiHelpersTests.cs +++ b/test/Api.Test/Utilities/ApiHelpersTests.cs @@ -5,23 +5,22 @@ using Microsoft.AspNetCore.Http; using NSubstitute; using Xunit; -namespace Bit.Api.Test.Utilities +namespace Bit.Api.Test.Utilities; + +public class ApiHelpersTests { - public class ApiHelpersTests + [Fact] + public async Task ReadJsonFileFromBody_Success() { - [Fact] - public async Task ReadJsonFileFromBody_Success() - { - var context = Substitute.For(); - context.Request.ContentLength.Returns(200); - var bytes = Encoding.UTF8.GetBytes(testFile); - var formFile = new FormFile(new MemoryStream(bytes), 0, bytes.Length, "bitwarden_organization_license", "bitwarden_organization_license.json"); + var context = Substitute.For(); + context.Request.ContentLength.Returns(200); + var bytes = Encoding.UTF8.GetBytes(testFile); + var formFile = new FormFile(new MemoryStream(bytes), 0, bytes.Length, "bitwarden_organization_license", "bitwarden_organization_license.json"); - var license = await ApiHelpers.ReadJsonFileFromBody(context, formFile); - Assert.Equal(8, license.Version); - } - - const string testFile = "{\"licenseKey\": \"licenseKey\", \"installationId\": \"6285f891-b2ec-4047-84c5-2eb7f7747e74\", \"id\": \"1065216d-5854-4326-838d-635487f30b43\",\"name\": \"Test Org\",\"billingEmail\": \"test@email.com\",\"businessName\": null,\"enabled\": true, \"plan\": \"Enterprise (Annually)\",\"planType\": 11,\"seats\": 6,\"maxCollections\": null,\"usePolicies\": true,\"useSso\": true,\"useKeyConnector\": false,\"useGroups\": true,\"useEvents\": true,\"useDirectory\": true,\"useTotp\": true,\"use2fa\": true,\"useApi\": true,\"useResetPassword\": true,\"maxStorageGb\": 1,\"selfHost\": true,\"usersGetPremium\": true,\"version\": 8,\"issued\": \"2022-01-25T21:58:38.9454581Z\",\"refresh\": \"2022-01-28T14:26:31Z\",\"expires\": \"2022-01-28T14:26:31Z\",\"trial\": true,\"hash\": \"testvalue\",\"signature\": \"signature\"}"; + var license = await ApiHelpers.ReadJsonFileFromBody(context, formFile); + Assert.Equal(8, license.Version); } + + const string testFile = "{\"licenseKey\": \"licenseKey\", \"installationId\": \"6285f891-b2ec-4047-84c5-2eb7f7747e74\", \"id\": \"1065216d-5854-4326-838d-635487f30b43\",\"name\": \"Test Org\",\"billingEmail\": \"test@email.com\",\"businessName\": null,\"enabled\": true, \"plan\": \"Enterprise (Annually)\",\"planType\": 11,\"seats\": 6,\"maxCollections\": null,\"usePolicies\": true,\"useSso\": true,\"useKeyConnector\": false,\"useGroups\": true,\"useEvents\": true,\"useDirectory\": true,\"useTotp\": true,\"use2fa\": true,\"useApi\": true,\"useResetPassword\": true,\"maxStorageGb\": 1,\"selfHost\": true,\"usersGetPremium\": true,\"version\": 8,\"issued\": \"2022-01-25T21:58:38.9454581Z\",\"refresh\": \"2022-01-28T14:26:31Z\",\"expires\": \"2022-01-28T14:26:31Z\",\"trial\": true,\"hash\": \"testvalue\",\"signature\": \"signature\"}"; } diff --git a/test/Billing.Test/Controllers/FreshdeskControllerTests.cs b/test/Billing.Test/Controllers/FreshdeskControllerTests.cs index 368896002..94f9e2849 100644 --- a/test/Billing.Test/Controllers/FreshdeskControllerTests.cs +++ b/test/Billing.Test/Controllers/FreshdeskControllerTests.cs @@ -10,71 +10,70 @@ using Microsoft.Extensions.Options; using NSubstitute; using Xunit; -namespace Bit.Billing.Test.Controllers +namespace Bit.Billing.Test.Controllers; + +[ControllerCustomize(typeof(FreshdeskController))] +[SutProviderCustomize] +public class FreshdeskControllerTests { - [ControllerCustomize(typeof(FreshdeskController))] - [SutProviderCustomize] - public class FreshdeskControllerTests + private const string ApiKey = "TESTFRESHDESKAPIKEY"; + private const string WebhookKey = "TESTKEY"; + + [Theory] + [BitAutoData((string)null, null)] + [BitAutoData((string)null)] + [BitAutoData(WebhookKey, null)] + public async Task PostWebhook_NullRequiredParameters_BadRequest(string freshdeskWebhookKey, FreshdeskWebhookModel model, + BillingSettings billingSettings, SutProvider sutProvider) { - private const string ApiKey = "TESTFRESHDESKAPIKEY"; - private const string WebhookKey = "TESTKEY"; + sutProvider.GetDependency>().Value.FreshdeskWebhookKey.Returns(billingSettings.FreshdeskWebhookKey); - [Theory] - [BitAutoData((string)null, null)] - [BitAutoData((string)null)] - [BitAutoData(WebhookKey, null)] - public async Task PostWebhook_NullRequiredParameters_BadRequest(string freshdeskWebhookKey, FreshdeskWebhookModel model, - BillingSettings billingSettings, SutProvider sutProvider) + var response = await sutProvider.Sut.PostWebhook(freshdeskWebhookKey, model); + + var statusCodeResult = Assert.IsAssignableFrom(response); + Assert.Equal(StatusCodes.Status400BadRequest, statusCodeResult.StatusCode); + } + + [Theory] + [BitAutoData] + public async Task PostWebhook_Success(User user, FreshdeskWebhookModel model, + List organizations, SutProvider sutProvider) + { + model.TicketContactEmail = user.Email; + + sutProvider.GetDependency().GetByEmailAsync(user.Email).Returns(user); + sutProvider.GetDependency().GetManyByUserIdAsync(user.Id).Returns(organizations); + + var mockHttpMessageHandler = Substitute.ForPartsOf(); + var mockResponse = new HttpResponseMessage(System.Net.HttpStatusCode.OK); + mockHttpMessageHandler.Send(Arg.Any(), Arg.Any()) + .Returns(mockResponse); + var httpClient = new HttpClient(mockHttpMessageHandler); + + sutProvider.GetDependency().CreateClient("FreshdeskApi").Returns(httpClient); + + sutProvider.GetDependency>().Value.FreshdeskWebhookKey.Returns(WebhookKey); + sutProvider.GetDependency>().Value.FreshdeskApiKey.Returns(ApiKey); + + var response = await sutProvider.Sut.PostWebhook(WebhookKey, model); + + var statusCodeResult = Assert.IsAssignableFrom(response); + Assert.Equal(StatusCodes.Status200OK, statusCodeResult.StatusCode); + + _ = mockHttpMessageHandler.Received(1).Send(Arg.Is(m => m.Method == HttpMethod.Put && m.RequestUri.ToString().EndsWith(model.TicketId)), Arg.Any()); + _ = mockHttpMessageHandler.Received(1).Send(Arg.Is(m => m.Method == HttpMethod.Post && m.RequestUri.ToString().EndsWith($"{model.TicketId}/notes")), Arg.Any()); + } + + public class MockHttpMessageHandler : HttpMessageHandler + { + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { - sutProvider.GetDependency>().Value.FreshdeskWebhookKey.Returns(billingSettings.FreshdeskWebhookKey); - - var response = await sutProvider.Sut.PostWebhook(freshdeskWebhookKey, model); - - var statusCodeResult = Assert.IsAssignableFrom(response); - Assert.Equal(StatusCodes.Status400BadRequest, statusCodeResult.StatusCode); + return Send(request, cancellationToken); } - [Theory] - [BitAutoData] - public async Task PostWebhook_Success(User user, FreshdeskWebhookModel model, - List organizations, SutProvider sutProvider) + public virtual Task Send(HttpRequestMessage request, CancellationToken cancellationToken) { - model.TicketContactEmail = user.Email; - - sutProvider.GetDependency().GetByEmailAsync(user.Email).Returns(user); - sutProvider.GetDependency().GetManyByUserIdAsync(user.Id).Returns(organizations); - - var mockHttpMessageHandler = Substitute.ForPartsOf(); - var mockResponse = new HttpResponseMessage(System.Net.HttpStatusCode.OK); - mockHttpMessageHandler.Send(Arg.Any(), Arg.Any()) - .Returns(mockResponse); - var httpClient = new HttpClient(mockHttpMessageHandler); - - sutProvider.GetDependency().CreateClient("FreshdeskApi").Returns(httpClient); - - sutProvider.GetDependency>().Value.FreshdeskWebhookKey.Returns(WebhookKey); - sutProvider.GetDependency>().Value.FreshdeskApiKey.Returns(ApiKey); - - var response = await sutProvider.Sut.PostWebhook(WebhookKey, model); - - var statusCodeResult = Assert.IsAssignableFrom(response); - Assert.Equal(StatusCodes.Status200OK, statusCodeResult.StatusCode); - - _ = mockHttpMessageHandler.Received(1).Send(Arg.Is(m => m.Method == HttpMethod.Put && m.RequestUri.ToString().EndsWith(model.TicketId)), Arg.Any()); - _ = mockHttpMessageHandler.Received(1).Send(Arg.Is(m => m.Method == HttpMethod.Post && m.RequestUri.ToString().EndsWith($"{model.TicketId}/notes")), Arg.Any()); - } - - public class MockHttpMessageHandler : HttpMessageHandler - { - protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) - { - return Send(request, cancellationToken); - } - - public virtual Task Send(HttpRequestMessage request, CancellationToken cancellationToken) - { - throw new NotImplementedException(); - } + throw new NotImplementedException(); } } } diff --git a/test/Billing.Test/Controllers/FreshsalesControllerTests.cs b/test/Billing.Test/Controllers/FreshsalesControllerTests.cs index 490f4051a..3a5cf3bf1 100644 --- a/test/Billing.Test/Controllers/FreshsalesControllerTests.cs +++ b/test/Billing.Test/Controllers/FreshsalesControllerTests.cs @@ -10,73 +10,72 @@ using Microsoft.Extensions.Options; using NSubstitute; using Xunit; -namespace Bit.Billing.Test.Controllers +namespace Bit.Billing.Test.Controllers; + +public class FreshsalesControllerTests { - public class FreshsalesControllerTests + private const string ApiKey = "TEST_FRESHSALES_APIKEY"; + private const string TestLead = "TEST_FRESHSALES_TESTLEAD"; + + private static (FreshsalesController, IUserRepository, IOrganizationRepository) CreateSut( + string freshsalesApiKey) { - private const string ApiKey = "TEST_FRESHSALES_APIKEY"; - private const string TestLead = "TEST_FRESHSALES_TESTLEAD"; + var userRepository = Substitute.For(); + var organizationRepository = Substitute.For(); - private static (FreshsalesController, IUserRepository, IOrganizationRepository) CreateSut( - string freshsalesApiKey) + var billingSettings = Options.Create(new BillingSettings { - var userRepository = Substitute.For(); - var organizationRepository = Substitute.For(); + FreshsalesApiKey = freshsalesApiKey, + }); + var globalSettings = new GlobalSettings(); + globalSettings.BaseServiceUri.Admin = "https://test.com"; - var billingSettings = Options.Create(new BillingSettings - { - FreshsalesApiKey = freshsalesApiKey, - }); - var globalSettings = new GlobalSettings(); - globalSettings.BaseServiceUri.Admin = "https://test.com"; + var sut = new FreshsalesController( + userRepository, + organizationRepository, + billingSettings, + Substitute.For>(), + globalSettings + ); - var sut = new FreshsalesController( - userRepository, - organizationRepository, - billingSettings, - Substitute.For>(), - globalSettings - ); + return (sut, userRepository, organizationRepository); + } - return (sut, userRepository, organizationRepository); - } + [RequiredEnvironmentTheory(ApiKey, TestLead), EnvironmentData(ApiKey, TestLead)] + public async Task PostWebhook_Success(string freshsalesApiKey, long leadId) + { + // This test is only for development to use: + // `export TEST_FRESHSALES_APIKEY=[apikey]` + // `export TEST_FRESHSALES_TESTLEAD=[lead id]` + // `dotnet test --filter "FullyQualifiedName~FreshsalesControllerTests.PostWebhook_Success"` + var (sut, userRepository, organizationRepository) = CreateSut(freshsalesApiKey); - [RequiredEnvironmentTheory(ApiKey, TestLead), EnvironmentData(ApiKey, TestLead)] - public async Task PostWebhook_Success(string freshsalesApiKey, long leadId) + var user = new User { - // This test is only for development to use: - // `export TEST_FRESHSALES_APIKEY=[apikey]` - // `export TEST_FRESHSALES_TESTLEAD=[lead id]` - // `dotnet test --filter "FullyQualifiedName~FreshsalesControllerTests.PostWebhook_Success"` - var (sut, userRepository, organizationRepository) = CreateSut(freshsalesApiKey); + Id = Guid.NewGuid(), + Email = "test@email.com", + Premium = true, + }; - var user = new User + userRepository.GetByEmailAsync(user.Email) + .Returns(user); + + organizationRepository.GetManyByUserIdAsync(user.Id) + .Returns(new List { - Id = Guid.NewGuid(), - Email = "test@email.com", - Premium = true, - }; - - userRepository.GetByEmailAsync(user.Email) - .Returns(user); - - organizationRepository.GetManyByUserIdAsync(user.Id) - .Returns(new List + new Organization { - new Organization - { - Id = Guid.NewGuid(), - Name = "Test Org", - } - }); + Id = Guid.NewGuid(), + Name = "Test Org", + } + }); - var response = await sut.PostWebhook(freshsalesApiKey, new CustomWebhookRequestModel - { - LeadId = leadId, - }, new CancellationToken(false)); + var response = await sut.PostWebhook(freshsalesApiKey, new CustomWebhookRequestModel + { + LeadId = leadId, + }, new CancellationToken(false)); - var statusCodeResult = Assert.IsAssignableFrom(response); - Assert.Equal(StatusCodes.Status204NoContent, statusCodeResult.StatusCode); - } + var statusCodeResult = Assert.IsAssignableFrom(response); + Assert.Equal(StatusCodes.Status204NoContent, statusCodeResult.StatusCode); } } diff --git a/test/Common/AutoFixture/Attributes/BitAutoDataAttribute.cs b/test/Common/AutoFixture/Attributes/BitAutoDataAttribute.cs index 7185468ea..d859f81fc 100644 --- a/test/Common/AutoFixture/Attributes/BitAutoDataAttribute.cs +++ b/test/Common/AutoFixture/Attributes/BitAutoDataAttribute.cs @@ -3,26 +3,25 @@ using AutoFixture; using Bit.Test.Common.Helpers; using Xunit.Sdk; -namespace Bit.Test.Common.AutoFixture.Attributes +namespace Bit.Test.Common.AutoFixture.Attributes; + +[DataDiscoverer("AutoFixture.Xunit2.NoPreDiscoveryDataDiscoverer", "AutoFixture.Xunit2")] +public class BitAutoDataAttribute : DataAttribute { - [DataDiscoverer("AutoFixture.Xunit2.NoPreDiscoveryDataDiscoverer", "AutoFixture.Xunit2")] - public class BitAutoDataAttribute : DataAttribute + private readonly Func _createFixture; + private readonly object[] _fixedTestParameters; + + public BitAutoDataAttribute(params object[] fixedTestParameters) : + this(() => new Fixture(), fixedTestParameters) + { } + + public BitAutoDataAttribute(Func createFixture, params object[] fixedTestParameters) : + base() { - private readonly Func _createFixture; - private readonly object[] _fixedTestParameters; - - public BitAutoDataAttribute(params object[] fixedTestParameters) : - this(() => new Fixture(), fixedTestParameters) - { } - - public BitAutoDataAttribute(Func createFixture, params object[] fixedTestParameters) : - base() - { - _createFixture = createFixture; - _fixedTestParameters = fixedTestParameters; - } - - public override IEnumerable GetData(MethodInfo testMethod) - => BitAutoDataAttributeHelpers.GetData(testMethod, _createFixture(), _fixedTestParameters); + _createFixture = createFixture; + _fixedTestParameters = fixedTestParameters; } + + public override IEnumerable GetData(MethodInfo testMethod) + => BitAutoDataAttributeHelpers.GetData(testMethod, _createFixture(), _fixedTestParameters); } diff --git a/test/Common/AutoFixture/Attributes/BitCustomizeAttribute.cs b/test/Common/AutoFixture/Attributes/BitCustomizeAttribute.cs index 9b9a5142d..105a6632d 100644 --- a/test/Common/AutoFixture/Attributes/BitCustomizeAttribute.cs +++ b/test/Common/AutoFixture/Attributes/BitCustomizeAttribute.cs @@ -1,21 +1,20 @@ using AutoFixture; -namespace Bit.Test.Common.AutoFixture.Attributes +namespace Bit.Test.Common.AutoFixture.Attributes; + +/// +/// +/// Base class for customizing parameters in methods decorated with the +/// Bit.Test.Common.AutoFixture.Attributes.MemberAutoDataAttribute. +/// +/// ⚠ Warning ⚠ Will not insert customizations into AutoFixture's AutoDataAttribute build chain +/// +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method | AttributeTargets.Parameter, AllowMultiple = true)] +public abstract class BitCustomizeAttribute : Attribute { /// - /// - /// Base class for customizing parameters in methods decorated with the - /// Bit.Test.Common.AutoFixture.Attributes.MemberAutoDataAttribute. - /// - /// ⚠ Warning ⚠ Will not insert customizations into AutoFixture's AutoDataAttribute build chain + /// /// Gets a customization for the method's parameters. /// - [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method | AttributeTargets.Parameter, AllowMultiple = true)] - public abstract class BitCustomizeAttribute : Attribute - { - /// - /// /// Gets a customization for the method's parameters. - /// - /// A customization for the method's paramters. - public abstract ICustomization GetCustomization(); - } + /// A customization for the method's paramters. + public abstract ICustomization GetCustomization(); } diff --git a/test/Common/AutoFixture/Attributes/BitMemberAutoDataAttribute.cs b/test/Common/AutoFixture/Attributes/BitMemberAutoDataAttribute.cs index e9604e1c9..7e6f81c30 100644 --- a/test/Common/AutoFixture/Attributes/BitMemberAutoDataAttribute.cs +++ b/test/Common/AutoFixture/Attributes/BitMemberAutoDataAttribute.cs @@ -3,23 +3,22 @@ using AutoFixture; using Bit.Test.Common.Helpers; using Xunit; -namespace Bit.Test.Common.AutoFixture.Attributes +namespace Bit.Test.Common.AutoFixture.Attributes; + +public class BitMemberAutoDataAttribute : MemberDataAttributeBase { - public class BitMemberAutoDataAttribute : MemberDataAttributeBase + private readonly Func _createFixture; + + public BitMemberAutoDataAttribute(string memberName, params object[] parameters) : + this(() => new Fixture(), memberName, parameters) + { } + + public BitMemberAutoDataAttribute(Func createFixture, string memberName, params object[] parameters) : + base(memberName, parameters) { - private readonly Func _createFixture; - - public BitMemberAutoDataAttribute(string memberName, params object[] parameters) : - this(() => new Fixture(), memberName, parameters) - { } - - public BitMemberAutoDataAttribute(Func createFixture, string memberName, params object[] parameters) : - base(memberName, parameters) - { - _createFixture = createFixture; - } - - protected override object[] ConvertDataItem(MethodInfo testMethod, object item) => - BitAutoDataAttributeHelpers.GetData(testMethod, _createFixture(), item as object[]).First(); + _createFixture = createFixture; } + + protected override object[] ConvertDataItem(MethodInfo testMethod, object item) => + BitAutoDataAttributeHelpers.GetData(testMethod, _createFixture(), item as object[]).First(); } diff --git a/test/Common/AutoFixture/Attributes/ControllerCustomizeAttribute.cs b/test/Common/AutoFixture/Attributes/ControllerCustomizeAttribute.cs index 6cab60bae..7627562b7 100644 --- a/test/Common/AutoFixture/Attributes/ControllerCustomizeAttribute.cs +++ b/test/Common/AutoFixture/Attributes/ControllerCustomizeAttribute.cs @@ -1,23 +1,22 @@ using AutoFixture; -namespace Bit.Test.Common.AutoFixture.Attributes +namespace Bit.Test.Common.AutoFixture.Attributes; + +/// +/// Disables setting of Auto Properties on the Controller to avoid ASP.net initialization errors from a mock environment. Still sets constructor dependencies. +/// +public class ControllerCustomizeAttribute : BitCustomizeAttribute { + private readonly Type _controllerType; + /// - /// Disables setting of Auto Properties on the Controller to avoid ASP.net initialization errors from a mock environment. Still sets constructor dependencies. + /// Initialize an instance of the ControllerCustomizeAttribute class /// - public class ControllerCustomizeAttribute : BitCustomizeAttribute + /// The Type of the controller to allow autofixture to create + public ControllerCustomizeAttribute(Type controllerType) { - private readonly Type _controllerType; - - /// - /// Initialize an instance of the ControllerCustomizeAttribute class - /// - /// The Type of the controller to allow autofixture to create - public ControllerCustomizeAttribute(Type controllerType) - { - _controllerType = controllerType; - } - - public override ICustomization GetCustomization() => new ControllerCustomization(_controllerType); + _controllerType = controllerType; } + + public override ICustomization GetCustomization() => new ControllerCustomization(_controllerType); } diff --git a/test/Common/AutoFixture/Attributes/CustomAutoDataAttribute.cs b/test/Common/AutoFixture/Attributes/CustomAutoDataAttribute.cs index 75308e448..6aac53ca3 100644 --- a/test/Common/AutoFixture/Attributes/CustomAutoDataAttribute.cs +++ b/test/Common/AutoFixture/Attributes/CustomAutoDataAttribute.cs @@ -1,23 +1,22 @@ using AutoFixture; using AutoFixture.Xunit2; -namespace Bit.Test.Common.AutoFixture.Attributes -{ - public class CustomAutoDataAttribute : AutoDataAttribute - { - public CustomAutoDataAttribute(params Type[] iCustomizationTypes) : this(iCustomizationTypes - .Select(t => (ICustomization)Activator.CreateInstance(t)).ToArray()) - { } +namespace Bit.Test.Common.AutoFixture.Attributes; - public CustomAutoDataAttribute(params ICustomization[] customizations) : base(() => +public class CustomAutoDataAttribute : AutoDataAttribute +{ + public CustomAutoDataAttribute(params Type[] iCustomizationTypes) : this(iCustomizationTypes + .Select(t => (ICustomization)Activator.CreateInstance(t)).ToArray()) + { } + + public CustomAutoDataAttribute(params ICustomization[] customizations) : base(() => + { + var fixture = new Fixture(); + foreach (var customization in customizations) { - var fixture = new Fixture(); - foreach (var customization in customizations) - { - fixture.Customize(customization); - } - return fixture; - }) - { } - } + fixture.Customize(customization); + } + return fixture; + }) + { } } diff --git a/test/Common/AutoFixture/Attributes/EnvironmentDataAttribute.cs b/test/Common/AutoFixture/Attributes/EnvironmentDataAttribute.cs index 5479d766d..acdf737be 100644 --- a/test/Common/AutoFixture/Attributes/EnvironmentDataAttribute.cs +++ b/test/Common/AutoFixture/Attributes/EnvironmentDataAttribute.cs @@ -1,43 +1,42 @@ using System.Reflection; using Xunit.Sdk; -namespace Bit.Test.Common.AutoFixture.Attributes +namespace Bit.Test.Common.AutoFixture.Attributes; + +/// +/// Used for collecting data from environment useful for when we want to test an integration with another service and +/// it might require an api key or other piece of sensitive data that we don't want slipping into the wrong hands. +/// +/// +/// It probably should be refactored to support fixtures and other customization so it can more easily be used in conjunction +/// with more parameters. Currently it attempt to match environment variable names to values of the parameter type in that positions. +/// It will start from the first parameter and go for each supplied name. +/// +public class EnvironmentDataAttribute : DataAttribute { - /// - /// Used for collecting data from environment useful for when we want to test an integration with another service and - /// it might require an api key or other piece of sensitive data that we don't want slipping into the wrong hands. - /// - /// - /// It probably should be refactored to support fixtures and other customization so it can more easily be used in conjunction - /// with more parameters. Currently it attempt to match environment variable names to values of the parameter type in that positions. - /// It will start from the first parameter and go for each supplied name. - /// - public class EnvironmentDataAttribute : DataAttribute + private readonly string[] _environmentVariableNames; + + public EnvironmentDataAttribute(params string[] environmentVariableNames) { - private readonly string[] _environmentVariableNames; + _environmentVariableNames = environmentVariableNames; + } - public EnvironmentDataAttribute(params string[] environmentVariableNames) + public override IEnumerable GetData(MethodInfo testMethod) + { + var methodParameters = testMethod.GetParameters(); + + if (methodParameters.Length < _environmentVariableNames.Length) { - _environmentVariableNames = environmentVariableNames; + throw new ArgumentException($"The target test method only has {methodParameters.Length} arguments but you supplied {_environmentVariableNames.Length}"); } - public override IEnumerable GetData(MethodInfo testMethod) + var values = new object[_environmentVariableNames.Length]; + + for (var i = 0; i < _environmentVariableNames.Length; i++) { - var methodParameters = testMethod.GetParameters(); - - if (methodParameters.Length < _environmentVariableNames.Length) - { - throw new ArgumentException($"The target test method only has {methodParameters.Length} arguments but you supplied {_environmentVariableNames.Length}"); - } - - var values = new object[_environmentVariableNames.Length]; - - for (var i = 0; i < _environmentVariableNames.Length; i++) - { - values[i] = Convert.ChangeType(Environment.GetEnvironmentVariable(_environmentVariableNames[i]), methodParameters[i].ParameterType); - } - - return new[] { values }; + values[i] = Convert.ChangeType(Environment.GetEnvironmentVariable(_environmentVariableNames[i]), methodParameters[i].ParameterType); } + + return new[] { values }; } } diff --git a/test/Common/AutoFixture/Attributes/InlineCustomAutoDataAttribute.cs b/test/Common/AutoFixture/Attributes/InlineCustomAutoDataAttribute.cs index b8c27f746..fb16d2f90 100644 --- a/test/Common/AutoFixture/Attributes/InlineCustomAutoDataAttribute.cs +++ b/test/Common/AutoFixture/Attributes/InlineCustomAutoDataAttribute.cs @@ -3,20 +3,19 @@ using AutoFixture.Xunit2; using Xunit; using Xunit.Sdk; -namespace Bit.Test.Common.AutoFixture.Attributes -{ - public class InlineCustomAutoDataAttribute : CompositeDataAttribute - { - public InlineCustomAutoDataAttribute(Type[] iCustomizationTypes, params object[] values) : base(new DataAttribute[] { - new InlineDataAttribute(values), - new CustomAutoDataAttribute(iCustomizationTypes) - }) - { } +namespace Bit.Test.Common.AutoFixture.Attributes; - public InlineCustomAutoDataAttribute(ICustomization[] customizations, params object[] values) : base(new DataAttribute[] { - new InlineDataAttribute(values), - new CustomAutoDataAttribute(customizations) - }) - { } - } +public class InlineCustomAutoDataAttribute : CompositeDataAttribute +{ + public InlineCustomAutoDataAttribute(Type[] iCustomizationTypes, params object[] values) : base(new DataAttribute[] { + new InlineDataAttribute(values), + new CustomAutoDataAttribute(iCustomizationTypes) + }) + { } + + public InlineCustomAutoDataAttribute(ICustomization[] customizations, params object[] values) : base(new DataAttribute[] { + new InlineDataAttribute(values), + new CustomAutoDataAttribute(customizations) + }) + { } } diff --git a/test/Common/AutoFixture/Attributes/InlineSutAutoDataAttribute.cs b/test/Common/AutoFixture/Attributes/InlineSutAutoDataAttribute.cs index b2709a330..ae32b476c 100644 --- a/test/Common/AutoFixture/Attributes/InlineSutAutoDataAttribute.cs +++ b/test/Common/AutoFixture/Attributes/InlineSutAutoDataAttribute.cs @@ -1,18 +1,17 @@ using AutoFixture; -namespace Bit.Test.Common.AutoFixture.Attributes -{ - public class InlineSutAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineSutAutoDataAttribute(params object[] values) : base( - new Type[] { typeof(SutProviderCustomization) }, values) - { } - public InlineSutAutoDataAttribute(Type[] iCustomizationTypes, params object[] values) : base( - iCustomizationTypes.Append(typeof(SutProviderCustomization)).ToArray(), values) - { } +namespace Bit.Test.Common.AutoFixture.Attributes; - public InlineSutAutoDataAttribute(ICustomization[] customizations, params object[] values) : base( - customizations.Append(new SutProviderCustomization()).ToArray(), values) - { } - } +public class InlineSutAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineSutAutoDataAttribute(params object[] values) : base( + new Type[] { typeof(SutProviderCustomization) }, values) + { } + public InlineSutAutoDataAttribute(Type[] iCustomizationTypes, params object[] values) : base( + iCustomizationTypes.Append(typeof(SutProviderCustomization)).ToArray(), values) + { } + + public InlineSutAutoDataAttribute(ICustomization[] customizations, params object[] values) : base( + customizations.Append(new SutProviderCustomization()).ToArray(), values) + { } } diff --git a/test/Common/AutoFixture/Attributes/JsonDocumentCustomizeAttribute.cs b/test/Common/AutoFixture/Attributes/JsonDocumentCustomizeAttribute.cs index d4df0599a..41b1dc63b 100644 --- a/test/Common/AutoFixture/Attributes/JsonDocumentCustomizeAttribute.cs +++ b/test/Common/AutoFixture/Attributes/JsonDocumentCustomizeAttribute.cs @@ -1,11 +1,10 @@ using AutoFixture; using Bit.Test.Common.AutoFixture.JsonDocumentFixtures; -namespace Bit.Test.Common.AutoFixture.Attributes +namespace Bit.Test.Common.AutoFixture.Attributes; + +public class JsonDocumentCustomizeAttribute : BitCustomizeAttribute { - public class JsonDocumentCustomizeAttribute : BitCustomizeAttribute - { - public string Json { get; set; } - public override ICustomization GetCustomization() => new JsonDocumentCustomization() { Json = Json }; - } + public string Json { get; set; } + public override ICustomization GetCustomization() => new JsonDocumentCustomization() { Json = Json }; } diff --git a/test/Common/AutoFixture/Attributes/RequiredEnvironmentTheoryAttribute.cs b/test/Common/AutoFixture/Attributes/RequiredEnvironmentTheoryAttribute.cs index 183001063..5bb0c3485 100644 --- a/test/Common/AutoFixture/Attributes/RequiredEnvironmentTheoryAttribute.cs +++ b/test/Common/AutoFixture/Attributes/RequiredEnvironmentTheoryAttribute.cs @@ -1,38 +1,37 @@ using Xunit; -namespace Bit.Test.Common.AutoFixture.Attributes +namespace Bit.Test.Common.AutoFixture.Attributes; + +/// +/// Used for requiring certain environment variables exist at the time. Mostly used for more edge unit tests that shouldn't +/// be run during CI builds or should only be ran in CI builds when pieces of information are available. +/// +public class RequiredEnvironmentTheoryAttribute : TheoryAttribute { - /// - /// Used for requiring certain environment variables exist at the time. Mostly used for more edge unit tests that shouldn't - /// be run during CI builds or should only be ran in CI builds when pieces of information are available. - /// - public class RequiredEnvironmentTheoryAttribute : TheoryAttribute + private readonly string[] _environmentVariableNames; + + public RequiredEnvironmentTheoryAttribute(params string[] environmentVariableNames) { - private readonly string[] _environmentVariableNames; + _environmentVariableNames = environmentVariableNames; - public RequiredEnvironmentTheoryAttribute(params string[] environmentVariableNames) + if (!HasRequiredEnvironmentVariables()) { - _environmentVariableNames = environmentVariableNames; - - if (!HasRequiredEnvironmentVariables()) - { - Skip = $"Missing one or more required environment variables. ({string.Join(", ", _environmentVariableNames)})"; - } - } - - private bool HasRequiredEnvironmentVariables() - { - foreach (var env in _environmentVariableNames) - { - var value = Environment.GetEnvironmentVariable(env); - - if (value == null) - { - return false; - } - } - - return true; + Skip = $"Missing one or more required environment variables. ({string.Join(", ", _environmentVariableNames)})"; } } + + private bool HasRequiredEnvironmentVariables() + { + foreach (var env in _environmentVariableNames) + { + var value = Environment.GetEnvironmentVariable(env); + + if (value == null) + { + return false; + } + } + + return true; + } } diff --git a/test/Common/AutoFixture/Attributes/SutAutoDataAttribute.cs b/test/Common/AutoFixture/Attributes/SutAutoDataAttribute.cs index 3680f4a66..a84bc3118 100644 --- a/test/Common/AutoFixture/Attributes/SutAutoDataAttribute.cs +++ b/test/Common/AutoFixture/Attributes/SutAutoDataAttribute.cs @@ -1,16 +1,15 @@ using AutoFixture; -namespace Bit.Test.Common.AutoFixture.Attributes -{ - public class SutProviderCustomizeAttribute : BitCustomizeAttribute - { - public override ICustomization GetCustomization() => new SutProviderCustomization(); - } +namespace Bit.Test.Common.AutoFixture.Attributes; - public class SutAutoDataAttribute : CustomAutoDataAttribute - { - public SutAutoDataAttribute(params Type[] iCustomizationTypes) : base( - iCustomizationTypes.Append(typeof(SutProviderCustomization)).ToArray()) - { } - } +public class SutProviderCustomizeAttribute : BitCustomizeAttribute +{ + public override ICustomization GetCustomization() => new SutProviderCustomization(); +} + +public class SutAutoDataAttribute : CustomAutoDataAttribute +{ + public SutAutoDataAttribute(params Type[] iCustomizationTypes) : base( + iCustomizationTypes.Append(typeof(SutProviderCustomization)).ToArray()) + { } } diff --git a/test/Common/AutoFixture/BuilderWithoutAutoProperties.cs b/test/Common/AutoFixture/BuilderWithoutAutoProperties.cs index b2bdae0d4..039475fad 100644 --- a/test/Common/AutoFixture/BuilderWithoutAutoProperties.cs +++ b/test/Common/AutoFixture/BuilderWithoutAutoProperties.cs @@ -1,39 +1,38 @@ using AutoFixture; using AutoFixture.Kernel; -namespace Bit.Test.Common.AutoFixture +namespace Bit.Test.Common.AutoFixture; + +public class BuilderWithoutAutoProperties : ISpecimenBuilder { - public class BuilderWithoutAutoProperties : ISpecimenBuilder + private readonly Type _type; + public BuilderWithoutAutoProperties(Type type) { - private readonly Type _type; - public BuilderWithoutAutoProperties(Type type) - { - _type = type; - } - - public object Create(object request, ISpecimenContext context) - { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != _type) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - // This is the equivalent of _fixture.Build<_type>().OmitAutoProperties().Create(request, context), but no overload for - // Build(Type type) exists. - dynamic reflectedComposer = typeof(Fixture).GetMethod("Build").MakeGenericMethod(_type).Invoke(fixture, null); - return reflectedComposer.OmitAutoProperties().Create(request, context); - } + _type = type; } - public class BuilderWithoutAutoProperties : ISpecimenBuilder + + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) => - new BuilderWithoutAutoProperties(typeof(T)).Create(request, context); + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + var type = request as Type; + if (type == null || type != _type) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + // This is the equivalent of _fixture.Build<_type>().OmitAutoProperties().Create(request, context), but no overload for + // Build(Type type) exists. + dynamic reflectedComposer = typeof(Fixture).GetMethod("Build").MakeGenericMethod(_type).Invoke(fixture, null); + return reflectedComposer.OmitAutoProperties().Create(request, context); } } +public class BuilderWithoutAutoProperties : ISpecimenBuilder +{ + public object Create(object request, ISpecimenContext context) => + new BuilderWithoutAutoProperties(typeof(T)).Create(request, context); +} diff --git a/test/Common/AutoFixture/ControllerCustomization.cs b/test/Common/AutoFixture/ControllerCustomization.cs index 9592466aa..f695f86b5 100644 --- a/test/Common/AutoFixture/ControllerCustomization.cs +++ b/test/Common/AutoFixture/ControllerCustomization.cs @@ -2,32 +2,31 @@ using Microsoft.AspNetCore.Mvc; using Org.BouncyCastle.Security; -namespace Bit.Test.Common.AutoFixture +namespace Bit.Test.Common.AutoFixture; + +/// +/// Disables setting of Auto Properties on the Controller to avoid ASP.net initialization errors. Still sets constructor dependencies. +/// +/// +public class ControllerCustomization : ICustomization { - /// - /// Disables setting of Auto Properties on the Controller to avoid ASP.net initialization errors. Still sets constructor dependencies. - /// - /// - public class ControllerCustomization : ICustomization + private readonly Type _controllerType; + public ControllerCustomization(Type controllerType) { - private readonly Type _controllerType; - public ControllerCustomization(Type controllerType) + if (!controllerType.IsAssignableTo(typeof(Controller))) { - if (!controllerType.IsAssignableTo(typeof(Controller))) - { - throw new InvalidParameterException($"{nameof(controllerType)} must derive from {typeof(Controller).Name}"); - } - - _controllerType = controllerType; + throw new InvalidParameterException($"{nameof(controllerType)} must derive from {typeof(Controller).Name}"); } - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new BuilderWithoutAutoProperties(_controllerType)); - } + _controllerType = controllerType; } - public class ControllerCustomization : ICustomization where T : Controller + + public void Customize(IFixture fixture) { - public void Customize(IFixture fixture) => new ControllerCustomization(typeof(T)).Customize(fixture); + fixture.Customizations.Add(new BuilderWithoutAutoProperties(_controllerType)); } } +public class ControllerCustomization : ICustomization where T : Controller +{ + public void Customize(IFixture fixture) => new ControllerCustomization(typeof(T)).Customize(fixture); +} diff --git a/test/Common/AutoFixture/FixtureExtensions.cs b/test/Common/AutoFixture/FixtureExtensions.cs index 162784a35..300967666 100644 --- a/test/Common/AutoFixture/FixtureExtensions.cs +++ b/test/Common/AutoFixture/FixtureExtensions.cs @@ -1,14 +1,13 @@ using AutoFixture; using AutoFixture.AutoNSubstitute; -namespace Bit.Test.Common.AutoFixture -{ - public static class FixtureExtensions - { - public static IFixture WithAutoNSubstitutions(this IFixture fixture) - => fixture.Customize(new AutoNSubstituteCustomization()); +namespace Bit.Test.Common.AutoFixture; - public static IFixture WithAutoNSubstitutionsAutoPopulatedProperties(this IFixture fixture) - => fixture.Customize(new AutoNSubstituteCustomization { ConfigureMembers = true }); - } +public static class FixtureExtensions +{ + public static IFixture WithAutoNSubstitutions(this IFixture fixture) + => fixture.Customize(new AutoNSubstituteCustomization()); + + public static IFixture WithAutoNSubstitutionsAutoPopulatedProperties(this IFixture fixture) + => fixture.Customize(new AutoNSubstituteCustomization { ConfigureMembers = true }); } diff --git a/test/Common/AutoFixture/GlobalSettingsFixtures.cs b/test/Common/AutoFixture/GlobalSettingsFixtures.cs index 86f460909..3a2a319ee 100644 --- a/test/Common/AutoFixture/GlobalSettingsFixtures.cs +++ b/test/Common/AutoFixture/GlobalSettingsFixtures.cs @@ -1,16 +1,15 @@ using AutoFixture; -namespace Bit.Test.Common.AutoFixture +namespace Bit.Test.Common.AutoFixture; + +public class GlobalSettings : ICustomization { - public class GlobalSettings : ICustomization + public void Customize(IFixture fixture) { - public void Customize(IFixture fixture) - { - fixture.Customize(composer => composer - .Without(s => s.BaseServiceUri) - .Without(s => s.Attachment) - .Without(s => s.Send) - .Without(s => s.DataProtection)); - } + fixture.Customize(composer => composer + .Without(s => s.BaseServiceUri) + .Without(s => s.Attachment) + .Without(s => s.Send) + .Without(s => s.DataProtection)); } } diff --git a/test/Common/AutoFixture/ISutProvider.cs b/test/Common/AutoFixture/ISutProvider.cs index 9f6b0b23a..1ce9c7a00 100644 --- a/test/Common/AutoFixture/ISutProvider.cs +++ b/test/Common/AutoFixture/ISutProvider.cs @@ -1,8 +1,7 @@ -namespace Bit.Test.Common.AutoFixture +namespace Bit.Test.Common.AutoFixture; + +public interface ISutProvider { - public interface ISutProvider - { - Type SutType { get; } - ISutProvider Create(); - } + Type SutType { get; } + ISutProvider Create(); } diff --git a/test/Common/AutoFixture/JsonDocumentFixtures.cs b/test/Common/AutoFixture/JsonDocumentFixtures.cs index e39b7f990..df27aa8ce 100644 --- a/test/Common/AutoFixture/JsonDocumentFixtures.cs +++ b/test/Common/AutoFixture/JsonDocumentFixtures.cs @@ -2,31 +2,30 @@ using AutoFixture; using AutoFixture.Kernel; -namespace Bit.Test.Common.AutoFixture.JsonDocumentFixtures +namespace Bit.Test.Common.AutoFixture.JsonDocumentFixtures; + +public class JsonDocumentCustomization : ICustomization, ISpecimenBuilder { - public class JsonDocumentCustomization : ICustomization, ISpecimenBuilder + + public string Json { get; set; } + + public void Customize(IFixture fixture) { + fixture.Customizations.Add(this); + } - public string Json { get; set; } - - public void Customize(IFixture fixture) + public object Create(object request, ISpecimenContext context) + { + if (context == null) { - fixture.Customizations.Add(this); + throw new ArgumentNullException(nameof(context)); + } + var type = request as Type; + if (type == null || (type != typeof(JsonDocument))) + { + return new NoSpecimen(); } - public object Create(object request, ISpecimenContext context) - { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - var type = request as Type; - if (type == null || (type != typeof(JsonDocument))) - { - return new NoSpecimen(); - } - - return JsonDocument.Parse(Json ?? "{}"); - } + return JsonDocument.Parse(Json ?? "{}"); } } diff --git a/test/Common/AutoFixture/SutProvider.cs b/test/Common/AutoFixture/SutProvider.cs index 2b00ed0cf..3a3d6409b 100644 --- a/test/Common/AutoFixture/SutProvider.cs +++ b/test/Common/AutoFixture/SutProvider.cs @@ -2,133 +2,132 @@ using AutoFixture; using AutoFixture.Kernel; -namespace Bit.Test.Common.AutoFixture +namespace Bit.Test.Common.AutoFixture; + +public class SutProvider : ISutProvider { - public class SutProvider : ISutProvider + private Dictionary> _dependencies; + private readonly IFixture _fixture; + private readonly ConstructorParameterRelay _constructorParameterRelay; + + public TSut Sut { get; private set; } + public Type SutType => typeof(TSut); + + public SutProvider() : this(new Fixture()) { } + + public SutProvider(IFixture fixture) { - private Dictionary> _dependencies; + _dependencies = new Dictionary>(); + _fixture = (fixture ?? new Fixture()).WithAutoNSubstitutions().Customize(new GlobalSettings()); + _constructorParameterRelay = new ConstructorParameterRelay(this, _fixture); + _fixture.Customizations.Add(_constructorParameterRelay); + } + + public SutProvider SetDependency(T dependency, string parameterName = "") + => SetDependency(typeof(T), dependency, parameterName); + public SutProvider SetDependency(Type dependencyType, object dependency, string parameterName = "") + { + if (_dependencies.ContainsKey(dependencyType)) + { + _dependencies[dependencyType][parameterName] = dependency; + } + else + { + _dependencies[dependencyType] = new Dictionary { { parameterName, dependency } }; + } + + return this; + } + + public T GetDependency(string parameterName = "") => (T)GetDependency(typeof(T), parameterName); + public object GetDependency(Type dependencyType, string parameterName = "") + { + if (DependencyIsSet(dependencyType, parameterName)) + { + return _dependencies[dependencyType][parameterName]; + } + else if (_dependencies.ContainsKey(dependencyType)) + { + var knownDependencies = _dependencies[dependencyType]; + if (knownDependencies.Values.Count == 1) + { + return _dependencies[dependencyType].Values.Single(); + } + else + { + throw new ArgumentException(string.Concat($"Dependency of type {dependencyType.Name} and name ", + $"{parameterName} does not exist. Available dependency names are: ", + string.Join(", ", knownDependencies.Keys))); + } + } + else + { + throw new ArgumentException($"Dependency of type {dependencyType.Name} and name {parameterName} has not been set."); + } + } + + public void Reset() + { + _dependencies = new Dictionary>(); + Sut = default; + } + + ISutProvider ISutProvider.Create() => Create(); + public SutProvider Create() + { + Sut = _fixture.Create(); + return this; + } + + private bool DependencyIsSet(Type dependencyType, string parameterName = "") + => _dependencies.ContainsKey(dependencyType) && _dependencies[dependencyType].ContainsKey(parameterName); + + private object GetDefault(Type type) => type.IsValueType ? Activator.CreateInstance(type) : null; + + private class ConstructorParameterRelay : ISpecimenBuilder + { + private readonly SutProvider _sutProvider; private readonly IFixture _fixture; - private readonly ConstructorParameterRelay _constructorParameterRelay; - public TSut Sut { get; private set; } - public Type SutType => typeof(TSut); - - public SutProvider() : this(new Fixture()) { } - - public SutProvider(IFixture fixture) + public ConstructorParameterRelay(SutProvider sutProvider, IFixture fixture) { - _dependencies = new Dictionary>(); - _fixture = (fixture ?? new Fixture()).WithAutoNSubstitutions().Customize(new GlobalSettings()); - _constructorParameterRelay = new ConstructorParameterRelay(this, _fixture); - _fixture.Customizations.Add(_constructorParameterRelay); + _sutProvider = sutProvider; + _fixture = fixture; } - public SutProvider SetDependency(T dependency, string parameterName = "") - => SetDependency(typeof(T), dependency, parameterName); - public SutProvider SetDependency(Type dependencyType, object dependency, string parameterName = "") + public object Create(object request, ISpecimenContext context) { - if (_dependencies.ContainsKey(dependencyType)) + if (context == null) { - _dependencies[dependencyType][parameterName] = dependency; + throw new ArgumentNullException(nameof(context)); } - else + if (!(request is ParameterInfo parameterInfo)) { - _dependencies[dependencyType] = new Dictionary { { parameterName, dependency } }; + return new NoSpecimen(); + } + if (parameterInfo.Member.DeclaringType != typeof(T) || + parameterInfo.Member.MemberType != MemberTypes.Constructor) + { + return new NoSpecimen(); } - return this; - } - - public T GetDependency(string parameterName = "") => (T)GetDependency(typeof(T), parameterName); - public object GetDependency(Type dependencyType, string parameterName = "") - { - if (DependencyIsSet(dependencyType, parameterName)) + if (_sutProvider.DependencyIsSet(parameterInfo.ParameterType, parameterInfo.Name)) { - return _dependencies[dependencyType][parameterName]; + return _sutProvider.GetDependency(parameterInfo.ParameterType, parameterInfo.Name); } - else if (_dependencies.ContainsKey(dependencyType)) + // Return default type if set + else if (_sutProvider.DependencyIsSet(parameterInfo.ParameterType, "")) { - var knownDependencies = _dependencies[dependencyType]; - if (knownDependencies.Values.Count == 1) - { - return _dependencies[dependencyType].Values.Single(); - } - else - { - throw new ArgumentException(string.Concat($"Dependency of type {dependencyType.Name} and name ", - $"{parameterName} does not exist. Available dependency names are: ", - string.Join(", ", knownDependencies.Keys))); - } - } - else - { - throw new ArgumentException($"Dependency of type {dependencyType.Name} and name {parameterName} has not been set."); - } - } - - public void Reset() - { - _dependencies = new Dictionary>(); - Sut = default; - } - - ISutProvider ISutProvider.Create() => Create(); - public SutProvider Create() - { - Sut = _fixture.Create(); - return this; - } - - private bool DependencyIsSet(Type dependencyType, string parameterName = "") - => _dependencies.ContainsKey(dependencyType) && _dependencies[dependencyType].ContainsKey(parameterName); - - private object GetDefault(Type type) => type.IsValueType ? Activator.CreateInstance(type) : null; - - private class ConstructorParameterRelay : ISpecimenBuilder - { - private readonly SutProvider _sutProvider; - private readonly IFixture _fixture; - - public ConstructorParameterRelay(SutProvider sutProvider, IFixture fixture) - { - _sutProvider = sutProvider; - _fixture = fixture; + return _sutProvider.GetDependency(parameterInfo.ParameterType, ""); } - public object Create(object request, ISpecimenContext context) - { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - if (!(request is ParameterInfo parameterInfo)) - { - return new NoSpecimen(); - } - if (parameterInfo.Member.DeclaringType != typeof(T) || - parameterInfo.Member.MemberType != MemberTypes.Constructor) - { - return new NoSpecimen(); - } - if (_sutProvider.DependencyIsSet(parameterInfo.ParameterType, parameterInfo.Name)) - { - return _sutProvider.GetDependency(parameterInfo.ParameterType, parameterInfo.Name); - } - // Return default type if set - else if (_sutProvider.DependencyIsSet(parameterInfo.ParameterType, "")) - { - return _sutProvider.GetDependency(parameterInfo.ParameterType, ""); - } - - - // This is the equivalent of _fixture.Create, but no overload for - // Create(Type type) exists. - var dependency = new SpecimenContext(_fixture).Resolve(new SeededRequest(parameterInfo.ParameterType, - _sutProvider.GetDefault(parameterInfo.ParameterType))); - _sutProvider.SetDependency(parameterInfo.ParameterType, dependency, parameterInfo.Name); - return dependency; - } + // This is the equivalent of _fixture.Create, but no overload for + // Create(Type type) exists. + var dependency = new SpecimenContext(_fixture).Resolve(new SeededRequest(parameterInfo.ParameterType, + _sutProvider.GetDefault(parameterInfo.ParameterType))); + _sutProvider.SetDependency(parameterInfo.ParameterType, dependency, parameterInfo.Name); + return dependency; } } } diff --git a/test/Common/AutoFixture/SutProviderCustomization.cs b/test/Common/AutoFixture/SutProviderCustomization.cs index 148592394..5cbff6a71 100644 --- a/test/Common/AutoFixture/SutProviderCustomization.cs +++ b/test/Common/AutoFixture/SutProviderCustomization.cs @@ -1,34 +1,33 @@ using AutoFixture; using AutoFixture.Kernel; -namespace Bit.Test.Common.AutoFixture.Attributes +namespace Bit.Test.Common.AutoFixture.Attributes; + +public class SutProviderCustomization : ICustomization, ISpecimenBuilder { - public class SutProviderCustomization : ICustomization, ISpecimenBuilder + private IFixture _fixture = null; + + public object Create(object request, ISpecimenContext context) { - private IFixture _fixture = null; - - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - if (!(request is Type typeRequest)) - { - return new NoSpecimen(); - } - if (!typeof(ISutProvider).IsAssignableFrom(typeRequest)) - { - return new NoSpecimen(); - } - - return ((ISutProvider)Activator.CreateInstance(typeRequest, _fixture)).Create(); + throw new ArgumentNullException(nameof(context)); + } + if (!(request is Type typeRequest)) + { + return new NoSpecimen(); + } + if (!typeof(ISutProvider).IsAssignableFrom(typeRequest)) + { + return new NoSpecimen(); } - public void Customize(IFixture fixture) - { - _fixture = fixture; - fixture.Customizations.Add(this); - } + return ((ISutProvider)Activator.CreateInstance(typeRequest, _fixture)).Create(); + } + + public void Customize(IFixture fixture) + { + _fixture = fixture; + fixture.Customizations.Add(this); } } diff --git a/test/Common/Helpers/AssertHelper.cs b/test/Common/Helpers/AssertHelper.cs index 7239cb6db..d690837ef 100644 --- a/test/Common/Helpers/AssertHelper.cs +++ b/test/Common/Helpers/AssertHelper.cs @@ -7,223 +7,222 @@ using Microsoft.AspNetCore.Http; using Xunit; using Xunit.Sdk; -namespace Bit.Test.Common.Helpers +namespace Bit.Test.Common.Helpers; + +public static class AssertHelper { - public static class AssertHelper + public static void AssertPropertyEqual(object expected, object actual, params string[] excludedPropertyStrings) { - public static void AssertPropertyEqual(object expected, object actual, params string[] excludedPropertyStrings) + var relevantExcludedProperties = excludedPropertyStrings.Where(name => !name.Contains('.')).ToList(); + if (expected == null) { - var relevantExcludedProperties = excludedPropertyStrings.Where(name => !name.Contains('.')).ToList(); - if (expected == null) - { - Assert.Null(actual); - return; - } - - if (actual == null) - { - throw new Exception("Actual object is null but expected is not"); - } - - foreach (var expectedPropInfo in expected.GetType().GetProperties().Where(pi => !relevantExcludedProperties.Contains(pi.Name))) - { - var actualPropInfo = actual.GetType().GetProperty(expectedPropInfo.Name); - - if (actualPropInfo == null) - { - throw new Exception(string.Concat($"Expected actual object to contain a property named {expectedPropInfo.Name}, but it does not\n", - $"Expected:\n{JsonSerializer.Serialize(expected, JsonHelpers.Indented)}\n", - $"Actual:\n{JsonSerializer.Serialize(actual, JsonHelpers.Indented)}")); - } - - if (expectedPropInfo.PropertyType == typeof(string) || expectedPropInfo.PropertyType.IsValueType) - { - Assert.Equal(expectedPropInfo.GetValue(expected), actualPropInfo.GetValue(actual)); - } - else if (expectedPropInfo.PropertyType == typeof(JsonDocument) && actualPropInfo.PropertyType == typeof(JsonDocument)) - { - static string JsonDocString(PropertyInfo info, object obj) => JsonSerializer.Serialize(info.GetValue(obj)); - Assert.Equal(JsonDocString(expectedPropInfo, expected), JsonDocString(actualPropInfo, actual)); - } - else - { - var prefix = $"{expectedPropInfo.PropertyType.Name}."; - var nextExcludedProperties = excludedPropertyStrings.Where(name => name.StartsWith(prefix)) - .Select(name => name[prefix.Length..]).ToArray(); - AssertPropertyEqual(expectedPropInfo.GetValue(expected), actualPropInfo.GetValue(actual), nextExcludedProperties); - } - } + Assert.Null(actual); + return; } - private static Predicate AssertPropertyEqualPredicate(T expected, params string[] excludedPropertyStrings) => (actual) => + if (actual == null) { - AssertPropertyEqual(expected, actual, excludedPropertyStrings); - return true; - }; - - public static Expression> AssertPropertyEqual(T expected, params string[] excludedPropertyStrings) => - (T actual) => AssertPropertyEqualPredicate(expected, excludedPropertyStrings)(actual); - - private static Predicate> AssertPropertyEqualPredicate(IEnumerable expected, params string[] excludedPropertyStrings) => (actual) => - { - // IEnumerable.Zip doesn't account for different lengths, we need to check this ourselves - if (actual.Count() != expected.Count()) - { - throw new Exception(string.Concat($"Actual IEnumerable does not have the expected length.\n", - $"Expected: {expected.Count()}\n", - $"Actual: {actual.Count()}")); - } - - var elements = expected.Zip(actual); - foreach (var (expectedEl, actualEl) in elements) - { - AssertPropertyEqual(expectedEl, actualEl, excludedPropertyStrings); - } - - return true; - }; - - public static Expression>> AssertPropertyEqual(IEnumerable expected, params string[] excludedPropertyStrings) => - (actual) => AssertPropertyEqualPredicate(expected, excludedPropertyStrings)(actual); - - private static Predicate AssertEqualExpectedPredicate(T expected) => (actual) => - { - Assert.Equal(expected, actual); - return true; - }; - - public static Expression> AssertEqualExpected(T expected) => - (T actual) => AssertEqualExpectedPredicate(expected)(actual); - - public static JsonElement AssertJsonProperty(JsonElement element, string propertyName, JsonValueKind jsonValueKind) - { - if (!element.TryGetProperty(propertyName, out var subElement)) - { - throw new XunitException($"Could not find property by name '{propertyName}'"); - } - - Assert.Equal(jsonValueKind, subElement.ValueKind); - return subElement; + throw new Exception("Actual object is null but expected is not"); } - public static void AssertEqualJson(JsonElement a, JsonElement b) + foreach (var expectedPropInfo in expected.GetType().GetProperties().Where(pi => !relevantExcludedProperties.Contains(pi.Name))) { - switch (a.ValueKind) + var actualPropInfo = actual.GetType().GetProperty(expectedPropInfo.Name); + + if (actualPropInfo == null) { - case JsonValueKind.Array: - Assert.Equal(JsonValueKind.Array, b.ValueKind); - AssertEqualJsonArray(a, b); - break; - case JsonValueKind.Object: - Assert.Equal(JsonValueKind.Object, b.ValueKind); - AssertEqualJsonObject(a, b); - break; - case JsonValueKind.False: - Assert.Equal(JsonValueKind.False, b.ValueKind); - break; - case JsonValueKind.True: - Assert.Equal(JsonValueKind.True, b.ValueKind); - break; - case JsonValueKind.Number: - Assert.Equal(JsonValueKind.Number, b.ValueKind); - Assert.Equal(a.GetDouble(), b.GetDouble()); - break; - case JsonValueKind.String: - Assert.Equal(JsonValueKind.String, b.ValueKind); - Assert.Equal(a.GetString(), b.GetString()); - break; - case JsonValueKind.Null: - Assert.Equal(JsonValueKind.Null, b.ValueKind); - break; - default: - throw new XunitException($"Bad JsonValueKind '{a.ValueKind}'"); + throw new Exception(string.Concat($"Expected actual object to contain a property named {expectedPropInfo.Name}, but it does not\n", + $"Expected:\n{JsonSerializer.Serialize(expected, JsonHelpers.Indented)}\n", + $"Actual:\n{JsonSerializer.Serialize(actual, JsonHelpers.Indented)}")); } - } - private static void AssertEqualJsonObject(JsonElement a, JsonElement b) - { - Debug.Assert(a.ValueKind == JsonValueKind.Object && b.ValueKind == JsonValueKind.Object); - - var aObjectEnumerator = a.EnumerateObject(); - var bObjectEnumerator = b.EnumerateObject(); - - while (true) + if (expectedPropInfo.PropertyType == typeof(string) || expectedPropInfo.PropertyType.IsValueType) { - var aCanMove = aObjectEnumerator.MoveNext(); - var bCanMove = bObjectEnumerator.MoveNext(); - - if (aCanMove) - { - Assert.True(bCanMove, $"a was able to enumerate over object '{a}' but b was NOT able to '{b}'"); - } - else - { - Assert.False(bCanMove, $"a was NOT able to enumerate over object '{a}' but b was able to '{b}'"); - } - - if (aCanMove == false && bCanMove == false) - { - // They both can't continue to enumerate at the same time, that is valid - break; - } - - var aProp = aObjectEnumerator.Current; - var bProp = bObjectEnumerator.Current; - - Assert.Equal(aProp.Name, bProp.Name); - // Recursion! - AssertEqualJson(aProp.Value, bProp.Value); + Assert.Equal(expectedPropInfo.GetValue(expected), actualPropInfo.GetValue(actual)); } - } - - private static void AssertEqualJsonArray(JsonElement a, JsonElement b) - { - Debug.Assert(a.ValueKind == JsonValueKind.Array && b.ValueKind == JsonValueKind.Array); - - var aArrayEnumerator = a.EnumerateArray(); - var bArrayEnumerator = b.EnumerateArray(); - - while (true) + else if (expectedPropInfo.PropertyType == typeof(JsonDocument) && actualPropInfo.PropertyType == typeof(JsonDocument)) { - var aCanMove = aArrayEnumerator.MoveNext(); - var bCanMove = bArrayEnumerator.MoveNext(); - - if (aCanMove) - { - Assert.True(bCanMove, $"a was able to enumerate over array '{a}' but b was NOT able to '{b}'"); - } - else - { - Assert.False(bCanMove, $"a was NOT able to enumerate over array '{a}' but b was able to '{b}'"); - } - - if (aCanMove == false && bCanMove == false) - { - // They both can't continue to enumerate at the same time, that is valid - break; - } - - var aElement = aArrayEnumerator.Current; - var bElement = bArrayEnumerator.Current; - - // Recursion! - AssertEqualJson(aElement, bElement); + static string JsonDocString(PropertyInfo info, object obj) => JsonSerializer.Serialize(info.GetValue(obj)); + Assert.Equal(JsonDocString(expectedPropInfo, expected), JsonDocString(actualPropInfo, actual)); + } + else + { + var prefix = $"{expectedPropInfo.PropertyType.Name}."; + var nextExcludedProperties = excludedPropertyStrings.Where(name => name.StartsWith(prefix)) + .Select(name => name[prefix.Length..]).ToArray(); + AssertPropertyEqual(expectedPropInfo.GetValue(expected), actualPropInfo.GetValue(actual), nextExcludedProperties); } - } - - public async static Task AssertResponseTypeIs(HttpContext context) - { - return await JsonSerializer.DeserializeAsync(context.Response.Body); - } - - public static TimeSpan AssertRecent(DateTime dateTime, int skewSeconds = 2) - => AssertRecent(dateTime, TimeSpan.FromSeconds(skewSeconds)); - - public static TimeSpan AssertRecent(DateTime dateTime, TimeSpan skew) - { - var difference = DateTime.UtcNow - dateTime; - Assert.True(difference < skew); - return difference; } } + + private static Predicate AssertPropertyEqualPredicate(T expected, params string[] excludedPropertyStrings) => (actual) => + { + AssertPropertyEqual(expected, actual, excludedPropertyStrings); + return true; + }; + + public static Expression> AssertPropertyEqual(T expected, params string[] excludedPropertyStrings) => + (T actual) => AssertPropertyEqualPredicate(expected, excludedPropertyStrings)(actual); + + private static Predicate> AssertPropertyEqualPredicate(IEnumerable expected, params string[] excludedPropertyStrings) => (actual) => + { + // IEnumerable.Zip doesn't account for different lengths, we need to check this ourselves + if (actual.Count() != expected.Count()) + { + throw new Exception(string.Concat($"Actual IEnumerable does not have the expected length.\n", + $"Expected: {expected.Count()}\n", + $"Actual: {actual.Count()}")); + } + + var elements = expected.Zip(actual); + foreach (var (expectedEl, actualEl) in elements) + { + AssertPropertyEqual(expectedEl, actualEl, excludedPropertyStrings); + } + + return true; + }; + + public static Expression>> AssertPropertyEqual(IEnumerable expected, params string[] excludedPropertyStrings) => + (actual) => AssertPropertyEqualPredicate(expected, excludedPropertyStrings)(actual); + + private static Predicate AssertEqualExpectedPredicate(T expected) => (actual) => + { + Assert.Equal(expected, actual); + return true; + }; + + public static Expression> AssertEqualExpected(T expected) => + (T actual) => AssertEqualExpectedPredicate(expected)(actual); + + public static JsonElement AssertJsonProperty(JsonElement element, string propertyName, JsonValueKind jsonValueKind) + { + if (!element.TryGetProperty(propertyName, out var subElement)) + { + throw new XunitException($"Could not find property by name '{propertyName}'"); + } + + Assert.Equal(jsonValueKind, subElement.ValueKind); + return subElement; + } + + public static void AssertEqualJson(JsonElement a, JsonElement b) + { + switch (a.ValueKind) + { + case JsonValueKind.Array: + Assert.Equal(JsonValueKind.Array, b.ValueKind); + AssertEqualJsonArray(a, b); + break; + case JsonValueKind.Object: + Assert.Equal(JsonValueKind.Object, b.ValueKind); + AssertEqualJsonObject(a, b); + break; + case JsonValueKind.False: + Assert.Equal(JsonValueKind.False, b.ValueKind); + break; + case JsonValueKind.True: + Assert.Equal(JsonValueKind.True, b.ValueKind); + break; + case JsonValueKind.Number: + Assert.Equal(JsonValueKind.Number, b.ValueKind); + Assert.Equal(a.GetDouble(), b.GetDouble()); + break; + case JsonValueKind.String: + Assert.Equal(JsonValueKind.String, b.ValueKind); + Assert.Equal(a.GetString(), b.GetString()); + break; + case JsonValueKind.Null: + Assert.Equal(JsonValueKind.Null, b.ValueKind); + break; + default: + throw new XunitException($"Bad JsonValueKind '{a.ValueKind}'"); + } + } + + private static void AssertEqualJsonObject(JsonElement a, JsonElement b) + { + Debug.Assert(a.ValueKind == JsonValueKind.Object && b.ValueKind == JsonValueKind.Object); + + var aObjectEnumerator = a.EnumerateObject(); + var bObjectEnumerator = b.EnumerateObject(); + + while (true) + { + var aCanMove = aObjectEnumerator.MoveNext(); + var bCanMove = bObjectEnumerator.MoveNext(); + + if (aCanMove) + { + Assert.True(bCanMove, $"a was able to enumerate over object '{a}' but b was NOT able to '{b}'"); + } + else + { + Assert.False(bCanMove, $"a was NOT able to enumerate over object '{a}' but b was able to '{b}'"); + } + + if (aCanMove == false && bCanMove == false) + { + // They both can't continue to enumerate at the same time, that is valid + break; + } + + var aProp = aObjectEnumerator.Current; + var bProp = bObjectEnumerator.Current; + + Assert.Equal(aProp.Name, bProp.Name); + // Recursion! + AssertEqualJson(aProp.Value, bProp.Value); + } + } + + private static void AssertEqualJsonArray(JsonElement a, JsonElement b) + { + Debug.Assert(a.ValueKind == JsonValueKind.Array && b.ValueKind == JsonValueKind.Array); + + var aArrayEnumerator = a.EnumerateArray(); + var bArrayEnumerator = b.EnumerateArray(); + + while (true) + { + var aCanMove = aArrayEnumerator.MoveNext(); + var bCanMove = bArrayEnumerator.MoveNext(); + + if (aCanMove) + { + Assert.True(bCanMove, $"a was able to enumerate over array '{a}' but b was NOT able to '{b}'"); + } + else + { + Assert.False(bCanMove, $"a was NOT able to enumerate over array '{a}' but b was able to '{b}'"); + } + + if (aCanMove == false && bCanMove == false) + { + // They both can't continue to enumerate at the same time, that is valid + break; + } + + var aElement = aArrayEnumerator.Current; + var bElement = bArrayEnumerator.Current; + + // Recursion! + AssertEqualJson(aElement, bElement); + } + } + + public async static Task AssertResponseTypeIs(HttpContext context) + { + return await JsonSerializer.DeserializeAsync(context.Response.Body); + } + + public static TimeSpan AssertRecent(DateTime dateTime, int skewSeconds = 2) + => AssertRecent(dateTime, TimeSpan.FromSeconds(skewSeconds)); + + public static TimeSpan AssertRecent(DateTime dateTime, TimeSpan skew) + { + var difference = DateTime.UtcNow - dateTime; + Assert.True(difference < skew); + return difference; + } } diff --git a/test/Common/Helpers/BitAutoDataAttributeHelpers.cs b/test/Common/Helpers/BitAutoDataAttributeHelpers.cs index aae8d72dc..32cacc49d 100644 --- a/test/Common/Helpers/BitAutoDataAttributeHelpers.cs +++ b/test/Common/Helpers/BitAutoDataAttributeHelpers.cs @@ -4,49 +4,48 @@ using AutoFixture.Kernel; using AutoFixture.Xunit2; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Test.Common.Helpers +namespace Bit.Test.Common.Helpers; + +public static class BitAutoDataAttributeHelpers { - public static class BitAutoDataAttributeHelpers + public static IEnumerable GetData(MethodInfo testMethod, IFixture fixture, object[] fixedTestParameters) { - public static IEnumerable GetData(MethodInfo testMethod, IFixture fixture, object[] fixedTestParameters) + var methodParameters = testMethod.GetParameters(); + var classCustomizations = testMethod.DeclaringType.GetCustomAttributes().Select(attr => attr.GetCustomization()); + var methodCustomizations = testMethod.GetCustomAttributes().Select(attr => attr.GetCustomization()); + + fixedTestParameters = fixedTestParameters ?? Array.Empty(); + + fixture = ApplyCustomizations(ApplyCustomizations(fixture, classCustomizations), methodCustomizations); + var missingParameters = methodParameters.Skip(fixedTestParameters.Length).Select(p => CustomizeAndCreate(p, fixture)); + + return new object[1][] { fixedTestParameters.Concat(missingParameters).ToArray() }; + } + + public static object CustomizeAndCreate(ParameterInfo p, IFixture fixture) + { + var customizations = p.GetCustomAttributes(typeof(CustomizeAttribute), false) + .OfType() + .Select(attr => attr.GetCustomization(p)); + + var context = new SpecimenContext(ApplyCustomizations(fixture, customizations)); + return context.Resolve(p); + } + + public static IFixture ApplyCustomizations(IFixture fixture, IEnumerable customizations) + { + var newFixture = new Fixture(); + + foreach (var customization in fixture.Customizations.Reverse().Select(b => b.ToCustomization())) { - var methodParameters = testMethod.GetParameters(); - var classCustomizations = testMethod.DeclaringType.GetCustomAttributes().Select(attr => attr.GetCustomization()); - var methodCustomizations = testMethod.GetCustomAttributes().Select(attr => attr.GetCustomization()); - - fixedTestParameters = fixedTestParameters ?? Array.Empty(); - - fixture = ApplyCustomizations(ApplyCustomizations(fixture, classCustomizations), methodCustomizations); - var missingParameters = methodParameters.Skip(fixedTestParameters.Length).Select(p => CustomizeAndCreate(p, fixture)); - - return new object[1][] { fixedTestParameters.Concat(missingParameters).ToArray() }; + newFixture.Customize(customization); } - public static object CustomizeAndCreate(ParameterInfo p, IFixture fixture) + foreach (var customization in customizations) { - var customizations = p.GetCustomAttributes(typeof(CustomizeAttribute), false) - .OfType() - .Select(attr => attr.GetCustomization(p)); - - var context = new SpecimenContext(ApplyCustomizations(fixture, customizations)); - return context.Resolve(p); + newFixture.Customize(customization); } - public static IFixture ApplyCustomizations(IFixture fixture, IEnumerable customizations) - { - var newFixture = new Fixture(); - - foreach (var customization in fixture.Customizations.Reverse().Select(b => b.ToCustomization())) - { - newFixture.Customize(customization); - } - - foreach (var customization in customizations) - { - newFixture.Customize(customization); - } - - return newFixture; - } + return newFixture; } } diff --git a/test/Common/Helpers/TestCaseHelper.cs b/test/Common/Helpers/TestCaseHelper.cs index c31d66e17..279229fc5 100644 --- a/test/Common/Helpers/TestCaseHelper.cs +++ b/test/Common/Helpers/TestCaseHelper.cs @@ -1,45 +1,44 @@ -namespace Bit.Test.Common.Helpers +namespace Bit.Test.Common.Helpers; + +public static class TestCaseHelper { - public static class TestCaseHelper + public static IEnumerable> GetCombinations(params T[] items) { - public static IEnumerable> GetCombinations(params T[] items) + var count = Math.Pow(2, items.Length); + for (var i = 0; i < count; i++) { - var count = Math.Pow(2, items.Length); - for (var i = 0; i < count; i++) + var str = Convert.ToString(i, 2).PadLeft(items.Length, '0'); + List combination = new(); + for (var j = 0; j < str.Length; j++) { - var str = Convert.ToString(i, 2).PadLeft(items.Length, '0'); - List combination = new(); - for (var j = 0; j < str.Length; j++) + if (str[j] == '1') { - if (str[j] == '1') - { - combination.Add(items[j]); - } + combination.Add(items[j]); } - yield return combination; } + yield return combination; + } + } + + public static IEnumerable> GetCombinationsOfMultipleLists(params IEnumerable[] optionLists) + { + if (!optionLists.Any()) + { + yield break; } - public static IEnumerable> GetCombinationsOfMultipleLists(params IEnumerable[] optionLists) + foreach (var item in optionLists.First()) { - if (!optionLists.Any()) + var itemArray = new[] { item }; + + if (optionLists.Length == 1) { - yield break; + yield return itemArray; } - foreach (var item in optionLists.First()) + foreach (var nextCombination in GetCombinationsOfMultipleLists(optionLists.Skip(1).ToArray())) { - var itemArray = new[] { item }; - - if (optionLists.Length == 1) - { - yield return itemArray; - } - - foreach (var nextCombination in GetCombinationsOfMultipleLists(optionLists.Skip(1).ToArray())) - { - yield return itemArray.Concat(nextCombination); - } + yield return itemArray.Concat(nextCombination); } } } diff --git a/test/Common/Test/TestCaseHelperTests.cs b/test/Common/Test/TestCaseHelperTests.cs index 697899813..4d18aa76e 100644 --- a/test/Common/Test/TestCaseHelperTests.cs +++ b/test/Common/Test/TestCaseHelperTests.cs @@ -1,51 +1,50 @@ using Bit.Test.Common.Helpers; using Xunit; -namespace Bit.Test.Common.Test +namespace Bit.Test.Common.Test; + +public class TestCaseHelperTests { - public class TestCaseHelperTests + [Fact] + public void GetCombinations_EmptyList() { - [Fact] - public void GetCombinations_EmptyList() - { - Assert.Equal(new[] { Array.Empty() }, TestCaseHelper.GetCombinations(Array.Empty()).ToArray()); - } + Assert.Equal(new[] { Array.Empty() }, TestCaseHelper.GetCombinations(Array.Empty()).ToArray()); + } - [Fact] - public void GetCombinations_OneItemList() - { - Assert.Equal(new[] { Array.Empty(), new[] { 1 } }, TestCaseHelper.GetCombinations(1)); - } + [Fact] + public void GetCombinations_OneItemList() + { + Assert.Equal(new[] { Array.Empty(), new[] { 1 } }, TestCaseHelper.GetCombinations(1)); + } - [Fact] - public void GetCombinations_TwoItemList() - { - Assert.Equal(new[] { Array.Empty(), new[] { 2 }, new[] { 1 }, new[] { 1, 2 } }, TestCaseHelper.GetCombinations(1, 2)); - } + [Fact] + public void GetCombinations_TwoItemList() + { + Assert.Equal(new[] { Array.Empty(), new[] { 2 }, new[] { 1 }, new[] { 1, 2 } }, TestCaseHelper.GetCombinations(1, 2)); + } - [Fact] - public void GetCombinationsOfMultipleLists_OneOne() - { - Assert.Equal(new[] { new object[] { 1, "1" } }, TestCaseHelper.GetCombinationsOfMultipleLists(new object[] { 1 }, new object[] { "1" })); - } + [Fact] + public void GetCombinationsOfMultipleLists_OneOne() + { + Assert.Equal(new[] { new object[] { 1, "1" } }, TestCaseHelper.GetCombinationsOfMultipleLists(new object[] { 1 }, new object[] { "1" })); + } - [Fact] - public void GetCombinationsOfMultipleLists_OneTwo() - { - Assert.Equal(new[] { new object[] { 1, "1" }, new object[] { 1, "2" } }, TestCaseHelper.GetCombinationsOfMultipleLists(new object[] { 1 }, new object[] { "1", "2" })); - } + [Fact] + public void GetCombinationsOfMultipleLists_OneTwo() + { + Assert.Equal(new[] { new object[] { 1, "1" }, new object[] { 1, "2" } }, TestCaseHelper.GetCombinationsOfMultipleLists(new object[] { 1 }, new object[] { "1", "2" })); + } - [Fact] - public void GetCombinationsOfMultipleLists_TwoOne() - { - Assert.Equal(new[] { new object[] { 1, "1" }, new object[] { 2, "1" } }, TestCaseHelper.GetCombinationsOfMultipleLists(new object[] { 1, 2 }, new object[] { "1" })); - } + [Fact] + public void GetCombinationsOfMultipleLists_TwoOne() + { + Assert.Equal(new[] { new object[] { 1, "1" }, new object[] { 2, "1" } }, TestCaseHelper.GetCombinationsOfMultipleLists(new object[] { 1, 2 }, new object[] { "1" })); + } - [Fact] - public void GetCombinationsOfMultipleLists_TwoTwo() - { - Assert.Equal(new[] { new object[] { 1, "1" }, new object[] { 1, "2" }, new object[] { 2, "1" }, new object[] { 2, "2" } }, TestCaseHelper.GetCombinationsOfMultipleLists(new object[] { 1, 2 }, new object[] { "1", "2" })); - } + [Fact] + public void GetCombinationsOfMultipleLists_TwoTwo() + { + Assert.Equal(new[] { new object[] { 1, "1" }, new object[] { 1, "2" }, new object[] { 2, "1" }, new object[] { 2, "2" } }, TestCaseHelper.GetCombinationsOfMultipleLists(new object[] { 1, 2 }, new object[] { "1", "2" })); } } diff --git a/test/Core.Test/AutoFixture/Attributes/CiSkippedTheory.cs b/test/Core.Test/AutoFixture/Attributes/CiSkippedTheory.cs index 90981704f..269988272 100644 --- a/test/Core.Test/AutoFixture/Attributes/CiSkippedTheory.cs +++ b/test/Core.Test/AutoFixture/Attributes/CiSkippedTheory.cs @@ -1,14 +1,13 @@ -namespace Bit.Core.Test.AutoFixture.Attributes +namespace Bit.Core.Test.AutoFixture.Attributes; + +public sealed class CiSkippedTheory : Xunit.TheoryAttribute { - public sealed class CiSkippedTheory : Xunit.TheoryAttribute + private static bool IsGithubActions() => Environment.GetEnvironmentVariable("CI") != null; + public CiSkippedTheory() { - private static bool IsGithubActions() => Environment.GetEnvironmentVariable("CI") != null; - public CiSkippedTheory() + if (IsGithubActions()) { - if (IsGithubActions()) - { - Skip = "Ignore during CI builds"; - } + Skip = "Ignore during CI builds"; } } } diff --git a/test/Core.Test/AutoFixture/CipherAttachmentMetaDataFixtures.cs b/test/Core.Test/AutoFixture/CipherAttachmentMetaDataFixtures.cs index 7b41f76be..ef18dcd5f 100644 --- a/test/Core.Test/AutoFixture/CipherAttachmentMetaDataFixtures.cs +++ b/test/Core.Test/AutoFixture/CipherAttachmentMetaDataFixtures.cs @@ -2,32 +2,31 @@ using AutoFixture.Dsl; using Bit.Core.Models.Data; -namespace Bit.Core.Test.AutoFixture.CipherAttachmentMetaData +namespace Bit.Core.Test.AutoFixture.CipherAttachmentMetaData; + +public class MetaData : ICustomization { - public class MetaData : ICustomization + protected virtual IPostprocessComposer ComposerAction(IFixture fixture, + ICustomizationComposer composer) { - protected virtual IPostprocessComposer ComposerAction(IFixture fixture, - ICustomizationComposer composer) - { - return composer.With(d => d.Size, fixture.Create()); - } - public void Customize(IFixture fixture) - { - fixture.Customize(composer => ComposerAction(fixture, composer)); - } + return composer.With(d => d.Size, fixture.Create()); } - - public class MetaDataWithoutContainer : MetaData + public void Customize(IFixture fixture) { - protected override IPostprocessComposer ComposerAction(IFixture fixture, - ICustomizationComposer composer) => - base.ComposerAction(fixture, composer).With(d => d.ContainerName, (string)null); - } - - public class MetaDataWithoutKey : MetaDataWithoutContainer - { - protected override IPostprocessComposer ComposerAction(IFixture fixture, - ICustomizationComposer composer) => - base.ComposerAction(fixture, composer).Without(d => d.Key); + fixture.Customize(composer => ComposerAction(fixture, composer)); } } + +public class MetaDataWithoutContainer : MetaData +{ + protected override IPostprocessComposer ComposerAction(IFixture fixture, + ICustomizationComposer composer) => + base.ComposerAction(fixture, composer).With(d => d.ContainerName, (string)null); +} + +public class MetaDataWithoutKey : MetaDataWithoutContainer +{ + protected override IPostprocessComposer ComposerAction(IFixture fixture, + ICustomizationComposer composer) => + base.ComposerAction(fixture, composer).Without(d => d.Key); +} diff --git a/test/Core.Test/AutoFixture/CipherFixtures.cs b/test/Core.Test/AutoFixture/CipherFixtures.cs index 5ef2976fb..b4c87ef8a 100644 --- a/test/Core.Test/AutoFixture/CipherFixtures.cs +++ b/test/Core.Test/AutoFixture/CipherFixtures.cs @@ -3,67 +3,66 @@ using Bit.Core.Entities; using Bit.Test.Common.AutoFixture.Attributes; using Core.Models.Data; -namespace Bit.Core.Test.AutoFixture.CipherFixtures +namespace Bit.Core.Test.AutoFixture.CipherFixtures; + +internal class OrganizationCipher : ICustomization { - internal class OrganizationCipher : ICustomization + public Guid? OrganizationId { get; set; } + public void Customize(IFixture fixture) { - public Guid? OrganizationId { get; set; } - public void Customize(IFixture fixture) - { - fixture.Customize(composer => composer - .With(c => c.OrganizationId, OrganizationId ?? Guid.NewGuid()) - .Without(c => c.UserId)); - fixture.Customize(composer => composer - .With(c => c.OrganizationId, Guid.NewGuid()) - .Without(c => c.UserId)); - } - } - - internal class UserCipher : ICustomization - { - public Guid? UserId { get; set; } - public void Customize(IFixture fixture) - { - fixture.Customize(composer => composer - .With(c => c.UserId, UserId ?? Guid.NewGuid()) - .Without(c => c.OrganizationId)); - fixture.Customize(composer => composer - .With(c => c.UserId, Guid.NewGuid()) - .Without(c => c.OrganizationId)); - } - } - - internal class UserCipherAutoDataAttribute : CustomAutoDataAttribute - { - public UserCipherAutoDataAttribute(string userId = null) : base(new SutProviderCustomization(), - new UserCipher { UserId = userId == null ? (Guid?)null : new Guid(userId) }) - { } - } - internal class InlineUserCipherAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineUserCipherAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(UserCipher) }, values) - { } - } - - internal class InlineKnownUserCipherAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineKnownUserCipherAutoDataAttribute(string userId, params object[] values) : base(new ICustomization[] - { new SutProviderCustomization(), new UserCipher { UserId = new Guid(userId) } }, values) - { } - } - - internal class OrganizationCipherAutoDataAttribute : CustomAutoDataAttribute - { - public OrganizationCipherAutoDataAttribute(string organizationId = null) : base(new SutProviderCustomization(), - new OrganizationCipher { OrganizationId = organizationId == null ? (Guid?)null : new Guid(organizationId) }) - { } - } - - internal class InlineOrganizationCipherAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineOrganizationCipherAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(OrganizationCipher) }, values) - { } + fixture.Customize(composer => composer + .With(c => c.OrganizationId, OrganizationId ?? Guid.NewGuid()) + .Without(c => c.UserId)); + fixture.Customize(composer => composer + .With(c => c.OrganizationId, Guid.NewGuid()) + .Without(c => c.UserId)); } } + +internal class UserCipher : ICustomization +{ + public Guid? UserId { get; set; } + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(c => c.UserId, UserId ?? Guid.NewGuid()) + .Without(c => c.OrganizationId)); + fixture.Customize(composer => composer + .With(c => c.UserId, Guid.NewGuid()) + .Without(c => c.OrganizationId)); + } +} + +internal class UserCipherAutoDataAttribute : CustomAutoDataAttribute +{ + public UserCipherAutoDataAttribute(string userId = null) : base(new SutProviderCustomization(), + new UserCipher { UserId = userId == null ? (Guid?)null : new Guid(userId) }) + { } +} +internal class InlineUserCipherAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineUserCipherAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(UserCipher) }, values) + { } +} + +internal class InlineKnownUserCipherAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineKnownUserCipherAutoDataAttribute(string userId, params object[] values) : base(new ICustomization[] + { new SutProviderCustomization(), new UserCipher { UserId = new Guid(userId) } }, values) + { } +} + +internal class OrganizationCipherAutoDataAttribute : CustomAutoDataAttribute +{ + public OrganizationCipherAutoDataAttribute(string organizationId = null) : base(new SutProviderCustomization(), + new OrganizationCipher { OrganizationId = organizationId == null ? (Guid?)null : new Guid(organizationId) }) + { } +} + +internal class InlineOrganizationCipherAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineOrganizationCipherAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(OrganizationCipher) }, values) + { } +} diff --git a/test/Core.Test/AutoFixture/CollectionFixtures.cs b/test/Core.Test/AutoFixture/CollectionFixtures.cs index 38517f5c0..26c169a44 100644 --- a/test/Core.Test/AutoFixture/CollectionFixtures.cs +++ b/test/Core.Test/AutoFixture/CollectionFixtures.cs @@ -1,11 +1,10 @@ using Bit.Core.Test.AutoFixture.OrganizationFixtures; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Core.Test.AutoFixture.CollectionFixtures +namespace Bit.Core.Test.AutoFixture.CollectionFixtures; + +internal class CollectionAutoDataAttribute : CustomAutoDataAttribute { - internal class CollectionAutoDataAttribute : CustomAutoDataAttribute - { - public CollectionAutoDataAttribute() : base(new SutProviderCustomization(), new OrganizationCustomization()) - { } - } + public CollectionAutoDataAttribute() : base(new SutProviderCustomization(), new OrganizationCustomization()) + { } } diff --git a/test/Core.Test/AutoFixture/CurrentContextFixtures.cs b/test/Core.Test/AutoFixture/CurrentContextFixtures.cs index 90187cf6e..1949dedd7 100644 --- a/test/Core.Test/AutoFixture/CurrentContextFixtures.cs +++ b/test/Core.Test/AutoFixture/CurrentContextFixtures.cs @@ -3,36 +3,35 @@ using AutoFixture.Kernel; using Bit.Core.Context; using Bit.Test.Common.AutoFixture; -namespace Bit.Core.Test.AutoFixture.CurrentContextFixtures +namespace Bit.Core.Test.AutoFixture.CurrentContextFixtures; + +internal class CurrentContext : ICustomization { - internal class CurrentContext : ICustomization + public void Customize(IFixture fixture) { - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new CurrentContextBuilder()); - } - } - - internal class CurrentContextBuilder : ISpecimenBuilder - { - public object Create(object request, ISpecimenContext context) - { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - if (!(request is Type typeRequest)) - { - return new NoSpecimen(); - } - if (typeof(ICurrentContext) != typeRequest) - { - return new NoSpecimen(); - } - - var obj = new Fixture().WithAutoNSubstitutions().Create(); - obj.Organizations = context.Create>(); - return obj; - } + fixture.Customizations.Add(new CurrentContextBuilder()); + } +} + +internal class CurrentContextBuilder : ISpecimenBuilder +{ + public object Create(object request, ISpecimenContext context) + { + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + if (!(request is Type typeRequest)) + { + return new NoSpecimen(); + } + if (typeof(ICurrentContext) != typeRequest) + { + return new NoSpecimen(); + } + + var obj = new Fixture().WithAutoNSubstitutions().Create(); + obj.Organizations = context.Create>(); + return obj; } } diff --git a/test/Core.Test/AutoFixture/GlobalSettingsFixtures.cs b/test/Core.Test/AutoFixture/GlobalSettingsFixtures.cs index a893840a4..b9c053c29 100644 --- a/test/Core.Test/AutoFixture/GlobalSettingsFixtures.cs +++ b/test/Core.Test/AutoFixture/GlobalSettingsFixtures.cs @@ -4,29 +4,28 @@ using AutoFixture.Kernel; using AutoFixture.Xunit2; using Bit.Core.Test.Helpers.Factories; -namespace Bit.Test.Common.AutoFixture +namespace Bit.Test.Common.AutoFixture; + +public class GlobalSettingsBuilder : ISpecimenBuilder { - public class GlobalSettingsBuilder : ISpecimenBuilder + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var pi = request as ParameterInfo; - var fixture = new Fixture(); - - if (pi == null || pi.ParameterType != typeof(Bit.Core.Settings.GlobalSettings)) - return new NoSpecimen(); - - return GlobalSettingsFactory.GlobalSettings; + throw new ArgumentNullException(nameof(context)); } - } - public class GlobalSettingsCustomizeAttribute : CustomizeAttribute - { - public override ICustomization GetCustomization(ParameterInfo parameter) => new GlobalSettings(); + var pi = request as ParameterInfo; + var fixture = new Fixture(); + + if (pi == null || pi.ParameterType != typeof(Bit.Core.Settings.GlobalSettings)) + return new NoSpecimen(); + + return GlobalSettingsFactory.GlobalSettings; } } + +public class GlobalSettingsCustomizeAttribute : CustomizeAttribute +{ + public override ICustomization GetCustomization(ParameterInfo parameter) => new GlobalSettings(); +} diff --git a/test/Core.Test/AutoFixture/GroupFixtures.cs b/test/Core.Test/AutoFixture/GroupFixtures.cs index 07b8f6a67..2501bbfc3 100644 --- a/test/Core.Test/AutoFixture/GroupFixtures.cs +++ b/test/Core.Test/AutoFixture/GroupFixtures.cs @@ -1,19 +1,18 @@ using Bit.Core.Test.AutoFixture.OrganizationFixtures; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Core.Test.AutoFixture.GroupFixtures -{ - internal class GroupOrganizationAutoDataAttribute : CustomAutoDataAttribute - { - public GroupOrganizationAutoDataAttribute() : base( - new SutProviderCustomization(), new OrganizationCustomization { UseGroups = true }) - { } - } +namespace Bit.Core.Test.AutoFixture.GroupFixtures; - internal class GroupOrganizationNotUseGroupsAutoDataAttribute : CustomAutoDataAttribute - { - public GroupOrganizationNotUseGroupsAutoDataAttribute() : base( - new SutProviderCustomization(), new OrganizationCustomization { UseGroups = false }) - { } - } +internal class GroupOrganizationAutoDataAttribute : CustomAutoDataAttribute +{ + public GroupOrganizationAutoDataAttribute() : base( + new SutProviderCustomization(), new OrganizationCustomization { UseGroups = true }) + { } +} + +internal class GroupOrganizationNotUseGroupsAutoDataAttribute : CustomAutoDataAttribute +{ + public GroupOrganizationNotUseGroupsAutoDataAttribute() : base( + new SutProviderCustomization(), new OrganizationCustomization { UseGroups = false }) + { } } diff --git a/test/Core.Test/AutoFixture/OrganizationFixtures.cs b/test/Core.Test/AutoFixture/OrganizationFixtures.cs index c496471c1..0641cb29e 100644 --- a/test/Core.Test/AutoFixture/OrganizationFixtures.cs +++ b/test/Core.Test/AutoFixture/OrganizationFixtures.cs @@ -10,177 +10,176 @@ using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Core.Test.AutoFixture.OrganizationFixtures +namespace Bit.Core.Test.AutoFixture.OrganizationFixtures; + +public class OrganizationCustomization : ICustomization { - public class OrganizationCustomization : ICustomization + public bool UseGroups { get; set; } + + public void Customize(IFixture fixture) { - public bool UseGroups { get; set; } + var organizationId = Guid.NewGuid(); + var maxConnections = (short)new Random().Next(10, short.MaxValue); - public void Customize(IFixture fixture) - { - var organizationId = Guid.NewGuid(); - var maxConnections = (short)new Random().Next(10, short.MaxValue); + fixture.Customize(composer => composer + .With(o => o.Id, organizationId) + .With(o => o.MaxCollections, maxConnections) + .With(o => o.UseGroups, UseGroups)); - fixture.Customize(composer => composer - .With(o => o.Id, organizationId) - .With(o => o.MaxCollections, maxConnections) - .With(o => o.UseGroups, UseGroups)); + fixture.Customize(composer => + composer + .With(c => c.OrganizationId, organizationId) + .Without(o => o.CreationDate) + .Without(o => o.RevisionDate)); - fixture.Customize(composer => - composer - .With(c => c.OrganizationId, organizationId) - .Without(o => o.CreationDate) - .Without(o => o.RevisionDate)); - - fixture.Customize(composer => composer.With(g => g.OrganizationId, organizationId)); - } - } - - internal class OrganizationBuilder : ISpecimenBuilder - { - public object Create(object request, ISpecimenContext context) - { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != typeof(Organization)) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - var providers = fixture.Create>(); - var organization = new Fixture().WithAutoNSubstitutions().Create(); - organization.SetTwoFactorProviders(providers); - return organization; - } - } - - internal class PaidOrganization : ICustomization - { - public PlanType CheckedPlanType { get; set; } - public void Customize(IFixture fixture) - { - var validUpgradePlans = StaticStore.Plans.Where(p => p.Type != PlanType.Free && !p.Disabled).Select(p => p.Type).ToList(); - var lowestActivePaidPlan = validUpgradePlans.First(); - CheckedPlanType = CheckedPlanType.Equals(PlanType.Free) ? lowestActivePaidPlan : CheckedPlanType; - validUpgradePlans.Remove(lowestActivePaidPlan); - fixture.Customize(composer => composer - .With(o => o.PlanType, CheckedPlanType)); - fixture.Customize(composer => composer - .With(ou => ou.Plan, validUpgradePlans.First())); - } - } - - internal class FreeOrganization : ICustomization - { - public void Customize(IFixture fixture) - { - fixture.Customize(composer => composer - .With(o => o.PlanType, PlanType.Free)); - } - } - - internal class FreeOrganizationUpgrade : ICustomization - { - public void Customize(IFixture fixture) - { - fixture.Customize(composer => composer - .With(o => o.PlanType, PlanType.Free)); - - var plansToIgnore = new List { PlanType.Free, PlanType.Custom }; - var selectedPlan = StaticStore.Plans.Last(p => !plansToIgnore.Contains(p.Type) && !p.Disabled); - - fixture.Customize(composer => composer - .With(ou => ou.Plan, selectedPlan.Type) - .With(ou => ou.PremiumAccessAddon, selectedPlan.HasPremiumAccessOption)); - fixture.Customize(composer => composer - .Without(o => o.GatewaySubscriptionId)); - } - } - - internal class OrganizationInvite : ICustomization - { - public OrganizationUserType InviteeUserType { get; set; } - public OrganizationUserType InvitorUserType { get; set; } - public string PermissionsBlob { get; set; } - public void Customize(IFixture fixture) - { - var organizationId = new Guid(); - PermissionsBlob = PermissionsBlob ?? JsonSerializer.Serialize(new Permissions(), new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }); - fixture.Customize(composer => composer - .With(o => o.Id, organizationId) - .With(o => o.Seats, (short)100)); - fixture.Customize(composer => composer - .With(ou => ou.OrganizationId, organizationId) - .With(ou => ou.Type, InvitorUserType) - .With(ou => ou.Permissions, PermissionsBlob)); - fixture.Customize(composer => composer - .With(oi => oi.Type, InviteeUserType)); - } - } - - internal class PaidOrganizationAutoDataAttribute : CustomAutoDataAttribute - { - public PaidOrganizationAutoDataAttribute(PlanType planType) : base(new SutProviderCustomization(), - new PaidOrganization { CheckedPlanType = planType }) - { } - public PaidOrganizationAutoDataAttribute(int planType = 0) : this((PlanType)planType) { } - } - - internal class InlinePaidOrganizationAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlinePaidOrganizationAutoDataAttribute(PlanType planType, object[] values) : base( - new ICustomization[] { new SutProviderCustomization(), new PaidOrganization { CheckedPlanType = planType } }, values) - { } - - public InlinePaidOrganizationAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(PaidOrganization) }, values) - { } - } - - internal class InlineFreeOrganizationAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineFreeOrganizationAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(FreeOrganization) }, values) - { } - } - - internal class FreeOrganizationUpgradeAutoDataAttribute : CustomAutoDataAttribute - { - public FreeOrganizationUpgradeAutoDataAttribute() : base(new SutProviderCustomization(), new FreeOrganizationUpgrade()) - { } - } - - internal class InlineFreeOrganizationUpgradeAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineFreeOrganizationUpgradeAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(FreeOrganizationUpgrade) }, values) - { } - } - - internal class OrganizationInviteAutoDataAttribute : CustomAutoDataAttribute - { - public OrganizationInviteAutoDataAttribute(int inviteeUserType = 0, int invitorUserType = 0, string permissionsBlob = null) : base(new SutProviderCustomization(), - new OrganizationInvite - { - InviteeUserType = (OrganizationUserType)inviteeUserType, - InvitorUserType = (OrganizationUserType)invitorUserType, - PermissionsBlob = permissionsBlob, - }) - { } - } - - internal class InlineOrganizationInviteAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineOrganizationInviteAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(OrganizationInvite) }, values) - { } + fixture.Customize(composer => composer.With(g => g.OrganizationId, organizationId)); } } + +internal class OrganizationBuilder : ISpecimenBuilder +{ + public object Create(object request, ISpecimenContext context) + { + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + var type = request as Type; + if (type == null || type != typeof(Organization)) + { + return new NoSpecimen(); + } + + var fixture = new Fixture(); + var providers = fixture.Create>(); + var organization = new Fixture().WithAutoNSubstitutions().Create(); + organization.SetTwoFactorProviders(providers); + return organization; + } +} + +internal class PaidOrganization : ICustomization +{ + public PlanType CheckedPlanType { get; set; } + public void Customize(IFixture fixture) + { + var validUpgradePlans = StaticStore.Plans.Where(p => p.Type != PlanType.Free && !p.Disabled).Select(p => p.Type).ToList(); + var lowestActivePaidPlan = validUpgradePlans.First(); + CheckedPlanType = CheckedPlanType.Equals(PlanType.Free) ? lowestActivePaidPlan : CheckedPlanType; + validUpgradePlans.Remove(lowestActivePaidPlan); + fixture.Customize(composer => composer + .With(o => o.PlanType, CheckedPlanType)); + fixture.Customize(composer => composer + .With(ou => ou.Plan, validUpgradePlans.First())); + } +} + +internal class FreeOrganization : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(o => o.PlanType, PlanType.Free)); + } +} + +internal class FreeOrganizationUpgrade : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(o => o.PlanType, PlanType.Free)); + + var plansToIgnore = new List { PlanType.Free, PlanType.Custom }; + var selectedPlan = StaticStore.Plans.Last(p => !plansToIgnore.Contains(p.Type) && !p.Disabled); + + fixture.Customize(composer => composer + .With(ou => ou.Plan, selectedPlan.Type) + .With(ou => ou.PremiumAccessAddon, selectedPlan.HasPremiumAccessOption)); + fixture.Customize(composer => composer + .Without(o => o.GatewaySubscriptionId)); + } +} + +internal class OrganizationInvite : ICustomization +{ + public OrganizationUserType InviteeUserType { get; set; } + public OrganizationUserType InvitorUserType { get; set; } + public string PermissionsBlob { get; set; } + public void Customize(IFixture fixture) + { + var organizationId = new Guid(); + PermissionsBlob = PermissionsBlob ?? JsonSerializer.Serialize(new Permissions(), new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }); + fixture.Customize(composer => composer + .With(o => o.Id, organizationId) + .With(o => o.Seats, (short)100)); + fixture.Customize(composer => composer + .With(ou => ou.OrganizationId, organizationId) + .With(ou => ou.Type, InvitorUserType) + .With(ou => ou.Permissions, PermissionsBlob)); + fixture.Customize(composer => composer + .With(oi => oi.Type, InviteeUserType)); + } +} + +internal class PaidOrganizationAutoDataAttribute : CustomAutoDataAttribute +{ + public PaidOrganizationAutoDataAttribute(PlanType planType) : base(new SutProviderCustomization(), + new PaidOrganization { CheckedPlanType = planType }) + { } + public PaidOrganizationAutoDataAttribute(int planType = 0) : this((PlanType)planType) { } +} + +internal class InlinePaidOrganizationAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlinePaidOrganizationAutoDataAttribute(PlanType planType, object[] values) : base( + new ICustomization[] { new SutProviderCustomization(), new PaidOrganization { CheckedPlanType = planType } }, values) + { } + + public InlinePaidOrganizationAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(PaidOrganization) }, values) + { } +} + +internal class InlineFreeOrganizationAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineFreeOrganizationAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(FreeOrganization) }, values) + { } +} + +internal class FreeOrganizationUpgradeAutoDataAttribute : CustomAutoDataAttribute +{ + public FreeOrganizationUpgradeAutoDataAttribute() : base(new SutProviderCustomization(), new FreeOrganizationUpgrade()) + { } +} + +internal class InlineFreeOrganizationUpgradeAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineFreeOrganizationUpgradeAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(FreeOrganizationUpgrade) }, values) + { } +} + +internal class OrganizationInviteAutoDataAttribute : CustomAutoDataAttribute +{ + public OrganizationInviteAutoDataAttribute(int inviteeUserType = 0, int invitorUserType = 0, string permissionsBlob = null) : base(new SutProviderCustomization(), + new OrganizationInvite + { + InviteeUserType = (OrganizationUserType)inviteeUserType, + InvitorUserType = (OrganizationUserType)invitorUserType, + PermissionsBlob = permissionsBlob, + }) + { } +} + +internal class InlineOrganizationInviteAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineOrganizationInviteAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(OrganizationInvite) }, values) + { } +} diff --git a/test/Core.Test/AutoFixture/OrganizationLicenseCustomization.cs b/test/Core.Test/AutoFixture/OrganizationLicenseCustomization.cs index 11c8cd8cb..66a7f5224 100644 --- a/test/Core.Test/AutoFixture/OrganizationLicenseCustomization.cs +++ b/test/Core.Test/AutoFixture/OrganizationLicenseCustomization.cs @@ -2,18 +2,17 @@ using Bit.Core.Models.Business; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Core.Test.AutoFixture +namespace Bit.Core.Test.AutoFixture; + +public class OrganizationLicenseCustomizeAttribute : BitCustomizeAttribute { - public class OrganizationLicenseCustomizeAttribute : BitCustomizeAttribute + public override ICustomization GetCustomization() => new OrganizationLicenseCustomization(); +} +public class OrganizationLicenseCustomization : ICustomization +{ + public void Customize(IFixture fixture) { - public override ICustomization GetCustomization() => new OrganizationLicenseCustomization(); - } - public class OrganizationLicenseCustomization : ICustomization - { - public void Customize(IFixture fixture) - { - fixture.Customize(composer => composer - .With(o => o.Signature, Guid.NewGuid().ToString().Replace('-', '+'))); - } + fixture.Customize(composer => composer + .With(o => o.Signature, Guid.NewGuid().ToString().Replace('-', '+'))); } } diff --git a/test/Core.Test/AutoFixture/OrganizationSponsorshipFixtures.cs b/test/Core.Test/AutoFixture/OrganizationSponsorshipFixtures.cs index a40c15917..b9172ae70 100644 --- a/test/Core.Test/AutoFixture/OrganizationSponsorshipFixtures.cs +++ b/test/Core.Test/AutoFixture/OrganizationSponsorshipFixtures.cs @@ -2,32 +2,31 @@ using Bit.Core.Entities; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Core.Test.AutoFixture.OrganizationSponsorshipFixtures +namespace Bit.Core.Test.AutoFixture.OrganizationSponsorshipFixtures; + +public class OrganizationSponsorshipCustomizeAttribute : BitCustomizeAttribute { - public class OrganizationSponsorshipCustomizeAttribute : BitCustomizeAttribute - { - public bool ToDelete = false; - public override ICustomization GetCustomization() => ToDelete ? - new ToDeleteOrganizationSponsorship() : - new ValidOrganizationSponsorship(); - } + public bool ToDelete = false; + public override ICustomization GetCustomization() => ToDelete ? + new ToDeleteOrganizationSponsorship() : + new ValidOrganizationSponsorship(); +} - public class ValidOrganizationSponsorship : ICustomization +public class ValidOrganizationSponsorship : ICustomization +{ + public void Customize(IFixture fixture) { - public void Customize(IFixture fixture) - { - fixture.Customize(composer => composer - .With(s => s.ToDelete, false) - .With(s => s.LastSyncDate, DateTime.UtcNow.AddDays(new Random().Next(-90, 0)))); - } - } - - public class ToDeleteOrganizationSponsorship : ICustomization - { - public void Customize(IFixture fixture) - { - fixture.Customize(composer => composer - .With(s => s.ToDelete, true)); - } + fixture.Customize(composer => composer + .With(s => s.ToDelete, false) + .With(s => s.LastSyncDate, DateTime.UtcNow.AddDays(new Random().Next(-90, 0)))); + } +} + +public class ToDeleteOrganizationSponsorship : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(s => s.ToDelete, true)); } } diff --git a/test/Core.Test/AutoFixture/OrganizationUserFixtures.cs b/test/Core.Test/AutoFixture/OrganizationUserFixtures.cs index 975b45313..74bdbfc51 100644 --- a/test/Core.Test/AutoFixture/OrganizationUserFixtures.cs +++ b/test/Core.Test/AutoFixture/OrganizationUserFixtures.cs @@ -4,43 +4,42 @@ using AutoFixture.Xunit2; using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Test.AutoFixture.OrganizationUserFixtures +namespace Bit.Core.Test.AutoFixture.OrganizationUserFixtures; + +public class OrganizationUserCustomization : ICustomization { - public class OrganizationUserCustomization : ICustomization + public OrganizationUserStatusType Status { get; set; } + public OrganizationUserType Type { get; set; } + + public OrganizationUserCustomization(OrganizationUserStatusType status, OrganizationUserType type) { - public OrganizationUserStatusType Status { get; set; } - public OrganizationUserType Type { get; set; } - - public OrganizationUserCustomization(OrganizationUserStatusType status, OrganizationUserType type) - { - Status = status; - Type = type; - } - - public void Customize(IFixture fixture) - { - fixture.Customize(composer => composer - .With(o => o.Type, Type) - .With(o => o.Status, Status)); - } + Status = status; + Type = type; } - public class OrganizationUserAttribute : CustomizeAttribute + public void Customize(IFixture fixture) { - private readonly OrganizationUserStatusType _status; - private readonly OrganizationUserType _type; - - public OrganizationUserAttribute( - OrganizationUserStatusType status = OrganizationUserStatusType.Confirmed, - OrganizationUserType type = OrganizationUserType.User) - { - _status = status; - _type = type; - } - - public override ICustomization GetCustomization(ParameterInfo parameter) - { - return new OrganizationUserCustomization(_status, _type); - } + fixture.Customize(composer => composer + .With(o => o.Type, Type) + .With(o => o.Status, Status)); + } +} + +public class OrganizationUserAttribute : CustomizeAttribute +{ + private readonly OrganizationUserStatusType _status; + private readonly OrganizationUserType _type; + + public OrganizationUserAttribute( + OrganizationUserStatusType status = OrganizationUserStatusType.Confirmed, + OrganizationUserType type = OrganizationUserType.User) + { + _status = status; + _type = type; + } + + public override ICustomization GetCustomization(ParameterInfo parameter) + { + return new OrganizationUserCustomization(_status, _type); } } diff --git a/test/Core.Test/AutoFixture/PolicyFixtures.cs b/test/Core.Test/AutoFixture/PolicyFixtures.cs index b3da0e698..fb8109baf 100644 --- a/test/Core.Test/AutoFixture/PolicyFixtures.cs +++ b/test/Core.Test/AutoFixture/PolicyFixtures.cs @@ -4,38 +4,37 @@ using AutoFixture.Xunit2; using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.Test.AutoFixture.PolicyFixtures +namespace Bit.Core.Test.AutoFixture.PolicyFixtures; + +internal class PolicyCustomization : ICustomization { - internal class PolicyCustomization : ICustomization + public PolicyType Type { get; set; } + + public PolicyCustomization(PolicyType type) { - public PolicyType Type { get; set; } - - public PolicyCustomization(PolicyType type) - { - Type = type; - } - - public void Customize(IFixture fixture) - { - fixture.Customize(composer => composer - .With(o => o.OrganizationId, Guid.NewGuid()) - .With(o => o.Type, Type) - .With(o => o.Enabled, true)); - } + Type = type; } - public class PolicyAttribute : CustomizeAttribute + public void Customize(IFixture fixture) { - private readonly PolicyType _type; - - public PolicyAttribute(PolicyType type) - { - _type = type; - } - - public override ICustomization GetCustomization(ParameterInfo parameter) - { - return new PolicyCustomization(_type); - } + fixture.Customize(composer => composer + .With(o => o.OrganizationId, Guid.NewGuid()) + .With(o => o.Type, Type) + .With(o => o.Enabled, true)); + } +} + +public class PolicyAttribute : CustomizeAttribute +{ + private readonly PolicyType _type; + + public PolicyAttribute(PolicyType type) + { + _type = type; + } + + public override ICustomization GetCustomization(ParameterInfo parameter) + { + return new PolicyCustomization(_type); } } diff --git a/test/Core.Test/AutoFixture/SendFixtures.cs b/test/Core.Test/AutoFixture/SendFixtures.cs index 573f32288..b7cdeeafd 100644 --- a/test/Core.Test/AutoFixture/SendFixtures.cs +++ b/test/Core.Test/AutoFixture/SendFixtures.cs @@ -2,63 +2,62 @@ using Bit.Core.Entities; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Core.Test.AutoFixture.SendFixtures +namespace Bit.Core.Test.AutoFixture.SendFixtures; + +internal class OrganizationSend : ICustomization { - internal class OrganizationSend : ICustomization + public Guid? OrganizationId { get; set; } + public void Customize(IFixture fixture) { - public Guid? OrganizationId { get; set; } - public void Customize(IFixture fixture) - { - fixture.Customize(composer => composer - .With(s => s.OrganizationId, OrganizationId ?? Guid.NewGuid()) - .Without(s => s.UserId)); - } - } - - internal class UserSend : ICustomization - { - public Guid? UserId { get; set; } - public void Customize(IFixture fixture) - { - fixture.Customize(composer => composer - .With(s => s.UserId, UserId ?? Guid.NewGuid()) - .Without(s => s.OrganizationId)); - } - } - - internal class UserSendAutoDataAttribute : CustomAutoDataAttribute - { - public UserSendAutoDataAttribute(string userId = null) : base(new SutProviderCustomization(), - new UserSend { UserId = userId == null ? (Guid?)null : new Guid(userId) }) - { } - } - internal class InlineUserSendAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineUserSendAutoDataAttribute(params object[] values) : base(new[] { typeof(CurrentContextFixtures.CurrentContext), - typeof(SutProviderCustomization), typeof(UserSend) }, values) - { } - } - - internal class InlineKnownUserSendAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineKnownUserSendAutoDataAttribute(string userId, params object[] values) : base(new ICustomization[] - { new CurrentContextFixtures.CurrentContext(), new SutProviderCustomization(), - new UserSend { UserId = new Guid(userId) } }, values) - { } - } - - internal class OrganizationSendAutoDataAttribute : CustomAutoDataAttribute - { - public OrganizationSendAutoDataAttribute(string organizationId = null) : base(new CurrentContextFixtures.CurrentContext(), - new SutProviderCustomization(), - new OrganizationSend { OrganizationId = organizationId == null ? (Guid?)null : new Guid(organizationId) }) - { } - } - - internal class InlineOrganizationSendAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineOrganizationSendAutoDataAttribute(params object[] values) : base(new[] { typeof(CurrentContextFixtures.CurrentContext), - typeof(SutProviderCustomization), typeof(OrganizationSend) }, values) - { } + fixture.Customize(composer => composer + .With(s => s.OrganizationId, OrganizationId ?? Guid.NewGuid()) + .Without(s => s.UserId)); } } + +internal class UserSend : ICustomization +{ + public Guid? UserId { get; set; } + public void Customize(IFixture fixture) + { + fixture.Customize(composer => composer + .With(s => s.UserId, UserId ?? Guid.NewGuid()) + .Without(s => s.OrganizationId)); + } +} + +internal class UserSendAutoDataAttribute : CustomAutoDataAttribute +{ + public UserSendAutoDataAttribute(string userId = null) : base(new SutProviderCustomization(), + new UserSend { UserId = userId == null ? (Guid?)null : new Guid(userId) }) + { } +} +internal class InlineUserSendAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineUserSendAutoDataAttribute(params object[] values) : base(new[] { typeof(CurrentContextFixtures.CurrentContext), + typeof(SutProviderCustomization), typeof(UserSend) }, values) + { } +} + +internal class InlineKnownUserSendAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineKnownUserSendAutoDataAttribute(string userId, params object[] values) : base(new ICustomization[] + { new CurrentContextFixtures.CurrentContext(), new SutProviderCustomization(), + new UserSend { UserId = new Guid(userId) } }, values) + { } +} + +internal class OrganizationSendAutoDataAttribute : CustomAutoDataAttribute +{ + public OrganizationSendAutoDataAttribute(string organizationId = null) : base(new CurrentContextFixtures.CurrentContext(), + new SutProviderCustomization(), + new OrganizationSend { OrganizationId = organizationId == null ? (Guid?)null : new Guid(organizationId) }) + { } +} + +internal class InlineOrganizationSendAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineOrganizationSendAutoDataAttribute(params object[] values) : base(new[] { typeof(CurrentContextFixtures.CurrentContext), + typeof(SutProviderCustomization), typeof(OrganizationSend) }, values) + { } +} diff --git a/test/Core.Test/AutoFixture/UserFixtures.cs b/test/Core.Test/AutoFixture/UserFixtures.cs index 98707938a..39221aafc 100644 --- a/test/Core.Test/AutoFixture/UserFixtures.cs +++ b/test/Core.Test/AutoFixture/UserFixtures.cs @@ -6,49 +6,48 @@ using Bit.Core.Models; using Bit.Core.Test.AutoFixture.OrganizationFixtures; using Bit.Test.Common.AutoFixture; -namespace Bit.Core.Test.AutoFixture.UserFixtures +namespace Bit.Core.Test.AutoFixture.UserFixtures; + +public class UserBuilder : ISpecimenBuilder { - public class UserBuilder : ISpecimenBuilder + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } + throw new ArgumentNullException(nameof(context)); + } - var type = request as Type; - if (type == typeof(User)) + var type = request as Type; + if (type == typeof(User)) + { + var fixture = new Fixture(); + var providers = fixture.Create>(); + var user = fixture.WithAutoNSubstitutions().Create(); + user.SetTwoFactorProviders(providers); + return user; + } + else if (type == typeof(List)) + { + var fixture = new Fixture(); + var users = fixture.WithAutoNSubstitutions().CreateMany(2); + foreach (var user in users) { - var fixture = new Fixture(); var providers = fixture.Create>(); - var user = fixture.WithAutoNSubstitutions().Create(); user.SetTwoFactorProviders(providers); - return user; } - else if (type == typeof(List)) - { - var fixture = new Fixture(); - var users = fixture.WithAutoNSubstitutions().CreateMany(2); - foreach (var user in users) - { - var providers = fixture.Create>(); - user.SetTwoFactorProviders(providers); - } - return users; - } - - return new NoSpecimen(); + return users; } - } - public class UserFixture : ICustomization - { - public virtual void Customize(IFixture fixture) - { - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - } + return new NoSpecimen(); + } +} + +public class UserFixture : ICustomization +{ + public virtual void Customize(IFixture fixture) + { + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); } } diff --git a/test/Core.Test/Entities/OrganizationTests.cs b/test/Core.Test/Entities/OrganizationTests.cs index c24d6effc..5a86c3fd0 100644 --- a/test/Core.Test/Entities/OrganizationTests.cs +++ b/test/Core.Test/Entities/OrganizationTests.cs @@ -5,96 +5,95 @@ using Bit.Core.Models; using Bit.Test.Common.Helpers; using Xunit; -namespace Bit.Core.Test.Entities +namespace Bit.Core.Test.Entities; + +public class OrganizationTests { - public class OrganizationTests + private static readonly Dictionary _testConfig = new Dictionary() { - private static readonly Dictionary _testConfig = new Dictionary() + [TwoFactorProviderType.OrganizationDuo] = new TwoFactorProvider { - [TwoFactorProviderType.OrganizationDuo] = new TwoFactorProvider + Enabled = true, + MetaData = new Dictionary { - Enabled = true, - MetaData = new Dictionary - { - ["IKey"] = "IKey_value", - ["SKey"] = "SKey_value", - ["Host"] = "Host_value", - }, - } + ["IKey"] = "IKey_value", + ["SKey"] = "SKey_value", + ["Host"] = "Host_value", + }, + } + }; + + + [Fact] + public void SetTwoFactorProviders_Success() + { + var organization = new Organization(); + organization.SetTwoFactorProviders(_testConfig); + + using var jsonDocument = JsonDocument.Parse(organization.TwoFactorProviders); + var root = jsonDocument.RootElement; + + var duo = AssertHelper.AssertJsonProperty(root, "6", JsonValueKind.Object); + AssertHelper.AssertJsonProperty(duo, "Enabled", JsonValueKind.True); + var duoMetaData = AssertHelper.AssertJsonProperty(duo, "MetaData", JsonValueKind.Object); + var iKey = AssertHelper.AssertJsonProperty(duoMetaData, "IKey", JsonValueKind.String).GetString(); + Assert.Equal("IKey_value", iKey); + var sKey = AssertHelper.AssertJsonProperty(duoMetaData, "SKey", JsonValueKind.String).GetString(); + Assert.Equal("SKey_value", sKey); + var host = AssertHelper.AssertJsonProperty(duoMetaData, "Host", JsonValueKind.String).GetString(); + Assert.Equal("Host_value", host); + } + + [Fact] + public void GetTwoFactorProviders_Success() + { + // This is to get rid of the cached dictionary the SetTwoFactorProviders keeps so we can fully test the JSON reading + // It intent is to mimic a storing of the entity in the database and it being read later + var tempOrganization = new Organization(); + tempOrganization.SetTwoFactorProviders(_testConfig); + var organization = new Organization + { + TwoFactorProviders = tempOrganization.TwoFactorProviders, }; + var twoFactorProviders = organization.GetTwoFactorProviders(); - [Fact] - public void SetTwoFactorProviders_Success() - { - var organization = new Organization(); - organization.SetTwoFactorProviders(_testConfig); + var duo = Assert.Contains(TwoFactorProviderType.OrganizationDuo, (IDictionary)twoFactorProviders); + Assert.True(duo.Enabled); + Assert.NotNull(duo.MetaData); + var iKey = Assert.Contains("IKey", (IDictionary)duo.MetaData); + Assert.Equal("IKey_value", iKey); + var sKey = Assert.Contains("SKey", (IDictionary)duo.MetaData); + Assert.Equal("SKey_value", sKey); + var host = Assert.Contains("Host", (IDictionary)duo.MetaData); + Assert.Equal("Host_value", host); + } - using var jsonDocument = JsonDocument.Parse(organization.TwoFactorProviders); - var root = jsonDocument.RootElement; + [Fact] + public void GetTwoFactorProviders_SavedWithName_Success() + { + var organization = new Organization(); + // This should save items with the string name of the enum and we will validate that we can read + // from that just incase some organizations have it saved that way. + organization.TwoFactorProviders = JsonSerializer.Serialize(_testConfig); - var duo = AssertHelper.AssertJsonProperty(root, "6", JsonValueKind.Object); - AssertHelper.AssertJsonProperty(duo, "Enabled", JsonValueKind.True); - var duoMetaData = AssertHelper.AssertJsonProperty(duo, "MetaData", JsonValueKind.Object); - var iKey = AssertHelper.AssertJsonProperty(duoMetaData, "IKey", JsonValueKind.String).GetString(); - Assert.Equal("IKey_value", iKey); - var sKey = AssertHelper.AssertJsonProperty(duoMetaData, "SKey", JsonValueKind.String).GetString(); - Assert.Equal("SKey_value", sKey); - var host = AssertHelper.AssertJsonProperty(duoMetaData, "Host", JsonValueKind.String).GetString(); - Assert.Equal("Host_value", host); - } + // Preliminary Asserts to make sure we are testing what we want to be testing + using var jsonDocument = JsonDocument.Parse(organization.TwoFactorProviders); + var root = jsonDocument.RootElement; + // This means it saved the enum as its string name + AssertHelper.AssertJsonProperty(root, "OrganizationDuo", JsonValueKind.Object); - [Fact] - public void GetTwoFactorProviders_Success() - { - // This is to get rid of the cached dictionary the SetTwoFactorProviders keeps so we can fully test the JSON reading - // It intent is to mimic a storing of the entity in the database and it being read later - var tempOrganization = new Organization(); - tempOrganization.SetTwoFactorProviders(_testConfig); - var organization = new Organization - { - TwoFactorProviders = tempOrganization.TwoFactorProviders, - }; + // Actual checks + var twoFactorProviders = organization.GetTwoFactorProviders(); - var twoFactorProviders = organization.GetTwoFactorProviders(); - - var duo = Assert.Contains(TwoFactorProviderType.OrganizationDuo, (IDictionary)twoFactorProviders); - Assert.True(duo.Enabled); - Assert.NotNull(duo.MetaData); - var iKey = Assert.Contains("IKey", (IDictionary)duo.MetaData); - Assert.Equal("IKey_value", iKey); - var sKey = Assert.Contains("SKey", (IDictionary)duo.MetaData); - Assert.Equal("SKey_value", sKey); - var host = Assert.Contains("Host", (IDictionary)duo.MetaData); - Assert.Equal("Host_value", host); - } - - [Fact] - public void GetTwoFactorProviders_SavedWithName_Success() - { - var organization = new Organization(); - // This should save items with the string name of the enum and we will validate that we can read - // from that just incase some organizations have it saved that way. - organization.TwoFactorProviders = JsonSerializer.Serialize(_testConfig); - - // Preliminary Asserts to make sure we are testing what we want to be testing - using var jsonDocument = JsonDocument.Parse(organization.TwoFactorProviders); - var root = jsonDocument.RootElement; - // This means it saved the enum as its string name - AssertHelper.AssertJsonProperty(root, "OrganizationDuo", JsonValueKind.Object); - - // Actual checks - var twoFactorProviders = organization.GetTwoFactorProviders(); - - var duo = Assert.Contains(TwoFactorProviderType.OrganizationDuo, (IDictionary)twoFactorProviders); - Assert.True(duo.Enabled); - Assert.NotNull(duo.MetaData); - var iKey = Assert.Contains("IKey", (IDictionary)duo.MetaData); - Assert.Equal("IKey_value", iKey); - var sKey = Assert.Contains("SKey", (IDictionary)duo.MetaData); - Assert.Equal("SKey_value", sKey); - var host = Assert.Contains("Host", (IDictionary)duo.MetaData); - Assert.Equal("Host_value", host); - } + var duo = Assert.Contains(TwoFactorProviderType.OrganizationDuo, (IDictionary)twoFactorProviders); + Assert.True(duo.Enabled); + Assert.NotNull(duo.MetaData); + var iKey = Assert.Contains("IKey", (IDictionary)duo.MetaData); + Assert.Equal("IKey_value", iKey); + var sKey = Assert.Contains("SKey", (IDictionary)duo.MetaData); + Assert.Equal("SKey_value", sKey); + var host = Assert.Contains("Host", (IDictionary)duo.MetaData); + Assert.Equal("Host_value", host); } } diff --git a/test/Core.Test/Entities/UserTests.cs b/test/Core.Test/Entities/UserTests.cs index c60b031da..8a1986cd9 100644 --- a/test/Core.Test/Entities/UserTests.cs +++ b/test/Core.Test/Entities/UserTests.cs @@ -5,141 +5,140 @@ using Bit.Core.Models; using Bit.Test.Common.Helpers; using Xunit; -namespace Bit.Core.Test.Entities +namespace Bit.Core.Test.Entities; + +public class UserTests { - public class UserTests + // KB MB GB + public const long Multiplier = 1024 * 1024 * 1024; + + [Fact] + public void StorageBytesRemaining_HasMax_DoesNotHaveStorage_ReturnsMaxAsBytes() { - // KB MB GB - public const long Multiplier = 1024 * 1024 * 1024; + short maxStorageGb = 1; - [Fact] - public void StorageBytesRemaining_HasMax_DoesNotHaveStorage_ReturnsMaxAsBytes() + var user = new User { - short maxStorageGb = 1; - - var user = new User - { - MaxStorageGb = maxStorageGb, - Storage = null, - }; - - var bytesRemaining = user.StorageBytesRemaining(); - - Assert.Equal(bytesRemaining, maxStorageGb * Multiplier); - } - - [Theory] - [InlineData(2, 1 * Multiplier, 1 * Multiplier)] - - public void StorageBytesRemaining_HasMax_HasStorage_ReturnRemainingStorage(short maxStorageGb, long storageBytes, long expectedRemainingBytes) - { - var user = new User - { - MaxStorageGb = maxStorageGb, - Storage = storageBytes, - }; - - var bytesRemaining = user.StorageBytesRemaining(); - - Assert.Equal(expectedRemainingBytes, bytesRemaining); - } - - private static readonly Dictionary _testTwoFactorConfig = new Dictionary - { - [TwoFactorProviderType.WebAuthn] = new TwoFactorProvider - { - Enabled = true, - MetaData = new Dictionary - { - ["Item"] = "thing", - }, - }, - [TwoFactorProviderType.Email] = new TwoFactorProvider - { - Enabled = false, - MetaData = new Dictionary - { - ["Email"] = "test@email.com", - }, - }, + MaxStorageGb = maxStorageGb, + Storage = null, }; - [Fact] - public void SetTwoFactorProviders_Success() + var bytesRemaining = user.StorageBytesRemaining(); + + Assert.Equal(bytesRemaining, maxStorageGb * Multiplier); + } + + [Theory] + [InlineData(2, 1 * Multiplier, 1 * Multiplier)] + + public void StorageBytesRemaining_HasMax_HasStorage_ReturnRemainingStorage(short maxStorageGb, long storageBytes, long expectedRemainingBytes) + { + var user = new User { - var user = new User(); - user.SetTwoFactorProviders(_testTwoFactorConfig); + MaxStorageGb = maxStorageGb, + Storage = storageBytes, + }; - using var jsonDocument = JsonDocument.Parse(user.TwoFactorProviders); - var root = jsonDocument.RootElement; + var bytesRemaining = user.StorageBytesRemaining(); - var webAuthn = AssertHelper.AssertJsonProperty(root, "7", JsonValueKind.Object); - AssertHelper.AssertJsonProperty(webAuthn, "Enabled", JsonValueKind.True); - var webMetaData = AssertHelper.AssertJsonProperty(webAuthn, "MetaData", JsonValueKind.Object); - AssertHelper.AssertJsonProperty(webMetaData, "Item", JsonValueKind.String); + Assert.Equal(expectedRemainingBytes, bytesRemaining); + } - var email = AssertHelper.AssertJsonProperty(root, "1", JsonValueKind.Object); - AssertHelper.AssertJsonProperty(email, "Enabled", JsonValueKind.False); - var emailMetaData = AssertHelper.AssertJsonProperty(email, "MetaData", JsonValueKind.Object); - AssertHelper.AssertJsonProperty(emailMetaData, "Email", JsonValueKind.String); - } - - [Fact] - public void GetTwoFactorProviders_Success() + private static readonly Dictionary _testTwoFactorConfig = new Dictionary + { + [TwoFactorProviderType.WebAuthn] = new TwoFactorProvider { - // This is to get rid of the cached dictionary the SetTwoFactorProviders keeps so we can fully test the JSON reading - // It intent is to mimic a storing of the entity in the database and it being read later - var tempUser = new User(); - tempUser.SetTwoFactorProviders(_testTwoFactorConfig); - var user = new User + Enabled = true, + MetaData = new Dictionary { - TwoFactorProviders = tempUser.TwoFactorProviders, - }; - - var twoFactorProviders = user.GetTwoFactorProviders(); - - var webAuthn = Assert.Contains(TwoFactorProviderType.WebAuthn, (IDictionary)twoFactorProviders); - Assert.True(webAuthn.Enabled); - Assert.NotNull(webAuthn.MetaData); - var webAuthnMetaDataItem = Assert.Contains("Item", (IDictionary)webAuthn.MetaData); - Assert.Equal("thing", webAuthnMetaDataItem); - - var email = Assert.Contains(TwoFactorProviderType.Email, (IDictionary)twoFactorProviders); - Assert.False(email.Enabled); - Assert.NotNull(email.MetaData); - var emailMetaDataEmail = Assert.Contains("Email", (IDictionary)email.MetaData); - Assert.Equal("test@email.com", emailMetaDataEmail); - } - - [Fact] - public void GetTwoFactorProviders_SavedWithName_Success() + ["Item"] = "thing", + }, + }, + [TwoFactorProviderType.Email] = new TwoFactorProvider { - var user = new User(); - // This should save items with the string name of the enum and we will validate that we can read - // from that just incase some users have it saved that way. - user.TwoFactorProviders = JsonSerializer.Serialize(_testTwoFactorConfig); + Enabled = false, + MetaData = new Dictionary + { + ["Email"] = "test@email.com", + }, + }, + }; - // Preliminary Asserts to make sure we are testing what we want to be testing - using var jsonDocument = JsonDocument.Parse(user.TwoFactorProviders); - var root = jsonDocument.RootElement; - // This means it saved the enum as its string name - AssertHelper.AssertJsonProperty(root, "WebAuthn", JsonValueKind.Object); - AssertHelper.AssertJsonProperty(root, "Email", JsonValueKind.Object); + [Fact] + public void SetTwoFactorProviders_Success() + { + var user = new User(); + user.SetTwoFactorProviders(_testTwoFactorConfig); - // Actual checks - var twoFactorProviders = user.GetTwoFactorProviders(); + using var jsonDocument = JsonDocument.Parse(user.TwoFactorProviders); + var root = jsonDocument.RootElement; - var webAuthn = Assert.Contains(TwoFactorProviderType.WebAuthn, (IDictionary)twoFactorProviders); - Assert.True(webAuthn.Enabled); - Assert.NotNull(webAuthn.MetaData); - var webAuthnMetaDataItem = Assert.Contains("Item", (IDictionary)webAuthn.MetaData); - Assert.Equal("thing", webAuthnMetaDataItem); + var webAuthn = AssertHelper.AssertJsonProperty(root, "7", JsonValueKind.Object); + AssertHelper.AssertJsonProperty(webAuthn, "Enabled", JsonValueKind.True); + var webMetaData = AssertHelper.AssertJsonProperty(webAuthn, "MetaData", JsonValueKind.Object); + AssertHelper.AssertJsonProperty(webMetaData, "Item", JsonValueKind.String); - var email = Assert.Contains(TwoFactorProviderType.Email, (IDictionary)twoFactorProviders); - Assert.False(email.Enabled); - Assert.NotNull(email.MetaData); - var emailMetaDataEmail = Assert.Contains("Email", (IDictionary)email.MetaData); - Assert.Equal("test@email.com", emailMetaDataEmail); - } + var email = AssertHelper.AssertJsonProperty(root, "1", JsonValueKind.Object); + AssertHelper.AssertJsonProperty(email, "Enabled", JsonValueKind.False); + var emailMetaData = AssertHelper.AssertJsonProperty(email, "MetaData", JsonValueKind.Object); + AssertHelper.AssertJsonProperty(emailMetaData, "Email", JsonValueKind.String); + } + + [Fact] + public void GetTwoFactorProviders_Success() + { + // This is to get rid of the cached dictionary the SetTwoFactorProviders keeps so we can fully test the JSON reading + // It intent is to mimic a storing of the entity in the database and it being read later + var tempUser = new User(); + tempUser.SetTwoFactorProviders(_testTwoFactorConfig); + var user = new User + { + TwoFactorProviders = tempUser.TwoFactorProviders, + }; + + var twoFactorProviders = user.GetTwoFactorProviders(); + + var webAuthn = Assert.Contains(TwoFactorProviderType.WebAuthn, (IDictionary)twoFactorProviders); + Assert.True(webAuthn.Enabled); + Assert.NotNull(webAuthn.MetaData); + var webAuthnMetaDataItem = Assert.Contains("Item", (IDictionary)webAuthn.MetaData); + Assert.Equal("thing", webAuthnMetaDataItem); + + var email = Assert.Contains(TwoFactorProviderType.Email, (IDictionary)twoFactorProviders); + Assert.False(email.Enabled); + Assert.NotNull(email.MetaData); + var emailMetaDataEmail = Assert.Contains("Email", (IDictionary)email.MetaData); + Assert.Equal("test@email.com", emailMetaDataEmail); + } + + [Fact] + public void GetTwoFactorProviders_SavedWithName_Success() + { + var user = new User(); + // This should save items with the string name of the enum and we will validate that we can read + // from that just incase some users have it saved that way. + user.TwoFactorProviders = JsonSerializer.Serialize(_testTwoFactorConfig); + + // Preliminary Asserts to make sure we are testing what we want to be testing + using var jsonDocument = JsonDocument.Parse(user.TwoFactorProviders); + var root = jsonDocument.RootElement; + // This means it saved the enum as its string name + AssertHelper.AssertJsonProperty(root, "WebAuthn", JsonValueKind.Object); + AssertHelper.AssertJsonProperty(root, "Email", JsonValueKind.Object); + + // Actual checks + var twoFactorProviders = user.GetTwoFactorProviders(); + + var webAuthn = Assert.Contains(TwoFactorProviderType.WebAuthn, (IDictionary)twoFactorProviders); + Assert.True(webAuthn.Enabled); + Assert.NotNull(webAuthn.MetaData); + var webAuthnMetaDataItem = Assert.Contains("Item", (IDictionary)webAuthn.MetaData); + Assert.Equal("thing", webAuthnMetaDataItem); + + var email = Assert.Contains(TwoFactorProviderType.Email, (IDictionary)twoFactorProviders); + Assert.False(email.Enabled); + Assert.NotNull(email.MetaData); + var emailMetaDataEmail = Assert.Contains("Email", (IDictionary)email.MetaData); + Assert.Equal("test@email.com", emailMetaDataEmail); } } diff --git a/test/Core.Test/Helpers/Factories.cs b/test/Core.Test/Helpers/Factories.cs index 3d6523bc9..7761d5cb1 100644 --- a/test/Core.Test/Helpers/Factories.cs +++ b/test/Core.Test/Helpers/Factories.cs @@ -1,16 +1,15 @@ using Bit.Core.Settings; using Microsoft.Extensions.Configuration; -namespace Bit.Core.Test.Helpers.Factories +namespace Bit.Core.Test.Helpers.Factories; + +public static class GlobalSettingsFactory { - public static class GlobalSettingsFactory + public static GlobalSettings GlobalSettings { get; } = new(); + static GlobalSettingsFactory() { - public static GlobalSettings GlobalSettings { get; } = new(); - static GlobalSettingsFactory() - { - var configBuilder = new ConfigurationBuilder().AddUserSecrets("bitwarden-Api"); - var Configuration = configBuilder.Build(); - ConfigurationBinder.Bind(Configuration.GetSection("GlobalSettings"), GlobalSettings); - } + var configBuilder = new ConfigurationBuilder().AddUserSecrets("bitwarden-Api"); + var Configuration = configBuilder.Build(); + ConfigurationBinder.Bind(Configuration.GetSection("GlobalSettings"), GlobalSettings); } } diff --git a/test/Core.Test/Identity/AuthenticationTokenProviderTests.cs b/test/Core.Test/Identity/AuthenticationTokenProviderTests.cs index 8a5de6898..7b1ad3892 100644 --- a/test/Core.Test/Identity/AuthenticationTokenProviderTests.cs +++ b/test/Core.Test/Identity/AuthenticationTokenProviderTests.cs @@ -5,35 +5,34 @@ using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.Identity +namespace Bit.Core.Test.Identity; + +public class AuthenticationTokenProviderTests : BaseTokenProviderTests { - public class AuthenticationTokenProviderTests : BaseTokenProviderTests + public override TwoFactorProviderType TwoFactorProviderType => TwoFactorProviderType.Authenticator; + + public static IEnumerable CanGenerateTwoFactorTokenAsyncData + => SetupCanGenerateData( + ( + new Dictionary + { + ["Key"] = "stuff", + }, + true + ), + ( + new Dictionary + { + ["Key"] = "" + }, + false + ) + ); + + [Theory, BitMemberAutoData(nameof(CanGenerateTwoFactorTokenAsyncData))] + public override async Task RunCanGenerateTwoFactorTokenAsync(Dictionary metaData, bool expectedResponse, + User user, SutProvider sutProvider) { - public override TwoFactorProviderType TwoFactorProviderType => TwoFactorProviderType.Authenticator; - - public static IEnumerable CanGenerateTwoFactorTokenAsyncData - => SetupCanGenerateData( - ( - new Dictionary - { - ["Key"] = "stuff", - }, - true - ), - ( - new Dictionary - { - ["Key"] = "" - }, - false - ) - ); - - [Theory, BitMemberAutoData(nameof(CanGenerateTwoFactorTokenAsyncData))] - public override async Task RunCanGenerateTwoFactorTokenAsync(Dictionary metaData, bool expectedResponse, - User user, SutProvider sutProvider) - { - await base.RunCanGenerateTwoFactorTokenAsync(metaData, expectedResponse, user, sutProvider); - } + await base.RunCanGenerateTwoFactorTokenAsync(metaData, expectedResponse, user, sutProvider); } } diff --git a/test/Core.Test/Identity/BaseTokenProviderTests.cs b/test/Core.Test/Identity/BaseTokenProviderTests.cs index 9de8abbe5..5a9e0316e 100644 --- a/test/Core.Test/Identity/BaseTokenProviderTests.cs +++ b/test/Core.Test/Identity/BaseTokenProviderTests.cs @@ -11,83 +11,82 @@ using Microsoft.Extensions.Options; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Identity +namespace Bit.Core.Test.Identity; + +[SutProviderCustomize] +public abstract class BaseTokenProviderTests + where T : IUserTwoFactorTokenProvider { - [SutProviderCustomize] - public abstract class BaseTokenProviderTests - where T : IUserTwoFactorTokenProvider + public abstract TwoFactorProviderType TwoFactorProviderType { get; } + + #region Helpers + protected static IEnumerable SetupCanGenerateData(params (Dictionary MetaData, bool ExpectedResponse)[] data) { - public abstract TwoFactorProviderType TwoFactorProviderType { get; } - - #region Helpers - protected static IEnumerable SetupCanGenerateData(params (Dictionary MetaData, bool ExpectedResponse)[] data) - { - return data.Select(d => - new object[] - { - d.MetaData, - d.ExpectedResponse, - }); - } - - protected virtual IUserService AdditionalSetup(SutProvider sutProvider, User user) - { - var userService = Substitute.For(); - - sutProvider.GetDependency() - .GetService(typeof(IUserService)) - .Returns(userService); - - SetupUserService(userService, user); - - return userService; - } - - protected virtual void SetupUserService(IUserService userService, User user) - { - userService - .TwoFactorProviderIsEnabledAsync(TwoFactorProviderType, user) - .Returns(true); - } - - protected static UserManager SubstituteUserManager() - { - return new UserManager(Substitute.For>(), - Substitute.For>(), - Substitute.For>(), - Enumerable.Empty>(), - Enumerable.Empty>(), - Substitute.For(), - Substitute.For(), - Substitute.For(), - Substitute.For>>()); - } - - protected void MockDatabase(User user, Dictionary metaData) - { - var providers = new Dictionary + return data.Select(d => + new object[] { - [TwoFactorProviderType] = new TwoFactorProvider - { - Enabled = true, - MetaData = metaData, - }, - }; + d.MetaData, + d.ExpectedResponse, + }); + } - user.TwoFactorProviders = JsonHelpers.LegacySerialize(providers); - } - #endregion + protected virtual IUserService AdditionalSetup(SutProvider sutProvider, User user) + { + var userService = Substitute.For(); - public virtual async Task RunCanGenerateTwoFactorTokenAsync(Dictionary metaData, bool expectedResponse, - User user, SutProvider sutProvider) + sutProvider.GetDependency() + .GetService(typeof(IUserService)) + .Returns(userService); + + SetupUserService(userService, user); + + return userService; + } + + protected virtual void SetupUserService(IUserService userService, User user) + { + userService + .TwoFactorProviderIsEnabledAsync(TwoFactorProviderType, user) + .Returns(true); + } + + protected static UserManager SubstituteUserManager() + { + return new UserManager(Substitute.For>(), + Substitute.For>(), + Substitute.For>(), + Enumerable.Empty>(), + Enumerable.Empty>(), + Substitute.For(), + Substitute.For(), + Substitute.For(), + Substitute.For>>()); + } + + protected void MockDatabase(User user, Dictionary metaData) + { + var providers = new Dictionary { - var userManager = SubstituteUserManager(); - MockDatabase(user, metaData); + [TwoFactorProviderType] = new TwoFactorProvider + { + Enabled = true, + MetaData = metaData, + }, + }; - AdditionalSetup(sutProvider, user); + user.TwoFactorProviders = JsonHelpers.LegacySerialize(providers); + } + #endregion - var response = await sutProvider.Sut.CanGenerateTwoFactorTokenAsync(userManager, user); - Assert.Equal(expectedResponse, response); - } + public virtual async Task RunCanGenerateTwoFactorTokenAsync(Dictionary metaData, bool expectedResponse, + User user, SutProvider sutProvider) + { + var userManager = SubstituteUserManager(); + MockDatabase(user, metaData); + + AdditionalSetup(sutProvider, user); + + var response = await sutProvider.Sut.CanGenerateTwoFactorTokenAsync(userManager, user); + Assert.Equal(expectedResponse, response); } } diff --git a/test/Core.Test/Identity/EmailTokenProviderTests.cs b/test/Core.Test/Identity/EmailTokenProviderTests.cs index b1b471201..707ed798d 100644 --- a/test/Core.Test/Identity/EmailTokenProviderTests.cs +++ b/test/Core.Test/Identity/EmailTokenProviderTests.cs @@ -5,42 +5,41 @@ using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.Identity +namespace Bit.Core.Test.Identity; + +public class EmailTokenProviderTests : BaseTokenProviderTests { - public class EmailTokenProviderTests : BaseTokenProviderTests + public override TwoFactorProviderType TwoFactorProviderType => TwoFactorProviderType.Email; + + public static IEnumerable CanGenerateTwoFactorTokenAsyncData + => SetupCanGenerateData( + ( + new Dictionary + { + ["Email"] = "test@email.com", + }, + true + ), + ( + new Dictionary + { + ["NotEmail"] = "value", + }, + false + ), + ( + new Dictionary + { + ["Email"] = "", + }, + false + ) + ); + + [Theory, BitMemberAutoData(nameof(CanGenerateTwoFactorTokenAsyncData))] + public override async Task RunCanGenerateTwoFactorTokenAsync(Dictionary metaData, bool expectedResponse, + User user, SutProvider sutProvider) { - public override TwoFactorProviderType TwoFactorProviderType => TwoFactorProviderType.Email; - - public static IEnumerable CanGenerateTwoFactorTokenAsyncData - => SetupCanGenerateData( - ( - new Dictionary - { - ["Email"] = "test@email.com", - }, - true - ), - ( - new Dictionary - { - ["NotEmail"] = "value", - }, - false - ), - ( - new Dictionary - { - ["Email"] = "", - }, - false - ) - ); - - [Theory, BitMemberAutoData(nameof(CanGenerateTwoFactorTokenAsyncData))] - public override async Task RunCanGenerateTwoFactorTokenAsync(Dictionary metaData, bool expectedResponse, - User user, SutProvider sutProvider) - { - await base.RunCanGenerateTwoFactorTokenAsync(metaData, expectedResponse, user, sutProvider); - } + await base.RunCanGenerateTwoFactorTokenAsync(metaData, expectedResponse, user, sutProvider); } } diff --git a/test/Core.Test/IdentityServer/TokenRetrievalTests.cs b/test/Core.Test/IdentityServer/TokenRetrievalTests.cs index 071f4d914..591427da9 100644 --- a/test/Core.Test/IdentityServer/TokenRetrievalTests.cs +++ b/test/Core.Test/IdentityServer/TokenRetrievalTests.cs @@ -4,91 +4,90 @@ using Microsoft.Extensions.Primitives; using NSubstitute; using Xunit; -namespace Bit.Core.Test.IdentityServer +namespace Bit.Core.Test.IdentityServer; + +public class TokenRetrievalTests { - public class TokenRetrievalTests + private readonly Func _sut = TokenRetrieval.FromAuthorizationHeaderOrQueryString(); + + [Fact] + public void RetrieveToken_FromHeader_ReturnsToken() { - private readonly Func _sut = TokenRetrieval.FromAuthorizationHeaderOrQueryString(); - - [Fact] - public void RetrieveToken_FromHeader_ReturnsToken() + // Arrange + var headers = new HeaderDictionary { - // Arrange - var headers = new HeaderDictionary - { - { "Authorization", "Bearer test_value" }, - { "X-Test-Header", "random_value" } - }; + { "Authorization", "Bearer test_value" }, + { "X-Test-Header", "random_value" } + }; - var request = Substitute.For(); + var request = Substitute.For(); - request.Headers.Returns(headers); + request.Headers.Returns(headers); - // Act - var token = _sut(request); + // Act + var token = _sut(request); - // Assert - Assert.Equal("test_value", token); - } + // Assert + Assert.Equal("test_value", token); + } - [Fact] - public void RetrieveToken_FromQueryString_ReturnsToken() + [Fact] + public void RetrieveToken_FromQueryString_ReturnsToken() + { + // Arrange + var queryString = new Dictionary { - // Arrange - var queryString = new Dictionary - { - { "access_token", "test_value" }, - { "test-query", "random_value" } - }; + { "access_token", "test_value" }, + { "test-query", "random_value" } + }; - var request = Substitute.For(); - request.Query.Returns(new QueryCollection(queryString)); + var request = Substitute.For(); + request.Query.Returns(new QueryCollection(queryString)); - // Act - var token = _sut(request); + // Act + var token = _sut(request); - // Assert - Assert.Equal("test_value", token); - } + // Assert + Assert.Equal("test_value", token); + } - [Fact] - public void RetrieveToken_HasBoth_ReturnsHeaderToken() + [Fact] + public void RetrieveToken_HasBoth_ReturnsHeaderToken() + { + // Arrange + var queryString = new Dictionary { - // Arrange - var queryString = new Dictionary - { - { "access_token", "query_string_token" }, - { "test-query", "random_value" } - }; + { "access_token", "query_string_token" }, + { "test-query", "random_value" } + }; - var headers = new HeaderDictionary - { - { "Authorization", "Bearer header_token" }, - { "X-Test-Header", "random_value" } - }; - - var request = Substitute.For(); - request.Headers.Returns(headers); - request.Query.Returns(new QueryCollection(queryString)); - - // Act - var token = _sut(request); - - // Assert - Assert.Equal("header_token", token); - } - - [Fact] - public void RetrieveToken_NoToken_ReturnsNull() + var headers = new HeaderDictionary { - // Arrange - var request = Substitute.For(); + { "Authorization", "Bearer header_token" }, + { "X-Test-Header", "random_value" } + }; - // Act - var token = _sut(request); + var request = Substitute.For(); + request.Headers.Returns(headers); + request.Query.Returns(new QueryCollection(queryString)); - // Assert - Assert.Null(token); - } + // Act + var token = _sut(request); + + // Assert + Assert.Equal("header_token", token); + } + + [Fact] + public void RetrieveToken_NoToken_ReturnsNull() + { + // Arrange + var request = Substitute.For(); + + // Act + var token = _sut(request); + + // Assert + Assert.Null(token); } } diff --git a/test/Core.Test/Models/Business/BillingInfo.cs b/test/Core.Test/Models/Business/BillingInfo.cs index 0023b4669..c6c1ae56f 100644 --- a/test/Core.Test/Models/Business/BillingInfo.cs +++ b/test/Core.Test/Models/Business/BillingInfo.cs @@ -1,23 +1,22 @@ using Bit.Core.Models.Business; using Xunit; -namespace Bit.Core.Test.Models.Business +namespace Bit.Core.Test.Models.Business; + +public class BillingInfoTests { - public class BillingInfoTests + [Fact] + public void BillingInvoice_Amount_ShouldComeFrom_InvoiceTotal() { - [Fact] - public void BillingInvoice_Amount_ShouldComeFrom_InvoiceTotal() + var invoice = new Stripe.Invoice { - var invoice = new Stripe.Invoice - { - AmountDue = 1000, - Total = 2000, - }; + AmountDue = 1000, + Total = 2000, + }; - var billingInvoice = new BillingInfo.BillingInvoice(invoice); + var billingInvoice = new BillingInfo.BillingInvoice(invoice); - // Should have been set from Total - Assert.Equal(20M, billingInvoice.Amount); - } + // Should have been set from Total + Assert.Equal(20M, billingInvoice.Amount); } } diff --git a/test/Core.Test/Models/Business/TaxInfoTests.cs b/test/Core.Test/Models/Business/TaxInfoTests.cs index 124201b62..197948006 100644 --- a/test/Core.Test/Models/Business/TaxInfoTests.cs +++ b/test/Core.Test/Models/Business/TaxInfoTests.cs @@ -1,115 +1,114 @@ using Bit.Core.Models.Business; using Xunit; -namespace Bit.Core.Test.Models.Business +namespace Bit.Core.Test.Models.Business; + +public class TaxInfoTests { - public class TaxInfoTests + // PH = Placeholder + [Theory] + [InlineData(null, null, null, null)] + [InlineData("", "", null, null)] + [InlineData("PH", "", null, null)] + [InlineData("", "PH", null, null)] + [InlineData("AE", "PH", null, "ae_trn")] + [InlineData("AU", "PH", null, "au_abn")] + [InlineData("BR", "PH", null, "br_cnpj")] + [InlineData("CA", "PH", "bec", "ca_qst")] + [InlineData("CA", "PH", null, "ca_bn")] + [InlineData("CL", "PH", null, "cl_tin")] + [InlineData("AT", "PH", null, "eu_vat")] + [InlineData("BE", "PH", null, "eu_vat")] + [InlineData("BG", "PH", null, "eu_vat")] + [InlineData("CY", "PH", null, "eu_vat")] + [InlineData("CZ", "PH", null, "eu_vat")] + [InlineData("DE", "PH", null, "eu_vat")] + [InlineData("DK", "PH", null, "eu_vat")] + [InlineData("EE", "PH", null, "eu_vat")] + [InlineData("ES", "PH", null, "eu_vat")] + [InlineData("FI", "PH", null, "eu_vat")] + [InlineData("FR", "PH", null, "eu_vat")] + [InlineData("GB", "PH", null, "eu_vat")] + [InlineData("GR", "PH", null, "eu_vat")] + [InlineData("HR", "PH", null, "eu_vat")] + [InlineData("HU", "PH", null, "eu_vat")] + [InlineData("IE", "PH", null, "eu_vat")] + [InlineData("IT", "PH", null, "eu_vat")] + [InlineData("LT", "PH", null, "eu_vat")] + [InlineData("LU", "PH", null, "eu_vat")] + [InlineData("LV", "PH", null, "eu_vat")] + [InlineData("MT", "PH", null, "eu_vat")] + [InlineData("NL", "PH", null, "eu_vat")] + [InlineData("PL", "PH", null, "eu_vat")] + [InlineData("PT", "PH", null, "eu_vat")] + [InlineData("RO", "PH", null, "eu_vat")] + [InlineData("SE", "PH", null, "eu_vat")] + [InlineData("SI", "PH", null, "eu_vat")] + [InlineData("SK", "PH", null, "eu_vat")] + [InlineData("HK", "PH", null, "hk_br")] + [InlineData("IN", "PH", null, "in_gst")] + [InlineData("JP", "PH", null, "jp_cn")] + [InlineData("KR", "PH", null, "kr_brn")] + [InlineData("LI", "PH", null, "li_uid")] + [InlineData("MX", "PH", null, "mx_rfc")] + [InlineData("MY", "PH", null, "my_sst")] + [InlineData("NO", "PH", null, "no_vat")] + [InlineData("NZ", "PH", null, "nz_gst")] + [InlineData("RU", "PH", null, "ru_inn")] + [InlineData("SA", "PH", null, "sa_vat")] + [InlineData("SG", "PH", null, "sg_gst")] + [InlineData("TH", "PH", null, "th_vat")] + [InlineData("TW", "PH", null, "tw_vat")] + [InlineData("US", "PH", null, "us_ein")] + [InlineData("ZA", "PH", null, "za_vat")] + [InlineData("ABCDEF", "PH", null, null)] + public void GetTaxIdType_Success(string billingAddressCountry, + string taxIdNumber, + string billingAddressState, + string expectedTaxIdType) { - // PH = Placeholder - [Theory] - [InlineData(null, null, null, null)] - [InlineData("", "", null, null)] - [InlineData("PH", "", null, null)] - [InlineData("", "PH", null, null)] - [InlineData("AE", "PH", null, "ae_trn")] - [InlineData("AU", "PH", null, "au_abn")] - [InlineData("BR", "PH", null, "br_cnpj")] - [InlineData("CA", "PH", "bec", "ca_qst")] - [InlineData("CA", "PH", null, "ca_bn")] - [InlineData("CL", "PH", null, "cl_tin")] - [InlineData("AT", "PH", null, "eu_vat")] - [InlineData("BE", "PH", null, "eu_vat")] - [InlineData("BG", "PH", null, "eu_vat")] - [InlineData("CY", "PH", null, "eu_vat")] - [InlineData("CZ", "PH", null, "eu_vat")] - [InlineData("DE", "PH", null, "eu_vat")] - [InlineData("DK", "PH", null, "eu_vat")] - [InlineData("EE", "PH", null, "eu_vat")] - [InlineData("ES", "PH", null, "eu_vat")] - [InlineData("FI", "PH", null, "eu_vat")] - [InlineData("FR", "PH", null, "eu_vat")] - [InlineData("GB", "PH", null, "eu_vat")] - [InlineData("GR", "PH", null, "eu_vat")] - [InlineData("HR", "PH", null, "eu_vat")] - [InlineData("HU", "PH", null, "eu_vat")] - [InlineData("IE", "PH", null, "eu_vat")] - [InlineData("IT", "PH", null, "eu_vat")] - [InlineData("LT", "PH", null, "eu_vat")] - [InlineData("LU", "PH", null, "eu_vat")] - [InlineData("LV", "PH", null, "eu_vat")] - [InlineData("MT", "PH", null, "eu_vat")] - [InlineData("NL", "PH", null, "eu_vat")] - [InlineData("PL", "PH", null, "eu_vat")] - [InlineData("PT", "PH", null, "eu_vat")] - [InlineData("RO", "PH", null, "eu_vat")] - [InlineData("SE", "PH", null, "eu_vat")] - [InlineData("SI", "PH", null, "eu_vat")] - [InlineData("SK", "PH", null, "eu_vat")] - [InlineData("HK", "PH", null, "hk_br")] - [InlineData("IN", "PH", null, "in_gst")] - [InlineData("JP", "PH", null, "jp_cn")] - [InlineData("KR", "PH", null, "kr_brn")] - [InlineData("LI", "PH", null, "li_uid")] - [InlineData("MX", "PH", null, "mx_rfc")] - [InlineData("MY", "PH", null, "my_sst")] - [InlineData("NO", "PH", null, "no_vat")] - [InlineData("NZ", "PH", null, "nz_gst")] - [InlineData("RU", "PH", null, "ru_inn")] - [InlineData("SA", "PH", null, "sa_vat")] - [InlineData("SG", "PH", null, "sg_gst")] - [InlineData("TH", "PH", null, "th_vat")] - [InlineData("TW", "PH", null, "tw_vat")] - [InlineData("US", "PH", null, "us_ein")] - [InlineData("ZA", "PH", null, "za_vat")] - [InlineData("ABCDEF", "PH", null, null)] - public void GetTaxIdType_Success(string billingAddressCountry, - string taxIdNumber, - string billingAddressState, - string expectedTaxIdType) + var taxInfo = new TaxInfo { - var taxInfo = new TaxInfo - { - BillingAddressCountry = billingAddressCountry, - TaxIdNumber = taxIdNumber, - BillingAddressState = billingAddressState, - }; + BillingAddressCountry = billingAddressCountry, + TaxIdNumber = taxIdNumber, + BillingAddressState = billingAddressState, + }; - Assert.Equal(expectedTaxIdType, taxInfo.TaxIdType); - } + Assert.Equal(expectedTaxIdType, taxInfo.TaxIdType); + } - [Fact] - public void GetTaxIdType_CreateOnce_ReturnCacheSecondTime() + [Fact] + public void GetTaxIdType_CreateOnce_ReturnCacheSecondTime() + { + var taxInfo = new TaxInfo { - var taxInfo = new TaxInfo - { - BillingAddressCountry = "US", - TaxIdNumber = "PH", - BillingAddressState = null, - }; + BillingAddressCountry = "US", + TaxIdNumber = "PH", + BillingAddressState = null, + }; - Assert.Equal("us_ein", taxInfo.TaxIdType); + Assert.Equal("us_ein", taxInfo.TaxIdType); - // Per the current spec even if the values change to something other than null it - // will return the cached version of TaxIdType. - taxInfo.BillingAddressCountry = "ZA"; + // Per the current spec even if the values change to something other than null it + // will return the cached version of TaxIdType. + taxInfo.BillingAddressCountry = "ZA"; - Assert.Equal("us_ein", taxInfo.TaxIdType); - } + Assert.Equal("us_ein", taxInfo.TaxIdType); + } - [Theory] - [InlineData(null, null, false)] - [InlineData("123", "US", true)] - [InlineData("123", "ZQ12", false)] - [InlineData(" ", "US", false)] - public void HasTaxId_ReturnsExpected(string taxIdNumber, string billingAddressCountry, bool expected) + [Theory] + [InlineData(null, null, false)] + [InlineData("123", "US", true)] + [InlineData("123", "ZQ12", false)] + [InlineData(" ", "US", false)] + public void HasTaxId_ReturnsExpected(string taxIdNumber, string billingAddressCountry, bool expected) + { + var taxInfo = new TaxInfo { - var taxInfo = new TaxInfo - { - TaxIdNumber = taxIdNumber, - BillingAddressCountry = billingAddressCountry, - }; + TaxIdNumber = taxIdNumber, + BillingAddressCountry = billingAddressCountry, + }; - Assert.Equal(expected, taxInfo.HasTaxId); - } + Assert.Equal(expected, taxInfo.HasTaxId); } } diff --git a/test/Core.Test/Models/Business/Tokenables/EmergencyAccessInviteTokenableTests.cs b/test/Core.Test/Models/Business/Tokenables/EmergencyAccessInviteTokenableTests.cs index d334c7dfa..40e390c7d 100644 --- a/test/Core.Test/Models/Business/Tokenables/EmergencyAccessInviteTokenableTests.cs +++ b/test/Core.Test/Models/Business/Tokenables/EmergencyAccessInviteTokenableTests.cs @@ -4,30 +4,29 @@ using Bit.Core.Models.Business.Tokenables; using Bit.Core.Tokens; using Xunit; -namespace Bit.Core.Test.Models.Business.Tokenables +namespace Bit.Core.Test.Models.Business.Tokenables; + +public class EmergencyAccessInviteTokenableTests { - public class EmergencyAccessInviteTokenableTests + [Theory, AutoData] + public void SerializationSetsCorrectDateTime(EmergencyAccess emergencyAccess) { - [Theory, AutoData] - public void SerializationSetsCorrectDateTime(EmergencyAccess emergencyAccess) - { - var token = new EmergencyAccessInviteTokenable(emergencyAccess, 2); - Assert.Equal(Tokenable.FromToken(token.ToToken().ToString()).ExpirationDate, - token.ExpirationDate, - TimeSpan.FromMilliseconds(10)); - } + var token = new EmergencyAccessInviteTokenable(emergencyAccess, 2); + Assert.Equal(Tokenable.FromToken(token.ToToken().ToString()).ExpirationDate, + token.ExpirationDate, + TimeSpan.FromMilliseconds(10)); + } - [Fact] - public void IsInvalidIfIdentifierIsWrong() + [Fact] + public void IsInvalidIfIdentifierIsWrong() + { + var token = new EmergencyAccessInviteTokenable(DateTime.MaxValue) { - var token = new EmergencyAccessInviteTokenable(DateTime.MaxValue) - { - Email = "email", - Id = Guid.NewGuid(), - Identifier = "not correct" - }; + Email = "email", + Id = Guid.NewGuid(), + Identifier = "not correct" + }; - Assert.False(token.Valid); - } + Assert.False(token.Valid); } } diff --git a/test/Core.Test/Models/Business/Tokenables/HCaptchaTokenableTests.cs b/test/Core.Test/Models/Business/Tokenables/HCaptchaTokenableTests.cs index ce72fa8dc..ce97cb8b3 100644 --- a/test/Core.Test/Models/Business/Tokenables/HCaptchaTokenableTests.cs +++ b/test/Core.Test/Models/Business/Tokenables/HCaptchaTokenableTests.cs @@ -5,84 +5,83 @@ using Bit.Core.Tokens; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.Models.Business.Tokenables +namespace Bit.Core.Test.Models.Business.Tokenables; + +public class HCaptchaTokenableTests { - public class HCaptchaTokenableTests + [Fact] + public void CanHandleNullUser() { - [Fact] - public void CanHandleNullUser() + var token = new HCaptchaTokenable(null); + + Assert.Equal(default, token.Id); + Assert.Equal(default, token.Email); + } + + [Fact] + public void TokenWithNullUserIsInvalid() + { + var token = new HCaptchaTokenable(null) { - var token = new HCaptchaTokenable(null); + ExpirationDate = DateTime.UtcNow + TimeSpan.FromDays(1) + }; - Assert.Equal(default, token.Id); - Assert.Equal(default, token.Email); - } + Assert.False(token.Valid); + } - [Fact] - public void TokenWithNullUserIsInvalid() + [Theory, BitAutoData] + public void TokenValidityCheckNullUserIdIsInvalid(User user) + { + var token = new HCaptchaTokenable(user) { - var token = new HCaptchaTokenable(null) - { - ExpirationDate = DateTime.UtcNow + TimeSpan.FromDays(1) - }; + ExpirationDate = DateTime.UtcNow + TimeSpan.FromDays(1) + }; - Assert.False(token.Valid); - } + Assert.False(token.TokenIsValid(null)); + } - [Theory, BitAutoData] - public void TokenValidityCheckNullUserIdIsInvalid(User user) + [Theory, AutoData] + public void CanUpdateExpirationToNonStandard(User user) + { + var token = new HCaptchaTokenable(user) { - var token = new HCaptchaTokenable(user) - { - ExpirationDate = DateTime.UtcNow + TimeSpan.FromDays(1) - }; + ExpirationDate = DateTime.MinValue + }; - Assert.False(token.TokenIsValid(null)); - } + Assert.Equal(DateTime.MinValue, token.ExpirationDate, TimeSpan.FromMilliseconds(10)); + } - [Theory, AutoData] - public void CanUpdateExpirationToNonStandard(User user) + [Theory, AutoData] + public void SetsDataFromUser(User user) + { + var token = new HCaptchaTokenable(user); + + Assert.Equal(user.Id, token.Id); + Assert.Equal(user.Email, token.Email); + } + + [Theory, AutoData] + public void SerializationSetsCorrectDateTime(User user) + { + var expectedDateTime = DateTime.UtcNow.AddHours(-5); + var token = new HCaptchaTokenable(user) { - var token = new HCaptchaTokenable(user) - { - ExpirationDate = DateTime.MinValue - }; + ExpirationDate = expectedDateTime + }; - Assert.Equal(DateTime.MinValue, token.ExpirationDate, TimeSpan.FromMilliseconds(10)); - } + var result = Tokenable.FromToken(token.ToToken()); - [Theory, AutoData] - public void SetsDataFromUser(User user) + Assert.Equal(expectedDateTime, result.ExpirationDate, TimeSpan.FromMilliseconds(10)); + } + + [Theory, AutoData] + public void IsInvalidIfIdentifierIsWrong(User user) + { + var token = new HCaptchaTokenable(user) { - var token = new HCaptchaTokenable(user); + Identifier = "not correct" + }; - Assert.Equal(user.Id, token.Id); - Assert.Equal(user.Email, token.Email); - } - - [Theory, AutoData] - public void SerializationSetsCorrectDateTime(User user) - { - var expectedDateTime = DateTime.UtcNow.AddHours(-5); - var token = new HCaptchaTokenable(user) - { - ExpirationDate = expectedDateTime - }; - - var result = Tokenable.FromToken(token.ToToken()); - - Assert.Equal(expectedDateTime, result.ExpirationDate, TimeSpan.FromMilliseconds(10)); - } - - [Theory, AutoData] - public void IsInvalidIfIdentifierIsWrong(User user) - { - var token = new HCaptchaTokenable(user) - { - Identifier = "not correct" - }; - - Assert.False(token.Valid); - } + Assert.False(token.Valid); } } diff --git a/test/Core.Test/Models/Business/Tokenables/OrganizationSponsorshipOfferTokenableTests.cs b/test/Core.Test/Models/Business/Tokenables/OrganizationSponsorshipOfferTokenableTests.cs index fd39c196b..172d4c911 100644 --- a/test/Core.Test/Models/Business/Tokenables/OrganizationSponsorshipOfferTokenableTests.cs +++ b/test/Core.Test/Models/Business/Tokenables/OrganizationSponsorshipOfferTokenableTests.cs @@ -4,153 +4,152 @@ using Bit.Core.Models.Business.Tokenables; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.Models.Business.Tokenables +namespace Bit.Core.Test.Models.Business.Tokenables; + +public class OrganizationSponsorshipOfferTokenableTests { - public class OrganizationSponsorshipOfferTokenableTests + public static IEnumerable PlanSponsorshipTypes() => Enum.GetValues().Select(x => new object[] { x }); + + [Fact] + public void IsInvalidIfIdentifierIsWrong() { - public static IEnumerable PlanSponsorshipTypes() => Enum.GetValues().Select(x => new object[] { x }); - - [Fact] - public void IsInvalidIfIdentifierIsWrong() + var token = new OrganizationSponsorshipOfferTokenable() { - var token = new OrganizationSponsorshipOfferTokenable() - { - Email = "email", - Id = Guid.NewGuid(), - Identifier = "not correct", - SponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, - }; + Email = "email", + Id = Guid.NewGuid(), + Identifier = "not correct", + SponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, + }; - Assert.False(token.Valid); - } + Assert.False(token.Valid); + } - [Fact] - public void IsInvalidIfIdIsDefault() + [Fact] + public void IsInvalidIfIdIsDefault() + { + var token = new OrganizationSponsorshipOfferTokenable() { - var token = new OrganizationSponsorshipOfferTokenable() - { - Email = "email", - Id = default, - SponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, - }; + Email = "email", + Id = default, + SponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, + }; - Assert.False(token.Valid); - } + Assert.False(token.Valid); + } - [Fact] - public void IsInvalidIfEmailIsEmpty() + [Fact] + public void IsInvalidIfEmailIsEmpty() + { + var token = new OrganizationSponsorshipOfferTokenable() { - var token = new OrganizationSponsorshipOfferTokenable() - { - Email = "", - Id = Guid.NewGuid(), - SponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, - }; + Email = "", + Id = Guid.NewGuid(), + SponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, + }; - Assert.False(token.Valid); - } + Assert.False(token.Valid); + } - [Theory, BitAutoData] - public void IsValid_Success(OrganizationSponsorship sponsorship) - { - var token = new OrganizationSponsorshipOfferTokenable(sponsorship); + [Theory, BitAutoData] + public void IsValid_Success(OrganizationSponsorship sponsorship) + { + var token = new OrganizationSponsorshipOfferTokenable(sponsorship); - Assert.True(token.IsValid(sponsorship, sponsorship.OfferedToEmail)); - } + Assert.True(token.IsValid(sponsorship, sponsorship.OfferedToEmail)); + } - [Theory, BitAutoData] - public void IsValid_RequiresNonNullSponsorship(OrganizationSponsorship sponsorship) - { - var token = new OrganizationSponsorshipOfferTokenable(sponsorship); + [Theory, BitAutoData] + public void IsValid_RequiresNonNullSponsorship(OrganizationSponsorship sponsorship) + { + var token = new OrganizationSponsorshipOfferTokenable(sponsorship); - Assert.False(token.IsValid(null, sponsorship.OfferedToEmail)); - } + Assert.False(token.IsValid(null, sponsorship.OfferedToEmail)); + } - [Theory, BitAutoData] - public void IsValid_RequiresCurrentEmailToBeSameAsOfferedToEmail(OrganizationSponsorship sponsorship, string currentEmail) - { - var token = new OrganizationSponsorshipOfferTokenable(sponsorship); + [Theory, BitAutoData] + public void IsValid_RequiresCurrentEmailToBeSameAsOfferedToEmail(OrganizationSponsorship sponsorship, string currentEmail) + { + var token = new OrganizationSponsorshipOfferTokenable(sponsorship); - Assert.False(token.IsValid(sponsorship, currentEmail)); - } + Assert.False(token.IsValid(sponsorship, currentEmail)); + } - [Theory, BitAutoData] - public void IsValid_RequiresSameSponsorshipId(OrganizationSponsorship sponsorship1, OrganizationSponsorship sponsorship2) - { - sponsorship1.Id = sponsorship2.Id; + [Theory, BitAutoData] + public void IsValid_RequiresSameSponsorshipId(OrganizationSponsorship sponsorship1, OrganizationSponsorship sponsorship2) + { + sponsorship1.Id = sponsorship2.Id; - var token = new OrganizationSponsorshipOfferTokenable(sponsorship1); + var token = new OrganizationSponsorshipOfferTokenable(sponsorship1); - Assert.False(token.IsValid(sponsorship2, sponsorship1.OfferedToEmail)); - } + Assert.False(token.IsValid(sponsorship2, sponsorship1.OfferedToEmail)); + } - [Theory, BitAutoData] - public void IsValid_RequiresSameEmail(OrganizationSponsorship sponsorship1, OrganizationSponsorship sponsorship2) - { - sponsorship1.OfferedToEmail = sponsorship2.OfferedToEmail; + [Theory, BitAutoData] + public void IsValid_RequiresSameEmail(OrganizationSponsorship sponsorship1, OrganizationSponsorship sponsorship2) + { + sponsorship1.OfferedToEmail = sponsorship2.OfferedToEmail; - var token = new OrganizationSponsorshipOfferTokenable(sponsorship1); + var token = new OrganizationSponsorshipOfferTokenable(sponsorship1); - Assert.False(token.IsValid(sponsorship2, sponsorship1.OfferedToEmail)); - } + Assert.False(token.IsValid(sponsorship2, sponsorship1.OfferedToEmail)); + } - [Theory, BitAutoData] - public void Constructor_GrabsIdFromSponsorship(OrganizationSponsorship sponsorship) - { - var token = new OrganizationSponsorshipOfferTokenable(sponsorship); + [Theory, BitAutoData] + public void Constructor_GrabsIdFromSponsorship(OrganizationSponsorship sponsorship) + { + var token = new OrganizationSponsorshipOfferTokenable(sponsorship); - Assert.Equal(sponsorship.Id, token.Id); - } + Assert.Equal(sponsorship.Id, token.Id); + } - [Theory, BitAutoData] - public void Constructor_GrabsEmailFromSponsorshipOfferedToEmail(OrganizationSponsorship sponsorship) - { - var token = new OrganizationSponsorshipOfferTokenable(sponsorship); + [Theory, BitAutoData] + public void Constructor_GrabsEmailFromSponsorshipOfferedToEmail(OrganizationSponsorship sponsorship) + { + var token = new OrganizationSponsorshipOfferTokenable(sponsorship); - Assert.Equal(sponsorship.OfferedToEmail, token.Email); - } + Assert.Equal(sponsorship.OfferedToEmail, token.Email); + } - [Theory, BitMemberAutoData(nameof(PlanSponsorshipTypes))] - public void Constructor_GrabsSponsorshipType(PlanSponsorshipType planSponsorshipType, - OrganizationSponsorship sponsorship) - { - sponsorship.PlanSponsorshipType = planSponsorshipType; - var token = new OrganizationSponsorshipOfferTokenable(sponsorship); + [Theory, BitMemberAutoData(nameof(PlanSponsorshipTypes))] + public void Constructor_GrabsSponsorshipType(PlanSponsorshipType planSponsorshipType, + OrganizationSponsorship sponsorship) + { + sponsorship.PlanSponsorshipType = planSponsorshipType; + var token = new OrganizationSponsorshipOfferTokenable(sponsorship); - Assert.Equal(sponsorship.PlanSponsorshipType, token.SponsorshipType); - } + Assert.Equal(sponsorship.PlanSponsorshipType, token.SponsorshipType); + } - [Theory, BitAutoData] - public void Constructor_DefaultId_Throws(OrganizationSponsorship sponsorship) - { - sponsorship.Id = default; + [Theory, BitAutoData] + public void Constructor_DefaultId_Throws(OrganizationSponsorship sponsorship) + { + sponsorship.Id = default; - Assert.Throws(() => new OrganizationSponsorshipOfferTokenable(sponsorship)); - } + Assert.Throws(() => new OrganizationSponsorshipOfferTokenable(sponsorship)); + } - [Theory, BitAutoData] - public void Constructor_NoOfferedToEmail_Throws(OrganizationSponsorship sponsorship) - { - sponsorship.OfferedToEmail = null; + [Theory, BitAutoData] + public void Constructor_NoOfferedToEmail_Throws(OrganizationSponsorship sponsorship) + { + sponsorship.OfferedToEmail = null; - Assert.Throws(() => new OrganizationSponsorshipOfferTokenable(sponsorship)); - } + Assert.Throws(() => new OrganizationSponsorshipOfferTokenable(sponsorship)); + } - [Theory, BitAutoData] - public void Constructor_EmptyOfferedToEmail_Throws(OrganizationSponsorship sponsorship) - { - sponsorship.OfferedToEmail = ""; + [Theory, BitAutoData] + public void Constructor_EmptyOfferedToEmail_Throws(OrganizationSponsorship sponsorship) + { + sponsorship.OfferedToEmail = ""; - Assert.Throws(() => new OrganizationSponsorshipOfferTokenable(sponsorship)); - } + Assert.Throws(() => new OrganizationSponsorshipOfferTokenable(sponsorship)); + } - [Theory, BitAutoData] - public void Constructor_NoPlanSponsorshipType_Throws(OrganizationSponsorship sponsorship) - { - sponsorship.PlanSponsorshipType = null; + [Theory, BitAutoData] + public void Constructor_NoPlanSponsorshipType_Throws(OrganizationSponsorship sponsorship) + { + sponsorship.PlanSponsorshipType = null; - Assert.Throws(() => new OrganizationSponsorshipOfferTokenable(sponsorship)); - } + Assert.Throws(() => new OrganizationSponsorshipOfferTokenable(sponsorship)); } } diff --git a/test/Core.Test/Models/Business/Tokenables/SsoTokenableTests.cs b/test/Core.Test/Models/Business/Tokenables/SsoTokenableTests.cs index aef71e5ba..0ec4b5a35 100644 --- a/test/Core.Test/Models/Business/Tokenables/SsoTokenableTests.cs +++ b/test/Core.Test/Models/Business/Tokenables/SsoTokenableTests.cs @@ -5,85 +5,84 @@ using Bit.Core.Tokens; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.Models.Business.Tokenables +namespace Bit.Core.Test.Models.Business.Tokenables; + +public class SsoTokenableTests { - public class SsoTokenableTests + [Fact] + public void CanHandleNullOrganization() { - [Fact] - public void CanHandleNullOrganization() + var token = new SsoTokenable(null, default); + + Assert.Equal(default, token.OrganizationId); + Assert.Equal(default, token.DomainHint); + } + + [Fact] + public void TokenWithNullOrganizationIsInvalid() + { + var token = new SsoTokenable(null, 500) { - var token = new SsoTokenable(null, default); + ExpirationDate = DateTime.UtcNow + TimeSpan.FromDays(1) + }; - Assert.Equal(default, token.OrganizationId); - Assert.Equal(default, token.DomainHint); - } + Assert.False(token.Valid); + } - [Fact] - public void TokenWithNullOrganizationIsInvalid() + [Theory, BitAutoData] + public void TokenValidityCheckNullOrganizationIsInvalid(Organization organization) + { + var token = new SsoTokenable(organization, 500) { - var token = new SsoTokenable(null, 500) - { - ExpirationDate = DateTime.UtcNow + TimeSpan.FromDays(1) - }; + ExpirationDate = DateTime.UtcNow + TimeSpan.FromDays(1) + }; - Assert.False(token.Valid); - } + Assert.False(token.TokenIsValid(null)); + } - [Theory, BitAutoData] - public void TokenValidityCheckNullOrganizationIsInvalid(Organization organization) + [Theory, AutoData] + public void SetsDataFromOrganization(Organization organization) + { + var token = new SsoTokenable(organization, default); + + Assert.Equal(organization.Id, token.OrganizationId); + Assert.Equal(organization.Identifier, token.DomainHint); + } + + [Fact] + public void SetsExpirationFromConstructor() + { + var expectedDateTime = DateTime.UtcNow.AddSeconds(500); + var token = new SsoTokenable(null, 500); + + Assert.Equal(expectedDateTime, token.ExpirationDate, TimeSpan.FromMilliseconds(10)); + } + + [Theory, AutoData] + public void SerializationSetsCorrectDateTime(Organization organization) + { + var expectedDateTime = DateTime.UtcNow.AddHours(-5); + var token = new SsoTokenable(organization, default) { - var token = new SsoTokenable(organization, 500) - { - ExpirationDate = DateTime.UtcNow + TimeSpan.FromDays(1) - }; + ExpirationDate = expectedDateTime + }; - Assert.False(token.TokenIsValid(null)); - } + var result = Tokenable.FromToken(token.ToToken()); - [Theory, AutoData] - public void SetsDataFromOrganization(Organization organization) + Assert.Equal(expectedDateTime, result.ExpirationDate, TimeSpan.FromMilliseconds(10)); + } + + [Theory, AutoData] + public void TokenIsValidFailsWhenExpired(Organization organization) + { + var expectedDateTime = DateTime.UtcNow.AddHours(-5); + var token = new SsoTokenable(organization, default) { - var token = new SsoTokenable(organization, default); + ExpirationDate = expectedDateTime + }; - Assert.Equal(organization.Id, token.OrganizationId); - Assert.Equal(organization.Identifier, token.DomainHint); - } + var result = token.TokenIsValid(organization); - [Fact] - public void SetsExpirationFromConstructor() - { - var expectedDateTime = DateTime.UtcNow.AddSeconds(500); - var token = new SsoTokenable(null, 500); - - Assert.Equal(expectedDateTime, token.ExpirationDate, TimeSpan.FromMilliseconds(10)); - } - - [Theory, AutoData] - public void SerializationSetsCorrectDateTime(Organization organization) - { - var expectedDateTime = DateTime.UtcNow.AddHours(-5); - var token = new SsoTokenable(organization, default) - { - ExpirationDate = expectedDateTime - }; - - var result = Tokenable.FromToken(token.ToToken()); - - Assert.Equal(expectedDateTime, result.ExpirationDate, TimeSpan.FromMilliseconds(10)); - } - - [Theory, AutoData] - public void TokenIsValidFailsWhenExpired(Organization organization) - { - var expectedDateTime = DateTime.UtcNow.AddHours(-5); - var token = new SsoTokenable(organization, default) - { - ExpirationDate = expectedDateTime - }; - - var result = token.TokenIsValid(organization); - - Assert.False(result); - } + Assert.False(result); } } diff --git a/test/Core.Test/Models/CipherTests.cs b/test/Core.Test/Models/CipherTests.cs index 3993f4caf..af7a0b6e3 100644 --- a/test/Core.Test/Models/CipherTests.cs +++ b/test/Core.Test/Models/CipherTests.cs @@ -3,16 +3,15 @@ using Bit.Core.Entities; using Bit.Core.Test.AutoFixture.CipherFixtures; using Xunit; -namespace Bit.Core.Test.Models +namespace Bit.Core.Test.Models; + +public class CipherTests { - public class CipherTests + [Theory] + [InlineUserCipherAutoData] + [InlineOrganizationCipherAutoData] + public void Clone_CreatesExactCopy(Cipher cipher) { - [Theory] - [InlineUserCipherAutoData] - [InlineOrganizationCipherAutoData] - public void Clone_CreatesExactCopy(Cipher cipher) - { - Assert.Equal(JsonSerializer.Serialize(cipher), JsonSerializer.Serialize(cipher.Clone())); - } + Assert.Equal(JsonSerializer.Serialize(cipher), JsonSerializer.Serialize(cipher.Clone())); } } diff --git a/test/Core.Test/Models/Data/SendFileDataTests.cs b/test/Core.Test/Models/Data/SendFileDataTests.cs index 7a7dc9bc5..6f2afe748 100644 --- a/test/Core.Test/Models/Data/SendFileDataTests.cs +++ b/test/Core.Test/Models/Data/SendFileDataTests.cs @@ -3,26 +3,25 @@ using Bit.Core.Models.Data; using Bit.Test.Common.Helpers; using Xunit; -namespace Bit.Core.Test.Models.Data -{ - public class SendFileDataTests - { - [Fact] - public void Serialize_Success() - { - var sut = new SendFileData - { - Id = "test", - Size = 100, - FileName = "thing.pdf", - Validated = true, - }; +namespace Bit.Core.Test.Models.Data; - var json = JsonSerializer.Serialize(sut); - var document = JsonDocument.Parse(json); - var root = document.RootElement; - AssertHelper.AssertJsonProperty(root, "Size", JsonValueKind.String); - Assert.False(root.TryGetProperty("SizeString", out _)); - } +public class SendFileDataTests +{ + [Fact] + public void Serialize_Success() + { + var sut = new SendFileData + { + Id = "test", + Size = 100, + FileName = "thing.pdf", + Validated = true, + }; + + var json = JsonSerializer.Serialize(sut); + var document = JsonDocument.Parse(json); + var root = document.RootElement; + AssertHelper.AssertJsonProperty(root, "Size", JsonValueKind.String); + Assert.False(root.TryGetProperty("SizeString", out _)); } } diff --git a/test/Core.Test/Models/PermissionsTests.cs b/test/Core.Test/Models/PermissionsTests.cs index c8522eaa2..76b88f6ff 100644 --- a/test/Core.Test/Models/PermissionsTests.cs +++ b/test/Core.Test/Models/PermissionsTests.cs @@ -3,59 +3,58 @@ using Bit.Core.Models.Data; using Bit.Core.Utilities; using Xunit; -namespace Bit.Core.Test.Models +namespace Bit.Core.Test.Models; + +public class PermissionsTests { - public class PermissionsTests + private static readonly string _exampleSerializedPermissions = string.Concat( + "{", + "\"accessEventLogs\": false,", + "\"accessImportExport\": false,", + "\"accessReports\": false,", + "\"manageAllCollections\": true,", // exists for backwards compatibility + "\"createNewCollections\": true,", + "\"editAnyCollection\": true,", + "\"deleteAnyCollection\": true,", + "\"manageAssignedCollections\": false,", // exists for backwards compatibility + "\"editAssignedCollections\": false,", + "\"deleteAssignedCollections\": false,", + "\"manageGroups\": false,", + "\"managePolicies\": false,", + "\"manageSso\": false,", + "\"manageUsers\": false,", + "\"manageResetPassword\": false,", + "\"manageScim\": false", + "}"); + + [Fact] + public void Serialization_Success() { - private static readonly string _exampleSerializedPermissions = string.Concat( - "{", - "\"accessEventLogs\": false,", - "\"accessImportExport\": false,", - "\"accessReports\": false,", - "\"manageAllCollections\": true,", // exists for backwards compatibility - "\"createNewCollections\": true,", - "\"editAnyCollection\": true,", - "\"deleteAnyCollection\": true,", - "\"manageAssignedCollections\": false,", // exists for backwards compatibility - "\"editAssignedCollections\": false,", - "\"deleteAssignedCollections\": false,", - "\"manageGroups\": false,", - "\"managePolicies\": false,", - "\"manageSso\": false,", - "\"manageUsers\": false,", - "\"manageResetPassword\": false,", - "\"manageScim\": false", - "}"); - - [Fact] - public void Serialization_Success() + var permissions = new Permissions { - var permissions = new Permissions - { - AccessEventLogs = false, - AccessImportExport = false, - AccessReports = false, - CreateNewCollections = true, - EditAnyCollection = true, - DeleteAnyCollection = true, - EditAssignedCollections = false, - DeleteAssignedCollections = false, - ManageGroups = false, - ManagePolicies = false, - ManageSso = false, - ManageUsers = false, - ManageResetPassword = false, - ManageScim = false, - }; + AccessEventLogs = false, + AccessImportExport = false, + AccessReports = false, + CreateNewCollections = true, + EditAnyCollection = true, + DeleteAnyCollection = true, + EditAssignedCollections = false, + DeleteAssignedCollections = false, + ManageGroups = false, + ManagePolicies = false, + ManageSso = false, + ManageUsers = false, + ManageResetPassword = false, + ManageScim = false, + }; - // minify expected json - var expected = JsonSerializer.Serialize(permissions, JsonHelpers.CamelCase); + // minify expected json + var expected = JsonSerializer.Serialize(permissions, JsonHelpers.CamelCase); - var actual = JsonSerializer.Serialize( - JsonHelpers.DeserializeOrNew(_exampleSerializedPermissions, JsonHelpers.CamelCase), - JsonHelpers.CamelCase); + var actual = JsonSerializer.Serialize( + JsonHelpers.DeserializeOrNew(_exampleSerializedPermissions, JsonHelpers.CamelCase), + JsonHelpers.CamelCase); - Assert.Equal(expected, actual); - } + Assert.Equal(expected, actual); } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationApiKeys/GetOrganizationApiKeyCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationApiKeys/GetOrganizationApiKeyCommandTests.cs index e81d2bcc8..de568f265 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationApiKeys/GetOrganizationApiKeyCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationApiKeys/GetOrganizationApiKeyCommandTests.cs @@ -7,94 +7,93 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationApiKeys +namespace Bit.Core.Test.OrganizationFeatures.OrganizationApiKeys; + +[SutProviderCustomize] +public class GetOrganizationApiKeyCommandTests { - [SutProviderCustomize] - public class GetOrganizationApiKeyCommandTests + [Theory] + [BitAutoData] + public async Task GetOrganizationApiKey_HasOne_Returns(SutProvider sutProvider, + Guid id, Guid organizationId, OrganizationApiKeyType keyType) { - [Theory] - [BitAutoData] - public async Task GetOrganizationApiKey_HasOne_Returns(SutProvider sutProvider, - Guid id, Guid organizationId, OrganizationApiKeyType keyType) - { - sutProvider.GetDependency() - .GetManyByOrganizationIdTypeAsync(organizationId, keyType) - .Returns(new List + sutProvider.GetDependency() + .GetManyByOrganizationIdTypeAsync(organizationId, keyType) + .Returns(new List + { + new OrganizationApiKey { - new OrganizationApiKey - { - Id = id, - OrganizationId = organizationId, - ApiKey = "test", - Type = keyType, - RevisionDate = DateTime.Now.AddDays(-1), - }, - }); + Id = id, + OrganizationId = organizationId, + ApiKey = "test", + Type = keyType, + RevisionDate = DateTime.Now.AddDays(-1), + }, + }); - var apiKey = await sutProvider.Sut.GetOrganizationApiKeyAsync(organizationId, keyType); - Assert.NotNull(apiKey); - Assert.Equal(id, apiKey.Id); - } + var apiKey = await sutProvider.Sut.GetOrganizationApiKeyAsync(organizationId, keyType); + Assert.NotNull(apiKey); + Assert.Equal(id, apiKey.Id); + } - [Theory] - [BitAutoData] - public async Task GetOrganizationApiKey_HasTwo_Throws(SutProvider sutProvider, - Guid organizationId, OrganizationApiKeyType keyType) - { - sutProvider.GetDependency() - .GetManyByOrganizationIdTypeAsync(organizationId, keyType) - .Returns(new List + [Theory] + [BitAutoData] + public async Task GetOrganizationApiKey_HasTwo_Throws(SutProvider sutProvider, + Guid organizationId, OrganizationApiKeyType keyType) + { + sutProvider.GetDependency() + .GetManyByOrganizationIdTypeAsync(organizationId, keyType) + .Returns(new List + { + new OrganizationApiKey { - new OrganizationApiKey - { - Id = Guid.NewGuid(), - OrganizationId = organizationId, - ApiKey = "test", - Type = keyType, - RevisionDate = DateTime.Now.AddDays(-1), - }, - new OrganizationApiKey - { - Id = Guid.NewGuid(), - OrganizationId = organizationId, - ApiKey = "test_other", - Type = keyType, - RevisionDate = DateTime.Now.AddDays(-1), - }, - }); + Id = Guid.NewGuid(), + OrganizationId = organizationId, + ApiKey = "test", + Type = keyType, + RevisionDate = DateTime.Now.AddDays(-1), + }, + new OrganizationApiKey + { + Id = Guid.NewGuid(), + OrganizationId = organizationId, + ApiKey = "test_other", + Type = keyType, + RevisionDate = DateTime.Now.AddDays(-1), + }, + }); - await Assert.ThrowsAsync( - async () => await sutProvider.Sut.GetOrganizationApiKeyAsync(organizationId, keyType)); - } + await Assert.ThrowsAsync( + async () => await sutProvider.Sut.GetOrganizationApiKeyAsync(organizationId, keyType)); + } - [Theory] - [BitAutoData] - public async Task GetOrganizationApiKey_HasNone_CreatesAndReturns(SutProvider sutProvider, - Guid organizationId, OrganizationApiKeyType keyType) - { - sutProvider.GetDependency() - .GetManyByOrganizationIdTypeAsync(organizationId, keyType) - .Returns(Enumerable.Empty()); + [Theory] + [BitAutoData] + public async Task GetOrganizationApiKey_HasNone_CreatesAndReturns(SutProvider sutProvider, + Guid organizationId, OrganizationApiKeyType keyType) + { + sutProvider.GetDependency() + .GetManyByOrganizationIdTypeAsync(organizationId, keyType) + .Returns(Enumerable.Empty()); - var apiKey = await sutProvider.Sut.GetOrganizationApiKeyAsync(organizationId, keyType); + var apiKey = await sutProvider.Sut.GetOrganizationApiKeyAsync(organizationId, keyType); - Assert.NotNull(apiKey); - Assert.Equal(organizationId, apiKey.OrganizationId); - Assert.Equal(keyType, apiKey.Type); - await sutProvider.GetDependency() - .Received(1) - .CreateAsync(Arg.Any()); - } + Assert.NotNull(apiKey); + Assert.Equal(organizationId, apiKey.OrganizationId); + Assert.Equal(keyType, apiKey.Type); + await sutProvider.GetDependency() + .Received(1) + .CreateAsync(Arg.Any()); + } - [Theory] - [BitAutoData] - public async Task GetOrganizationApiKey_BadType_Throws(SutProvider sutProvider, - Guid organizationId, OrganizationApiKeyType keyType) - { - keyType = (OrganizationApiKeyType)byte.MaxValue; + [Theory] + [BitAutoData] + public async Task GetOrganizationApiKey_BadType_Throws(SutProvider sutProvider, + Guid organizationId, OrganizationApiKeyType keyType) + { + keyType = (OrganizationApiKeyType)byte.MaxValue; - await Assert.ThrowsAsync( - async () => await sutProvider.Sut.GetOrganizationApiKeyAsync(organizationId, keyType)); - } + await Assert.ThrowsAsync( + async () => await sutProvider.Sut.GetOrganizationApiKeyAsync(organizationId, keyType)); } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationApiKeys/RotateOrganizationApiKeyCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationApiKeys/RotateOrganizationApiKeyCommandTests.cs index 5bea4b8d2..dc2ec10c2 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationApiKeys/RotateOrganizationApiKeyCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationApiKeys/RotateOrganizationApiKeyCommandTests.cs @@ -5,19 +5,18 @@ using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.Helpers; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationApiKeys +namespace Bit.Core.Test.OrganizationFeatures.OrganizationApiKeys; + +[SutProviderCustomize] +public class RotateOrganizationApiKeyCommandTests { - [SutProviderCustomize] - public class RotateOrganizationApiKeyCommandTests + [Theory, BitAutoData] + public async Task RotateApiKeyAsync_RotatesKey(SutProvider sutProvider, + OrganizationApiKey organizationApiKey) { - [Theory, BitAutoData] - public async Task RotateApiKeyAsync_RotatesKey(SutProvider sutProvider, - OrganizationApiKey organizationApiKey) - { - var existingKey = organizationApiKey.ApiKey; - organizationApiKey = await sutProvider.Sut.RotateApiKeyAsync(organizationApiKey); - Assert.NotEqual(existingKey, organizationApiKey.ApiKey); - AssertHelper.AssertRecent(organizationApiKey.RevisionDate); - } + var existingKey = organizationApiKey.ApiKey; + organizationApiKey = await sutProvider.Sut.RotateApiKeyAsync(organizationApiKey); + Assert.NotEqual(existingKey, organizationApiKey.ApiKey); + AssertHelper.AssertRecent(organizationApiKey.RevisionDate); } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationCollections/CreateOrganizationConnectionCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationCollections/CreateOrganizationConnectionCommandTests.cs index c46a7e706..bfcb532d8 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationCollections/CreateOrganizationConnectionCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationCollections/CreateOrganizationConnectionCommandTests.cs @@ -8,20 +8,19 @@ using Bit.Test.Common.Helpers; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationConnections -{ - [SutProviderCustomize] - public class CreateOrganizationConnectionCommandTests - { - [Theory] - [BitAutoData] - public async Task CreateAsync_CallsCreate(OrganizationConnectionData data, - SutProvider sutProvider) - { - await sutProvider.Sut.CreateAsync(data); +namespace Bit.Core.Test.OrganizationFeatures.OrganizationConnections; - await sutProvider.GetDependency().Received(1) - .CreateAsync(Arg.Is(AssertHelper.AssertPropertyEqual(data.ToEntity()))); - } +[SutProviderCustomize] +public class CreateOrganizationConnectionCommandTests +{ + [Theory] + [BitAutoData] + public async Task CreateAsync_CallsCreate(OrganizationConnectionData data, + SutProvider sutProvider) + { + await sutProvider.Sut.CreateAsync(data); + + await sutProvider.GetDependency().Received(1) + .CreateAsync(Arg.Is(AssertHelper.AssertPropertyEqual(data.ToEntity()))); } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationCollections/DeleteOrganizationConnectionCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationCollections/DeleteOrganizationConnectionCommandTests.cs index 5a6690cd7..9432968f5 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationCollections/DeleteOrganizationConnectionCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationCollections/DeleteOrganizationConnectionCommandTests.cs @@ -6,20 +6,19 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationConnections -{ - [SutProviderCustomize] - public class DeleteOrganizationConnectionCommandTests - { - [Theory] - [BitAutoData] - public async Task DeleteAsync_CallsDelete(OrganizationConnection connection, - SutProvider sutProvider) - { - await sutProvider.Sut.DeleteAsync(connection); +namespace Bit.Core.Test.OrganizationFeatures.OrganizationConnections; - await sutProvider.GetDependency().Received(1) - .DeleteAsync(connection); - } +[SutProviderCustomize] +public class DeleteOrganizationConnectionCommandTests +{ + [Theory] + [BitAutoData] + public async Task DeleteAsync_CallsDelete(OrganizationConnection connection, + SutProvider sutProvider) + { + await sutProvider.Sut.DeleteAsync(connection); + + await sutProvider.GetDependency().Received(1) + .DeleteAsync(connection); } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationCollections/UpdateOrganizationConnectionCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationCollections/UpdateOrganizationConnectionCommandTests.cs index dba643214..f46d799d1 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationCollections/UpdateOrganizationConnectionCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationCollections/UpdateOrganizationConnectionCommandTests.cs @@ -10,50 +10,49 @@ using Bit.Test.Common.Helpers; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationConnections +namespace Bit.Core.Test.OrganizationFeatures.OrganizationConnections; + +[SutProviderCustomize] +public class UpdateOrganizationConnectionCommandTests { - [SutProviderCustomize] - public class UpdateOrganizationConnectionCommandTests + [Theory] + [BitAutoData] + public async Task UpdateAsync_NoId_Fails(OrganizationConnectionData data, + SutProvider sutProvider) { - [Theory] - [BitAutoData] - public async Task UpdateAsync_NoId_Fails(OrganizationConnectionData data, - SutProvider sutProvider) - { - data.Id = null; + data.Id = null; - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateAsync(data)); + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateAsync(data)); - Assert.Contains("Cannot update connection, Connection does not exist.", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } + Assert.Contains("Cannot update connection, Connection does not exist.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } - [Theory] - [BitAutoData] - public async Task UpdateAsync_ConnectionDoesNotExist_ThrowsNotFound( - OrganizationConnectionData data, - SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateAsync(data)); + [Theory] + [BitAutoData] + public async Task UpdateAsync_ConnectionDoesNotExist_ThrowsNotFound( + OrganizationConnectionData data, + SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateAsync(data)); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } - [Theory] - [BitAutoData] - public async Task UpdateAsync_CallsUpsert(OrganizationConnectionData data, - OrganizationConnection existing, - SutProvider sutProvider) - { - data.Id = existing.Id; + [Theory] + [BitAutoData] + public async Task UpdateAsync_CallsUpsert(OrganizationConnectionData data, + OrganizationConnection existing, + SutProvider sutProvider) + { + data.Id = existing.Id; - sutProvider.GetDependency().GetByIdAsync(data.Id.Value).Returns(existing); - await sutProvider.Sut.UpdateAsync(data); + sutProvider.GetDependency().GetByIdAsync(data.Id.Value).Returns(existing); + await sutProvider.Sut.UpdateAsync(data); - await sutProvider.GetDependency().Received(1) - .UpsertAsync(Arg.Is(AssertHelper.AssertPropertyEqual(data.ToEntity()))); - } + await sutProvider.GetDependency().Received(1) + .UpsertAsync(Arg.Is(AssertHelper.AssertPropertyEqual(data.ToEntity()))); } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs index 882395721..ca684a30c 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs @@ -4,71 +4,70 @@ using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using NSubstitute; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; + +public abstract class CancelSponsorshipCommandTestsBase : FamiliesForEnterpriseTestsBase { - public abstract class CancelSponsorshipCommandTestsBase : FamiliesForEnterpriseTestsBase + protected async Task AssertRemovedSponsoredPaymentAsync(Organization sponsoredOrg, +OrganizationSponsorship sponsorship, SutProvider sutProvider) { - protected async Task AssertRemovedSponsoredPaymentAsync(Organization sponsoredOrg, - OrganizationSponsorship sponsorship, SutProvider sutProvider) + await sutProvider.GetDependency().Received(1) + .RemoveOrganizationSponsorshipAsync(sponsoredOrg, sponsorship); + await sutProvider.GetDependency().Received(1).UpsertAsync(sponsoredOrg); + if (sponsorship != null) { - await sutProvider.GetDependency().Received(1) - .RemoveOrganizationSponsorshipAsync(sponsoredOrg, sponsorship); - await sutProvider.GetDependency().Received(1).UpsertAsync(sponsoredOrg); - if (sponsorship != null) - { - await sutProvider.GetDependency().Received(1) - .SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(sponsoredOrg.BillingEmailAddress(), sponsorship.ValidUntil.GetValueOrDefault()); - } - } - - protected async Task AssertDeletedSponsorshipAsync(OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - await sutProvider.GetDependency().Received(1) - .DeleteAsync(sponsorship); - } - - protected static async Task AssertDidNotRemoveSponsorshipAsync(SutProvider sutProvider) - { - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .DeleteAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } - - protected async Task AssertRemovedSponsorshipAsync(OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - await sutProvider.GetDependency().Received(1) - .DeleteAsync(sponsorship); - } - - protected static async Task AssertDidNotRemoveSponsoredPaymentAsync(SutProvider sutProvider) - { - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .RemoveOrganizationSponsorshipAsync(default, default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(default, default); - } - - protected static async Task AssertDidNotDeleteSponsorshipAsync(SutProvider sutProvider) - { - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .DeleteAsync(default); - } - - protected static async Task AssertDidNotUpdateSponsorshipAsync(SutProvider sutProvider) - { - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } - - protected static async Task AssertUpdatedSponsorshipAsync(OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - await sutProvider.GetDependency().Received(1).UpsertAsync(sponsorship); + await sutProvider.GetDependency().Received(1) + .SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(sponsoredOrg.BillingEmailAddress(), sponsorship.ValidUntil.GetValueOrDefault()); } } + + protected async Task AssertDeletedSponsorshipAsync(OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + await sutProvider.GetDependency().Received(1) + .DeleteAsync(sponsorship); + } + + protected static async Task AssertDidNotRemoveSponsorshipAsync(SutProvider sutProvider) + { + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .DeleteAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } + + protected async Task AssertRemovedSponsorshipAsync(OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + await sutProvider.GetDependency().Received(1) + .DeleteAsync(sponsorship); + } + + protected static async Task AssertDidNotRemoveSponsoredPaymentAsync(SutProvider sutProvider) + { + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .RemoveOrganizationSponsorshipAsync(default, default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .SendFamiliesForEnterpriseSponsorshipRevertingEmailAsync(default, default); + } + + protected static async Task AssertDidNotDeleteSponsorshipAsync(SutProvider sutProvider) + { + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .DeleteAsync(default); + } + + protected static async Task AssertDidNotUpdateSponsorshipAsync(SutProvider sutProvider) + { + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } + + protected static async Task AssertUpdatedSponsorshipAsync(OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + await sutProvider.GetDependency().Received(1).UpsertAsync(sponsorship); + } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudRevokeSponsorshipCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudRevokeSponsorshipCommandTests.cs index f0c7e976a..2b9a27c1a 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudRevokeSponsorshipCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudRevokeSponsorshipCommandTests.cs @@ -6,46 +6,45 @@ using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; + +[SutProviderCustomize] +[OrganizationSponsorshipCustomize] +public class CloudRevokeSponsorshipCommandTests : CancelSponsorshipCommandTestsBase { - [SutProviderCustomize] - [OrganizationSponsorshipCustomize] - public class CloudRevokeSponsorshipCommandTests : CancelSponsorshipCommandTestsBase + [Theory] + [BitAutoData] + public async Task RevokeSponsorship_NoExistingSponsorship_ThrowsBadRequest( + SutProvider sutProvider) { - [Theory] - [BitAutoData] - public async Task RevokeSponsorship_NoExistingSponsorship_ThrowsBadRequest( - SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.RevokeSponsorshipAsync(null)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RevokeSponsorshipAsync(null)); - Assert.Contains("You are not currently sponsoring an organization.", exception.Message); - await AssertDidNotDeleteSponsorshipAsync(sutProvider); - await AssertDidNotUpdateSponsorshipAsync(sutProvider); - } + Assert.Contains("You are not currently sponsoring an organization.", exception.Message); + await AssertDidNotDeleteSponsorshipAsync(sutProvider); + await AssertDidNotUpdateSponsorshipAsync(sutProvider); + } - [Theory] - [BitAutoData] - public async Task RevokeSponsorship_SponsorshipNotRedeemed_DeletesSponsorship(OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - sponsorship.SponsoredOrganizationId = null; + [Theory] + [BitAutoData] + public async Task RevokeSponsorship_SponsorshipNotRedeemed_DeletesSponsorship(OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + sponsorship.SponsoredOrganizationId = null; - await sutProvider.Sut.RevokeSponsorshipAsync(sponsorship); - await AssertDeletedSponsorshipAsync(sponsorship, sutProvider); - } + await sutProvider.Sut.RevokeSponsorshipAsync(sponsorship); + await AssertDeletedSponsorshipAsync(sponsorship, sutProvider); + } - [Theory] - [BitAutoData] - public async Task RevokeSponsorship_SponsorshipRedeemed_MarksForDelete(OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - await sutProvider.Sut.RevokeSponsorshipAsync(sponsorship); + [Theory] + [BitAutoData] + public async Task RevokeSponsorship_SponsorshipRedeemed_MarksForDelete(OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + await sutProvider.Sut.RevokeSponsorshipAsync(sponsorship); - Assert.True(sponsorship.ToDelete); - await AssertUpdatedSponsorshipAsync(sponsorship, sutProvider); - await AssertDidNotDeleteSponsorshipAsync(sutProvider); - } + Assert.True(sponsorship.ToDelete); + await AssertUpdatedSponsorshipAsync(sponsorship, sutProvider); + await AssertDidNotDeleteSponsorshipAsync(sutProvider); } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommandTests.cs index 3a5517814..f7534d8a7 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommandTests.cs @@ -10,218 +10,216 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; + +[SutProviderCustomize] +public class CloudSyncSponsorshipsCommandTests : FamiliesForEnterpriseTestsBase { - [SutProviderCustomize] - public class CloudSyncSponsorshipsCommandTests : FamiliesForEnterpriseTestsBase + [Theory] + [BitAutoData] + public async Task SyncOrganization_SponsoringOrgNotFound_ThrowsBadRequest( + IEnumerable sponsorshipsData, + SutProvider sutProvider) { + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SyncOrganization(null, sponsorshipsData)); - [Theory] - [BitAutoData] - public async Task SyncOrganization_SponsoringOrgNotFound_ThrowsBadRequest( - IEnumerable sponsorshipsData, - SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SyncOrganization(null, sponsorshipsData)); + Assert.Contains("Failed to sync sponsorship - missing organization.", exception.Message); - Assert.Contains("Failed to sync sponsorship - missing organization.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertManyAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + } - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertManyAsync(default); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); - } + [Theory] + [BitAutoData] + public async Task SyncOrganization_NoSponsorships_EarlyReturn( + Organization organization, + SutProvider sutProvider) + { + var result = await sutProvider.Sut.SyncOrganization(organization, Enumerable.Empty()); - [Theory] - [BitAutoData] - public async Task SyncOrganization_NoSponsorships_EarlyReturn( - Organization organization, - SutProvider sutProvider) - { - var result = await sutProvider.Sut.SyncOrganization(organization, Enumerable.Empty()); + Assert.Empty(result.Item1.SponsorshipsBatch); + Assert.Empty(result.Item2); - Assert.Empty(result.Item1.SponsorshipsBatch); - Assert.Empty(result.Item2); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertManyAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + } - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertManyAsync(default); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); - } + [Theory] + [BitMemberAutoData(nameof(NonEnterprisePlanTypes))] + public async Task SyncOrganization_BadSponsoringOrgPlan_NoSync( + PlanType planType, + Organization organization, IEnumerable sponsorshipsData, + SutProvider sutProvider) + { + organization.PlanType = planType; - [Theory] - [BitMemberAutoData(nameof(NonEnterprisePlanTypes))] - public async Task SyncOrganization_BadSponsoringOrgPlan_NoSync( - PlanType planType, - Organization organization, IEnumerable sponsorshipsData, - SutProvider sutProvider) - { - organization.PlanType = planType; + await sutProvider.Sut.SyncOrganization(organization, sponsorshipsData); - await sutProvider.Sut.SyncOrganization(organization, sponsorshipsData); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertManyAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + } - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertManyAsync(default); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); - } + [Theory] + [BitAutoData] + public async Task SyncOrganization_Success_RecordsEvent(Organization organization, + SutProvider sutProvider) + { + await sutProvider.Sut.SyncOrganization(organization, Array.Empty()); - [Theory] - [BitAutoData] - public async Task SyncOrganization_Success_RecordsEvent(Organization organization, - SutProvider sutProvider) - { - await sutProvider.Sut.SyncOrganization(organization, Array.Empty()); + await sutProvider.GetDependency().Received(1).LogOrganizationEventAsync(organization, EventType.Organization_SponsorshipsSynced, Arg.Any()); + } - await sutProvider.GetDependency().Received(1).LogOrganizationEventAsync(organization, EventType.Organization_SponsorshipsSynced, Arg.Any()); - } + [Theory] + [BitAutoData] + public async Task SyncOrganization_OneExisting_OneNew_Success(SutProvider sutProvider, + Organization sponsoringOrganization, OrganizationSponsorship existingSponsorship, OrganizationSponsorship newSponsorship) + { + // Arrange + sponsoringOrganization.Enabled = true; + sponsoringOrganization.PlanType = PlanType.EnterpriseAnnually; - [Theory] - [BitAutoData] - public async Task SyncOrganization_OneExisting_OneNew_Success(SutProvider sutProvider, - Organization sponsoringOrganization, OrganizationSponsorship existingSponsorship, OrganizationSponsorship newSponsorship) - { - // Arrange - sponsoringOrganization.Enabled = true; - sponsoringOrganization.PlanType = PlanType.EnterpriseAnnually; + existingSponsorship.ToDelete = false; + newSponsorship.ToDelete = false; - existingSponsorship.ToDelete = false; - newSponsorship.ToDelete = false; - - sutProvider.GetDependency() - .GetManyBySponsoringOrganizationAsync(sponsoringOrganization.Id) - .Returns(new List - { - existingSponsorship, - }); - - // Act - var (syncData, toEmailSponsorships) = await sutProvider.Sut.SyncOrganization(sponsoringOrganization, new[] + sutProvider.GetDependency() + .GetManyBySponsoringOrganizationAsync(sponsoringOrganization.Id) + .Returns(new List { - new OrganizationSponsorshipData(existingSponsorship), - new OrganizationSponsorshipData(newSponsorship), + existingSponsorship, }); - // Assert - // Should have updated the cloud copy for each item given - await sutProvider.GetDependency() - .Received(1) - .UpsertManyAsync(Arg.Is>(sponsorships => sponsorships.Count() == 2)); - - // Neither were marked as delete, should not have deleted - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); - - // Only one sponsorship was new so it should only send one - Assert.Single(toEmailSponsorships); - } - - [Theory] - [BitAutoData] - public async Task SyncOrganization_TwoToDelete_OneCanDelete_Success(SutProvider sutProvider, - Organization sponsoringOrganization, OrganizationSponsorship canDeleteSponsorship, OrganizationSponsorship cannotDeleteSponsorship) + // Act + var (syncData, toEmailSponsorships) = await sutProvider.Sut.SyncOrganization(sponsoringOrganization, new[] { - // Arrange - sponsoringOrganization.PlanType = PlanType.EnterpriseAnnually; + new OrganizationSponsorshipData(existingSponsorship), + new OrganizationSponsorshipData(newSponsorship), + }); - canDeleteSponsorship.ToDelete = true; - canDeleteSponsorship.SponsoredOrganizationId = null; + // Assert + // Should have updated the cloud copy for each item given + await sutProvider.GetDependency() + .Received(1) + .UpsertManyAsync(Arg.Is>(sponsorships => sponsorships.Count() == 2)); - cannotDeleteSponsorship.ToDelete = true; - cannotDeleteSponsorship.SponsoredOrganizationId = Guid.NewGuid(); + // Neither were marked as delete, should not have deleted + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); - sutProvider.GetDependency() - .GetManyBySponsoringOrganizationAsync(sponsoringOrganization.Id) - .Returns(new List - { - canDeleteSponsorship, - cannotDeleteSponsorship, - }); + // Only one sponsorship was new so it should only send one + Assert.Single(toEmailSponsorships); + } - // Act - var (syncData, toEmailSponsorships) = await sutProvider.Sut.SyncOrganization(sponsoringOrganization, new[] + [Theory] + [BitAutoData] + public async Task SyncOrganization_TwoToDelete_OneCanDelete_Success(SutProvider sutProvider, + Organization sponsoringOrganization, OrganizationSponsorship canDeleteSponsorship, OrganizationSponsorship cannotDeleteSponsorship) + { + // Arrange + sponsoringOrganization.PlanType = PlanType.EnterpriseAnnually; + + canDeleteSponsorship.ToDelete = true; + canDeleteSponsorship.SponsoredOrganizationId = null; + + cannotDeleteSponsorship.ToDelete = true; + cannotDeleteSponsorship.SponsoredOrganizationId = Guid.NewGuid(); + + sutProvider.GetDependency() + .GetManyBySponsoringOrganizationAsync(sponsoringOrganization.Id) + .Returns(new List { - new OrganizationSponsorshipData(canDeleteSponsorship), - new OrganizationSponsorshipData(cannotDeleteSponsorship), + canDeleteSponsorship, + cannotDeleteSponsorship, }); - // Assert - - await sutProvider.GetDependency() - .Received(1) - .UpsertManyAsync(Arg.Is>(sponsorships => sponsorships.Count() == 2)); - - // Deletes the sponsorship that had delete requested and is not sponsoring an org - await sutProvider.GetDependency() - .Received(1) - .DeleteManyAsync(Arg.Is>(toDeleteIds => - toDeleteIds.Count() == 1 && toDeleteIds.ElementAt(0) == canDeleteSponsorship.Id)); - } - - [Theory] - [BitAutoData] - public async Task SyncOrganization_BadData_DoesNotSave(SutProvider sutProvider, - Organization sponsoringOrganization, OrganizationSponsorship badOrganizationSponsorship) + // Act + var (syncData, toEmailSponsorships) = await sutProvider.Sut.SyncOrganization(sponsoringOrganization, new[] { - sponsoringOrganization.PlanType = PlanType.EnterpriseAnnually; + new OrganizationSponsorshipData(canDeleteSponsorship), + new OrganizationSponsorshipData(cannotDeleteSponsorship), + }); - badOrganizationSponsorship.ToDelete = true; - badOrganizationSponsorship.LastSyncDate = null; + // Assert - sutProvider.GetDependency() - .GetManyBySponsoringOrganizationAsync(sponsoringOrganization.Id) - .Returns(new List()); + await sutProvider.GetDependency() + .Received(1) + .UpsertManyAsync(Arg.Is>(sponsorships => sponsorships.Count() == 2)); - var (syncData, toEmailSponsorships) = await sutProvider.Sut.SyncOrganization(sponsoringOrganization, new[] - { - new OrganizationSponsorshipData(badOrganizationSponsorship), - }); + // Deletes the sponsorship that had delete requested and is not sponsoring an org + await sutProvider.GetDependency() + .Received(1) + .DeleteManyAsync(Arg.Is>(toDeleteIds => + toDeleteIds.Count() == 1 && toDeleteIds.ElementAt(0) == canDeleteSponsorship.Id)); + } - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertManyAsync(default); + [Theory] + [BitAutoData] + public async Task SyncOrganization_BadData_DoesNotSave(SutProvider sutProvider, + Organization sponsoringOrganization, OrganizationSponsorship badOrganizationSponsorship) + { + sponsoringOrganization.PlanType = PlanType.EnterpriseAnnually; - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); - } + badOrganizationSponsorship.ToDelete = true; + badOrganizationSponsorship.LastSyncDate = null; - [Theory] - [BitAutoData] - public async Task SyncOrganization_OrgDisabledForFourMonths_DoesNotSave(SutProvider sutProvider, - Organization sponsoringOrganization, OrganizationSponsorship organizationSponsorship) + sutProvider.GetDependency() + .GetManyBySponsoringOrganizationAsync(sponsoringOrganization.Id) + .Returns(new List()); + + var (syncData, toEmailSponsorships) = await sutProvider.Sut.SyncOrganization(sponsoringOrganization, new[] { - sponsoringOrganization.PlanType = PlanType.EnterpriseAnnually; - sponsoringOrganization.Enabled = false; - sponsoringOrganization.ExpirationDate = DateTime.UtcNow.AddDays(-120); + new OrganizationSponsorshipData(badOrganizationSponsorship), + }); - organizationSponsorship.ToDelete = false; + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertManyAsync(default); - sutProvider.GetDependency() - .GetManyBySponsoringOrganizationAsync(sponsoringOrganization.Id) - .Returns(new List()); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + } - var (syncData, toEmailSponsorships) = await sutProvider.Sut.SyncOrganization(sponsoringOrganization, new[] - { - new OrganizationSponsorshipData(organizationSponsorship), - }); + [Theory] + [BitAutoData] + public async Task SyncOrganization_OrgDisabledForFourMonths_DoesNotSave(SutProvider sutProvider, + Organization sponsoringOrganization, OrganizationSponsorship organizationSponsorship) + { + sponsoringOrganization.PlanType = PlanType.EnterpriseAnnually; + sponsoringOrganization.Enabled = false; + sponsoringOrganization.ExpirationDate = DateTime.UtcNow.AddDays(-120); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertManyAsync(default); + organizationSponsorship.ToDelete = false; - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); - } + sutProvider.GetDependency() + .GetManyBySponsoringOrganizationAsync(sponsoringOrganization.Id) + .Returns(new List()); + + var (syncData, toEmailSponsorships) = await sutProvider.Sut.SyncOrganization(sponsoringOrganization, new[] + { + new OrganizationSponsorshipData(organizationSponsorship), + }); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertManyAsync(default); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/OrganizationSponsorshipRenewCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/OrganizationSponsorshipRenewCommandTests.cs index ca89199a4..b85a3f234 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/OrganizationSponsorshipRenewCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/OrganizationSponsorshipRenewCommandTests.cs @@ -6,22 +6,21 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; + +[SutProviderCustomize] +public class OrganizationSponsorshipRenewCommandTests { - [SutProviderCustomize] - public class OrganizationSponsorshipRenewCommandTests + [Theory] + [BitAutoData] + public async Task UpdateExpirationDate_UpdatesValidUntil(OrganizationSponsorship sponsorship, DateTime expireDate, + SutProvider sutProvider) { - [Theory] - [BitAutoData] - public async Task UpdateExpirationDate_UpdatesValidUntil(OrganizationSponsorship sponsorship, DateTime expireDate, - SutProvider sutProvider) - { - sutProvider.GetDependency().GetBySponsoredOrganizationIdAsync(sponsorship.SponsoredOrganizationId.Value).Returns(sponsorship); + sutProvider.GetDependency().GetBySponsoredOrganizationIdAsync(sponsorship.SponsoredOrganizationId.Value).Returns(sponsorship); - await sutProvider.Sut.UpdateExpirationDateAsync(sponsorship.SponsoredOrganizationId.Value, expireDate); + await sutProvider.Sut.UpdateExpirationDateAsync(sponsorship.SponsoredOrganizationId.Value, expireDate); - await sutProvider.GetDependency().Received(1) - .UpsertAsync(sponsorship); - } + await sutProvider.GetDependency().Received(1) + .UpsertAsync(sponsorship); } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/RemoveSponsorshipCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/RemoveSponsorshipCommandTests.cs index a3ee0a7cd..29adcb486 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/RemoveSponsorshipCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/RemoveSponsorshipCommandTests.cs @@ -6,38 +6,37 @@ using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; + +[SutProviderCustomize] +[OrganizationSponsorshipCustomize] +public class RemoveSponsorshipCommandTests : CancelSponsorshipCommandTestsBase { - [SutProviderCustomize] - [OrganizationSponsorshipCustomize] - public class RemoveSponsorshipCommandTests : CancelSponsorshipCommandTestsBase + [Theory] + [BitAutoData] + public async Task RemoveSponsorship_SponsoredOrgNull_ThrowsBadRequest(OrganizationSponsorship sponsorship, + SutProvider sutProvider) { - [Theory] - [BitAutoData] - public async Task RemoveSponsorship_SponsoredOrgNull_ThrowsBadRequest(OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - sponsorship.SponsoredOrganizationId = null; + sponsorship.SponsoredOrganizationId = null; - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.RemoveSponsorshipAsync(sponsorship)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RemoveSponsorshipAsync(sponsorship)); - Assert.Contains("The requested organization is not currently being sponsored.", exception.Message); - Assert.False(sponsorship.ToDelete); - await AssertDidNotDeleteSponsorshipAsync(sutProvider); - await AssertDidNotUpdateSponsorshipAsync(sutProvider); - } + Assert.Contains("The requested organization is not currently being sponsored.", exception.Message); + Assert.False(sponsorship.ToDelete); + await AssertDidNotDeleteSponsorshipAsync(sutProvider); + await AssertDidNotUpdateSponsorshipAsync(sutProvider); + } - [Theory] - [BitAutoData] - public async Task RemoveSponsorship_SponsorshipNotFound_ThrowsBadRequest(SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.RemoveSponsorshipAsync(null)); + [Theory] + [BitAutoData] + public async Task RemoveSponsorship_SponsorshipNotFound_ThrowsBadRequest(SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RemoveSponsorshipAsync(null)); - Assert.Contains("The requested organization is not currently being sponsored.", exception.Message); - await AssertDidNotDeleteSponsorshipAsync(sutProvider); - await AssertDidNotUpdateSponsorshipAsync(sutProvider); - } + Assert.Contains("The requested organization is not currently being sponsored.", exception.Message); + await AssertDidNotDeleteSponsorshipAsync(sutProvider); + await AssertDidNotUpdateSponsorshipAsync(sutProvider); } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SendSponsorshipOfferCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SendSponsorshipOfferCommandTests.cs index 15377d7fe..f4f8a2cf4 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SendSponsorshipOfferCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SendSponsorshipOfferCommandTests.cs @@ -10,115 +10,114 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; + +[SutProviderCustomize] +[OrganizationSponsorshipCustomize] +public class SendSponsorshipOfferCommandTests : FamiliesForEnterpriseTestsBase { - [SutProviderCustomize] - [OrganizationSponsorshipCustomize] - public class SendSponsorshipOfferCommandTests : FamiliesForEnterpriseTestsBase + [Theory] + [BitAutoData] + public async Task SendSponsorshipOffer_SendSponsorshipOfferAsync_ExistingAccount_Success(OrganizationSponsorship sponsorship, string sponsoringOrgName, User user, SutProvider sutProvider) { - [Theory] - [BitAutoData] - public async Task SendSponsorshipOffer_SendSponsorshipOfferAsync_ExistingAccount_Success(OrganizationSponsorship sponsorship, string sponsoringOrgName, User user, SutProvider sutProvider) - { - sutProvider.GetDependency().GetByEmailAsync(sponsorship.OfferedToEmail).Returns(user); + sutProvider.GetDependency().GetByEmailAsync(sponsorship.OfferedToEmail).Returns(user); - await sutProvider.Sut.SendSponsorshipOfferAsync(sponsorship, sponsoringOrgName); + await sutProvider.Sut.SendSponsorshipOfferAsync(sponsorship, sponsoringOrgName); - await sutProvider.GetDependency().Received(1).SendFamiliesForEnterpriseOfferEmailAsync(sponsoringOrgName, sponsorship.OfferedToEmail, true, Arg.Any()); - } + await sutProvider.GetDependency().Received(1).SendFamiliesForEnterpriseOfferEmailAsync(sponsoringOrgName, sponsorship.OfferedToEmail, true, Arg.Any()); + } - [Theory] - [BitAutoData] - public async Task SendSponsorshipOffer_SendSponsorshipOfferAsync_NewAccount_Success(OrganizationSponsorship sponsorship, string sponsoringOrgName, SutProvider sutProvider) - { - sutProvider.GetDependency().GetByEmailAsync(sponsorship.OfferedToEmail).Returns((User)null); + [Theory] + [BitAutoData] + public async Task SendSponsorshipOffer_SendSponsorshipOfferAsync_NewAccount_Success(OrganizationSponsorship sponsorship, string sponsoringOrgName, SutProvider sutProvider) + { + sutProvider.GetDependency().GetByEmailAsync(sponsorship.OfferedToEmail).Returns((User)null); - await sutProvider.Sut.SendSponsorshipOfferAsync(sponsorship, sponsoringOrgName); + await sutProvider.Sut.SendSponsorshipOfferAsync(sponsorship, sponsoringOrgName); - await sutProvider.GetDependency().Received(1).SendFamiliesForEnterpriseOfferEmailAsync(sponsoringOrgName, sponsorship.OfferedToEmail, false, Arg.Any()); - } + await sutProvider.GetDependency().Received(1).SendFamiliesForEnterpriseOfferEmailAsync(sponsoringOrgName, sponsorship.OfferedToEmail, false, Arg.Any()); + } - [Theory] - [BitAutoData] - public async Task ResendSponsorshipOffer_SponsoringOrgNotFound_ThrowsBadRequest( - OrganizationUser orgUser, OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SendSponsorshipOfferAsync(null, orgUser, sponsorship)); + [Theory] + [BitAutoData] + public async Task ResendSponsorshipOffer_SponsoringOrgNotFound_ThrowsBadRequest( + OrganizationUser orgUser, OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SendSponsorshipOfferAsync(null, orgUser, sponsorship)); - Assert.Contains("Cannot find the requested sponsoring organization.", exception.Message); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); - } + Assert.Contains("Cannot find the requested sponsoring organization.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); + } - [Theory] - [BitAutoData] - public async Task ResendSponsorshipOffer_SponsoringOrgUserNotFound_ThrowsBadRequest(Organization org, - OrganizationSponsorship sponsorship, SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SendSponsorshipOfferAsync(org, null, sponsorship)); + [Theory] + [BitAutoData] + public async Task ResendSponsorshipOffer_SponsoringOrgUserNotFound_ThrowsBadRequest(Organization org, + OrganizationSponsorship sponsorship, SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SendSponsorshipOfferAsync(org, null, sponsorship)); - Assert.Contains("Only confirmed users can sponsor other organizations.", exception.Message); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); - } + Assert.Contains("Only confirmed users can sponsor other organizations.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); + } - [Theory] - [BitAutoData] - [BitMemberAutoData(nameof(NonConfirmedOrganizationUsersStatuses))] - public async Task ResendSponsorshipOffer_SponsoringOrgUserNotConfirmed_ThrowsBadRequest(OrganizationUserStatusType status, - Organization org, OrganizationUser orgUser, OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - orgUser.Status = status; + [Theory] + [BitAutoData] + [BitMemberAutoData(nameof(NonConfirmedOrganizationUsersStatuses))] + public async Task ResendSponsorshipOffer_SponsoringOrgUserNotConfirmed_ThrowsBadRequest(OrganizationUserStatusType status, + Organization org, OrganizationUser orgUser, OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + orgUser.Status = status; - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SendSponsorshipOfferAsync(org, orgUser, sponsorship)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SendSponsorshipOfferAsync(org, orgUser, sponsorship)); - Assert.Contains("Only confirmed users can sponsor other organizations.", exception.Message); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); - } + Assert.Contains("Only confirmed users can sponsor other organizations.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); + } - [Theory] - [BitAutoData] - public async Task ResendSponsorshipOffer_SponsorshipNotFound_ThrowsBadRequest(Organization org, - OrganizationUser orgUser, - SutProvider sutProvider) - { - orgUser.Status = OrganizationUserStatusType.Confirmed; + [Theory] + [BitAutoData] + public async Task ResendSponsorshipOffer_SponsorshipNotFound_ThrowsBadRequest(Organization org, + OrganizationUser orgUser, + SutProvider sutProvider) + { + orgUser.Status = OrganizationUserStatusType.Confirmed; - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SendSponsorshipOfferAsync(org, orgUser, null)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SendSponsorshipOfferAsync(org, orgUser, null)); - Assert.Contains("Cannot find an outstanding sponsorship offer for this organization.", exception.Message); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); - } + Assert.Contains("Cannot find an outstanding sponsorship offer for this organization.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); + } - [Theory] - [BitAutoData] - public async Task ResendSponsorshipOffer_NoOfferToEmail_ThrowsBadRequest(Organization org, - OrganizationUser orgUser, OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - orgUser.Status = OrganizationUserStatusType.Confirmed; - sponsorship.OfferedToEmail = null; + [Theory] + [BitAutoData] + public async Task ResendSponsorshipOffer_NoOfferToEmail_ThrowsBadRequest(Organization org, + OrganizationUser orgUser, OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + orgUser.Status = OrganizationUserStatusType.Confirmed; + sponsorship.OfferedToEmail = null; - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SendSponsorshipOfferAsync(org, orgUser, sponsorship)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SendSponsorshipOfferAsync(org, orgUser, sponsorship)); - Assert.Contains("Cannot find an outstanding sponsorship offer for this organization.", exception.Message); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); - } + Assert.Contains("Cannot find an outstanding sponsorship offer for this organization.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SendFamiliesForEnterpriseOfferEmailAsync(default, default, default, default); } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs index 358e4f007..5776e3e84 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs @@ -10,86 +10,85 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; + +[SutProviderCustomize] +[OrganizationSponsorshipCustomize] +public class SetUpSponsorshipCommandTests : FamiliesForEnterpriseTestsBase { - [SutProviderCustomize] - [OrganizationSponsorshipCustomize] - public class SetUpSponsorshipCommandTests : FamiliesForEnterpriseTestsBase + [Theory] + [BitAutoData] + public async Task SetUpSponsorship_SponsorshipNotFound_ThrowsBadRequest(Organization org, + SutProvider sutProvider) { - [Theory] - [BitAutoData] - public async Task SetUpSponsorship_SponsorshipNotFound_ThrowsBadRequest(Organization org, - SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SetUpSponsorshipAsync(null, org)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SetUpSponsorshipAsync(null, org)); - Assert.Contains("No unredeemed sponsorship offer exists for you.", exception.Message); - await AssertDidNotSetUpAsync(sutProvider); - } + Assert.Contains("No unredeemed sponsorship offer exists for you.", exception.Message); + await AssertDidNotSetUpAsync(sutProvider); + } - [Theory] - [BitAutoData] - public async Task SetUpSponsorship_OrgAlreadySponsored_ThrowsBadRequest(Organization org, - OrganizationSponsorship sponsorship, OrganizationSponsorship existingSponsorship, - SutProvider sutProvider) - { - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(org.Id).Returns(existingSponsorship); + [Theory] + [BitAutoData] + public async Task SetUpSponsorship_OrgAlreadySponsored_ThrowsBadRequest(Organization org, + OrganizationSponsorship sponsorship, OrganizationSponsorship existingSponsorship, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(org.Id).Returns(existingSponsorship); - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SetUpSponsorshipAsync(sponsorship, org)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SetUpSponsorshipAsync(sponsorship, org)); - Assert.Contains("Cannot redeem a sponsorship offer for an organization that is already sponsored. Revoke existing sponsorship first.", exception.Message); - await AssertDidNotSetUpAsync(sutProvider); - } + Assert.Contains("Cannot redeem a sponsorship offer for an organization that is already sponsored. Revoke existing sponsorship first.", exception.Message); + await AssertDidNotSetUpAsync(sutProvider); + } - [Theory] - [BitMemberAutoData(nameof(FamiliesPlanTypes))] - public async Task SetUpSponsorship_TooLongSinceLastSync_ThrowsBadRequest(PlanType planType, Organization org, - OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - org.PlanType = planType; - sponsorship.LastSyncDate = DateTime.UtcNow.AddDays(-365); + [Theory] + [BitMemberAutoData(nameof(FamiliesPlanTypes))] + public async Task SetUpSponsorship_TooLongSinceLastSync_ThrowsBadRequest(PlanType planType, Organization org, + OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + org.PlanType = planType; + sponsorship.LastSyncDate = DateTime.UtcNow.AddDays(-365); - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SetUpSponsorshipAsync(sponsorship, org)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SetUpSponsorshipAsync(sponsorship, org)); - Assert.Contains("This sponsorship offer is more than 6 months old and has expired.", exception.Message); - await sutProvider.GetDependency() - .Received(1) - .DeleteAsync(sponsorship); - await AssertDidNotSetUpAsync(sutProvider); - } + Assert.Contains("This sponsorship offer is more than 6 months old and has expired.", exception.Message); + await sutProvider.GetDependency() + .Received(1) + .DeleteAsync(sponsorship); + await AssertDidNotSetUpAsync(sutProvider); + } - [Theory] - [BitMemberAutoData(nameof(NonFamiliesPlanTypes))] - public async Task SetUpSponsorship_OrgNotFamiles_ThrowsBadRequest(PlanType planType, - OrganizationSponsorship sponsorship, Organization org, - SutProvider sutProvider) - { - org.PlanType = planType; - sponsorship.LastSyncDate = DateTime.UtcNow; + [Theory] + [BitMemberAutoData(nameof(NonFamiliesPlanTypes))] + public async Task SetUpSponsorship_OrgNotFamiles_ThrowsBadRequest(PlanType planType, + OrganizationSponsorship sponsorship, Organization org, + SutProvider sutProvider) + { + org.PlanType = planType; + sponsorship.LastSyncDate = DateTime.UtcNow; - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SetUpSponsorshipAsync(sponsorship, org)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SetUpSponsorshipAsync(sponsorship, org)); - Assert.Contains("Can only redeem sponsorship offer on families organizations.", exception.Message); - await AssertDidNotSetUpAsync(sutProvider); - } + Assert.Contains("Can only redeem sponsorship offer on families organizations.", exception.Message); + await AssertDidNotSetUpAsync(sutProvider); + } - private static async Task AssertDidNotSetUpAsync(SutProvider sutProvider) - { - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SponsorOrganizationAsync(default, default); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } + private static async Task AssertDidNotSetUpAsync(SutProvider sutProvider) + { + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SponsorOrganizationAsync(default, default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default); } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateBillingSyncKeyCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateBillingSyncKeyCommandTests.cs index 4b3426a53..9b01e3035 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateBillingSyncKeyCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateBillingSyncKeyCommandTests.cs @@ -8,51 +8,50 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; + +[SutProviderCustomize] +public class ValidateBillingSyncKeyCommandTests { - [SutProviderCustomize] - public class ValidateBillingSyncKeyCommandTests + [Theory] + [BitAutoData] + public async Task ValidateBillingSyncKeyAsync_NullOrganization_Throws(SutProvider sutProvider) { - [Theory] - [BitAutoData] - public async Task ValidateBillingSyncKeyAsync_NullOrganization_Throws(SutProvider sutProvider) - { - await Assert.ThrowsAsync(() => sutProvider.Sut.ValidateBillingSyncKeyAsync(null, null)); - } + await Assert.ThrowsAsync(() => sutProvider.Sut.ValidateBillingSyncKeyAsync(null, null)); + } - [Theory] - [BitAutoData((string)null)] - [BitAutoData("")] - [BitAutoData(" ")] - public async Task ValidateBillingSyncKeyAsync_BadString_ReturnsFalse(string billingSyncKey, SutProvider sutProvider) - { - Assert.False(await sutProvider.Sut.ValidateBillingSyncKeyAsync(new Organization(), billingSyncKey)); - } + [Theory] + [BitAutoData((string)null)] + [BitAutoData("")] + [BitAutoData(" ")] + public async Task ValidateBillingSyncKeyAsync_BadString_ReturnsFalse(string billingSyncKey, SutProvider sutProvider) + { + Assert.False(await sutProvider.Sut.ValidateBillingSyncKeyAsync(new Organization(), billingSyncKey)); + } - [Theory] - [BitAutoData] - public async Task ValidateBillingSyncKeyAsync_KeyEquals_ReturnsTrue(SutProvider sutProvider, - Organization organization, OrganizationApiKey orgApiKey, string billingSyncKey) - { - orgApiKey.ApiKey = billingSyncKey; + [Theory] + [BitAutoData] + public async Task ValidateBillingSyncKeyAsync_KeyEquals_ReturnsTrue(SutProvider sutProvider, + Organization organization, OrganizationApiKey orgApiKey, string billingSyncKey) + { + orgApiKey.ApiKey = billingSyncKey; - sutProvider.GetDependency() - .GetManyByOrganizationIdTypeAsync(organization.Id, OrganizationApiKeyType.BillingSync) - .Returns(new[] { orgApiKey }); + sutProvider.GetDependency() + .GetManyByOrganizationIdTypeAsync(organization.Id, OrganizationApiKeyType.BillingSync) + .Returns(new[] { orgApiKey }); - Assert.True(await sutProvider.Sut.ValidateBillingSyncKeyAsync(organization, billingSyncKey)); - } + Assert.True(await sutProvider.Sut.ValidateBillingSyncKeyAsync(organization, billingSyncKey)); + } - [Theory] - [BitAutoData] - public async Task ValidateBillingSyncKeyAsync_KeyDoesNotEqual_ReturnsFalse(SutProvider sutProvider, - Organization organization, OrganizationApiKey orgApiKey, string billingSyncKey) - { - sutProvider.GetDependency() - .GetManyByOrganizationIdTypeAsync(organization.Id, OrganizationApiKeyType.BillingSync) - .Returns(new[] { orgApiKey }); + [Theory] + [BitAutoData] + public async Task ValidateBillingSyncKeyAsync_KeyDoesNotEqual_ReturnsFalse(SutProvider sutProvider, + Organization organization, OrganizationApiKey orgApiKey, string billingSyncKey) + { + sutProvider.GetDependency() + .GetManyByOrganizationIdTypeAsync(organization.Id, OrganizationApiKeyType.BillingSync) + .Returns(new[] { orgApiKey }); - Assert.False(await sutProvider.Sut.ValidateBillingSyncKeyAsync(organization, billingSyncKey)); - } + Assert.False(await sutProvider.Sut.ValidateBillingSyncKeyAsync(organization, billingSyncKey)); } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateRedemptionTokenCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateRedemptionTokenCommandTests.cs index 9bbaaed1d..65aa4cfb2 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateRedemptionTokenCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateRedemptionTokenCommandTests.cs @@ -9,79 +9,78 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; + +[SutProviderCustomize] +public class ValidateRedemptionTokenCommandTests { - [SutProviderCustomize] - public class ValidateRedemptionTokenCommandTests + [Theory] + [BitAutoData] + public async Task ValidateRedemptionTokenAsync_CannotUnprotect_ReturnsFalse(SutProvider sutProvider, + string encryptedString) { - [Theory] - [BitAutoData] - public async Task ValidateRedemptionTokenAsync_CannotUnprotect_ReturnsFalse(SutProvider sutProvider, - string encryptedString) - { - sutProvider - .GetDependency>() - .TryUnprotect(encryptedString, out _) - .Returns(call => - { - call[1] = null; - return false; - }); + sutProvider + .GetDependency>() + .TryUnprotect(encryptedString, out _) + .Returns(call => + { + call[1] = null; + return false; + }); - var (valid, sponsorship) = await sutProvider.Sut.ValidateRedemptionTokenAsync(encryptedString, null); - Assert.False(valid); - Assert.Null(sponsorship); - } + var (valid, sponsorship) = await sutProvider.Sut.ValidateRedemptionTokenAsync(encryptedString, null); + Assert.False(valid); + Assert.Null(sponsorship); + } - [Theory] - [BitAutoData] - public async Task ValidateRedemptionTokenAsync_NoSponsorship_ReturnsFalse(SutProvider sutProvider, - string encryptedString, OrganizationSponsorshipOfferTokenable tokenable) - { - sutProvider - .GetDependency>() - .TryUnprotect(encryptedString, out _) - .Returns(call => - { - call[1] = tokenable; - return true; - }); + [Theory] + [BitAutoData] + public async Task ValidateRedemptionTokenAsync_NoSponsorship_ReturnsFalse(SutProvider sutProvider, + string encryptedString, OrganizationSponsorshipOfferTokenable tokenable) + { + sutProvider + .GetDependency>() + .TryUnprotect(encryptedString, out _) + .Returns(call => + { + call[1] = tokenable; + return true; + }); - var (valid, sponsorship) = await sutProvider.Sut.ValidateRedemptionTokenAsync(encryptedString, "test@email.com"); - Assert.False(valid); - Assert.Null(sponsorship); - } + var (valid, sponsorship) = await sutProvider.Sut.ValidateRedemptionTokenAsync(encryptedString, "test@email.com"); + Assert.False(valid); + Assert.Null(sponsorship); + } - [Theory] - [BitAutoData] - public async Task ValidateRedemptionTokenAsync_ValidSponsorship_ReturnsFalse(SutProvider sutProvider, - string encryptedString, string email, OrganizationSponsorshipOfferTokenable tokenable) - { - tokenable.Email = email; + [Theory] + [BitAutoData] + public async Task ValidateRedemptionTokenAsync_ValidSponsorship_ReturnsFalse(SutProvider sutProvider, + string encryptedString, string email, OrganizationSponsorshipOfferTokenable tokenable) + { + tokenable.Email = email; - sutProvider - .GetDependency>() - .TryUnprotect(encryptedString, out _) - .Returns(call => - { - call[1] = tokenable; - return true; - }); + sutProvider + .GetDependency>() + .TryUnprotect(encryptedString, out _) + .Returns(call => + { + call[1] = tokenable; + return true; + }); - sutProvider.GetDependency() - .GetByIdAsync(tokenable.Id) - .Returns(new OrganizationSponsorship - { - Id = tokenable.Id, - PlanSponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, - OfferedToEmail = email - }); + sutProvider.GetDependency() + .GetByIdAsync(tokenable.Id) + .Returns(new OrganizationSponsorship + { + Id = tokenable.Id, + PlanSponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, + OfferedToEmail = email + }); - var (valid, sponsorship) = await sutProvider.Sut - .ValidateRedemptionTokenAsync(encryptedString, email); + var (valid, sponsorship) = await sutProvider.Sut + .ValidateRedemptionTokenAsync(encryptedString, email); - Assert.True(valid); - Assert.NotNull(sponsorship); - } + Assert.True(valid); + Assert.NotNull(sponsorship); } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommandTests.cs index f1beb64f3..a187f5b29 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommandTests.cs @@ -8,247 +8,246 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; + +[SutProviderCustomize] +[OrganizationSponsorshipCustomize] +public class ValidateSponsorshipCommandTests : CancelSponsorshipCommandTestsBase { - [SutProviderCustomize] - [OrganizationSponsorshipCustomize] - public class ValidateSponsorshipCommandTests : CancelSponsorshipCommandTestsBase + [Theory] + [BitAutoData] + public async Task ValidateSponsorshipAsync_NoSponsoredOrg_EarlyReturn(Guid sponsoredOrgId, + SutProvider sutProvider) { - [Theory] - [BitAutoData] - public async Task ValidateSponsorshipAsync_NoSponsoredOrg_EarlyReturn(Guid sponsoredOrgId, - SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(sponsoredOrgId).Returns((Organization)null); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrgId).Returns((Organization)null); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrgId); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrgId); - Assert.False(result); - await AssertDidNotRemoveSponsoredPaymentAsync(sutProvider); - await AssertDidNotDeleteSponsorshipAsync(sutProvider); - } + Assert.False(result); + await AssertDidNotRemoveSponsoredPaymentAsync(sutProvider); + await AssertDidNotDeleteSponsorshipAsync(sutProvider); + } - [Theory] - [BitAutoData] - public async Task ValidateSponsorshipAsync_NoExistingSponsorship_UpdatesStripePlan(Organization sponsoredOrg, - SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + [Theory] + [BitAutoData] + public async Task ValidateSponsorshipAsync_NoExistingSponsorship_UpdatesStripePlan(Organization sponsoredOrg, + SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.False(result); - await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, null, sutProvider); - } + Assert.False(result); + await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, null, sutProvider); + } - [Theory] - [BitAutoData] - public async Task ValidateSponsorshipAsync_SponsoringOrgDefault_UpdatesStripePlan(Organization sponsoredOrg, - OrganizationSponsorship existingSponsorship, SutProvider sutProvider) - { - existingSponsorship.SponsoringOrganizationId = default; + [Theory] + [BitAutoData] + public async Task ValidateSponsorshipAsync_SponsoringOrgDefault_UpdatesStripePlan(Organization sponsoredOrg, + OrganizationSponsorship existingSponsorship, SutProvider sutProvider) + { + existingSponsorship.SponsoringOrganizationId = default; - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.False(result); - await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); - await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); - } + Assert.False(result); + await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); + await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); + } - [Theory] - [BitAutoData] - public async Task ValidateSponsorshipAsync_SponsoringOrgUserDefault_UpdatesStripePlan(Organization sponsoredOrg, - OrganizationSponsorship existingSponsorship, SutProvider sutProvider) - { - existingSponsorship.SponsoringOrganizationUserId = default; + [Theory] + [BitAutoData] + public async Task ValidateSponsorshipAsync_SponsoringOrgUserDefault_UpdatesStripePlan(Organization sponsoredOrg, + OrganizationSponsorship existingSponsorship, SutProvider sutProvider) + { + existingSponsorship.SponsoringOrganizationUserId = default; - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.False(result); - await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); - await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); - } + Assert.False(result); + await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); + await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); + } - [Theory] - [BitAutoData] - public async Task ValidateSponsorshipAsync_SponsorshipTypeNull_UpdatesStripePlan(Organization sponsoredOrg, - OrganizationSponsorship existingSponsorship, SutProvider sutProvider) - { - existingSponsorship.PlanSponsorshipType = null; + [Theory] + [BitAutoData] + public async Task ValidateSponsorshipAsync_SponsorshipTypeNull_UpdatesStripePlan(Organization sponsoredOrg, + OrganizationSponsorship existingSponsorship, SutProvider sutProvider) + { + existingSponsorship.PlanSponsorshipType = null; - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.False(result); - await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); - await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); - } + Assert.False(result); + await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); + await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); + } - [Theory] - [BitAutoData] - public async Task ValidateSponsorshipAsync_SponsoringOrgNotFound_UpdatesStripePlan(Organization sponsoredOrg, - OrganizationSponsorship existingSponsorship, SutProvider sutProvider) - { - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + [Theory] + [BitAutoData] + public async Task ValidateSponsorshipAsync_SponsoringOrgNotFound_UpdatesStripePlan(Organization sponsoredOrg, + OrganizationSponsorship existingSponsorship, SutProvider sutProvider) + { + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.False(result); - await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); - await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); - } + Assert.False(result); + await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); + await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); + } - [Theory] - [BitMemberAutoData(nameof(NonEnterprisePlanTypes))] - public async Task ValidateSponsorshipAsync_SponsoringOrgNotEnterprise_UpdatesStripePlan(PlanType planType, - Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, - SutProvider sutProvider) - { - sponsoringOrg.PlanType = planType; - existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; + [Theory] + [BitMemberAutoData(nameof(NonEnterprisePlanTypes))] + public async Task ValidateSponsorshipAsync_SponsoringOrgNotEnterprise_UpdatesStripePlan(PlanType planType, + Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, + SutProvider sutProvider) + { + sponsoringOrg.PlanType = planType; + existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.False(result); - await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); - await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); - } + Assert.False(result); + await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); + await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); + } - [Theory] - [BitMemberAutoData(nameof(EnterprisePlanTypes))] - public async Task ValidateSponsorshipAsync_SponsoringOrgDisabledLongerThanGrace_UpdatesStripePlan(PlanType planType, - Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, - SutProvider sutProvider) - { - sponsoringOrg.PlanType = planType; - sponsoringOrg.Enabled = false; - sponsoringOrg.ExpirationDate = DateTime.UtcNow.AddDays(-100); - existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; + [Theory] + [BitMemberAutoData(nameof(EnterprisePlanTypes))] + public async Task ValidateSponsorshipAsync_SponsoringOrgDisabledLongerThanGrace_UpdatesStripePlan(PlanType planType, + Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, + SutProvider sutProvider) + { + sponsoringOrg.PlanType = planType; + sponsoringOrg.Enabled = false; + sponsoringOrg.ExpirationDate = DateTime.UtcNow.AddDays(-100); + existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.False(result); - await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); - await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); - } + Assert.False(result); + await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); + await AssertDeletedSponsorshipAsync(existingSponsorship, sutProvider); + } - [Theory] - [OrganizationSponsorshipCustomize(ToDelete = true)] - [BitMemberAutoData(nameof(EnterprisePlanTypes))] - public async Task ValidateSponsorshipAsync_ToDeleteSponsorship_IsInvalid(PlanType planType, - Organization sponsoredOrg, OrganizationSponsorship sponsorship, Organization sponsoringOrg, - SutProvider sutProvider) - { - sponsoringOrg.PlanType = planType; - sponsoringOrg.Enabled = true; - sponsorship.SponsoringOrganizationId = sponsoringOrg.Id; + [Theory] + [OrganizationSponsorshipCustomize(ToDelete = true)] + [BitMemberAutoData(nameof(EnterprisePlanTypes))] + public async Task ValidateSponsorshipAsync_ToDeleteSponsorship_IsInvalid(PlanType planType, + Organization sponsoredOrg, OrganizationSponsorship sponsorship, Organization sponsoringOrg, + SutProvider sutProvider) + { + sponsoringOrg.PlanType = planType; + sponsoringOrg.Enabled = true; + sponsorship.SponsoringOrganizationId = sponsoringOrg.Id; - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(sponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(sponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.False(result); + Assert.False(result); - await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, sponsorship, sutProvider); - await AssertDeletedSponsorshipAsync(sponsorship, sutProvider); - } + await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, sponsorship, sutProvider); + await AssertDeletedSponsorshipAsync(sponsorship, sutProvider); + } - [Theory] - [BitMemberAutoData(nameof(EnterprisePlanTypes))] - public async Task ValidateSponsorshipAsync_SponsoringOrgDisabledUnknownTime_UpdatesStripePlan(PlanType planType, - Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, - SutProvider sutProvider) - { - sponsoringOrg.PlanType = planType; - sponsoringOrg.Enabled = false; - sponsoringOrg.ExpirationDate = null; - existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; + [Theory] + [BitMemberAutoData(nameof(EnterprisePlanTypes))] + public async Task ValidateSponsorshipAsync_SponsoringOrgDisabledUnknownTime_UpdatesStripePlan(PlanType planType, + Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, + SutProvider sutProvider) + { + sponsoringOrg.PlanType = planType; + sponsoringOrg.Enabled = false; + sponsoringOrg.ExpirationDate = null; + existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.False(result); - await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); - await AssertRemovedSponsorshipAsync(existingSponsorship, sutProvider); - } + Assert.False(result); + await AssertRemovedSponsoredPaymentAsync(sponsoredOrg, existingSponsorship, sutProvider); + await AssertRemovedSponsorshipAsync(existingSponsorship, sutProvider); + } - [Theory] - [BitMemberAutoData(nameof(EnterprisePlanTypes))] - public async Task ValidateSponsorshipAsync_SponsoringOrgDisabledLessThanGrace_Valid(PlanType planType, - Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, - SutProvider sutProvider) - { - sponsoringOrg.PlanType = planType; - sponsoringOrg.Enabled = true; - sponsoringOrg.ExpirationDate = DateTime.UtcNow.AddDays(-1); - existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; + [Theory] + [BitMemberAutoData(nameof(EnterprisePlanTypes))] + public async Task ValidateSponsorshipAsync_SponsoringOrgDisabledLessThanGrace_Valid(PlanType planType, + Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, + SutProvider sutProvider) + { + sponsoringOrg.PlanType = planType; + sponsoringOrg.Enabled = true; + sponsoringOrg.ExpirationDate = DateTime.UtcNow.AddDays(-1); + existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.True(result); + Assert.True(result); - await AssertDidNotRemoveSponsoredPaymentAsync(sutProvider); - await AssertDidNotRemoveSponsorshipAsync(sutProvider); - } + await AssertDidNotRemoveSponsoredPaymentAsync(sutProvider); + await AssertDidNotRemoveSponsorshipAsync(sutProvider); + } - [Theory] - [BitMemberAutoData(nameof(EnterprisePlanTypes))] - public async Task ValidateSponsorshipAsync_Valid(PlanType planType, - Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, - SutProvider sutProvider) - { - sponsoringOrg.PlanType = planType; - sponsoringOrg.Enabled = true; - existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; + [Theory] + [BitMemberAutoData(nameof(EnterprisePlanTypes))] + public async Task ValidateSponsorshipAsync_Valid(PlanType planType, + Organization sponsoredOrg, OrganizationSponsorship existingSponsorship, Organization sponsoringOrg, + SutProvider sutProvider) + { + sponsoringOrg.PlanType = planType; + sponsoringOrg.Enabled = true; + existingSponsorship.SponsoringOrganizationId = sponsoringOrg.Id; - sutProvider.GetDependency() - .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); - sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); - sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); + sutProvider.GetDependency() + .GetBySponsoredOrganizationIdAsync(sponsoredOrg.Id).Returns(existingSponsorship); + sutProvider.GetDependency().GetByIdAsync(sponsoredOrg.Id).Returns(sponsoredOrg); + sutProvider.GetDependency().GetByIdAsync(sponsoringOrg.Id).Returns(sponsoringOrg); - var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); + var result = await sutProvider.Sut.ValidateSponsorshipAsync(sponsoredOrg.Id); - Assert.True(result); + Assert.True(result); - await AssertDidNotRemoveSponsoredPaymentAsync(sutProvider); - await AssertDidNotDeleteSponsorshipAsync(sutProvider); - } + await AssertDidNotRemoveSponsoredPaymentAsync(sutProvider); + await AssertDidNotDeleteSponsorshipAsync(sutProvider); } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommandTests.cs index b4e014d06..4eb2779d9 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommandTests.cs @@ -13,167 +13,166 @@ using NSubstitute.ExceptionExtensions; using NSubstitute.ReturnsExtensions; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; + +[SutProviderCustomize] +public class CreateSponsorshipCommandTests : FamiliesForEnterpriseTestsBase { - [SutProviderCustomize] - public class CreateSponsorshipCommandTests : FamiliesForEnterpriseTestsBase + private bool SponsorshipValidator(OrganizationSponsorship sponsorship, OrganizationSponsorship expectedSponsorship) { - private bool SponsorshipValidator(OrganizationSponsorship sponsorship, OrganizationSponsorship expectedSponsorship) + try { - try - { - AssertHelper.AssertPropertyEqual(sponsorship, expectedSponsorship, nameof(OrganizationSponsorship.Id)); - return true; - } - catch - { - return false; - } + AssertHelper.AssertPropertyEqual(sponsorship, expectedSponsorship, nameof(OrganizationSponsorship.Id)); + return true; } - - [Theory, BitAutoData] - public async Task CreateSponsorship_OfferedToNotFound_ThrowsBadRequest(OrganizationUser orgUser, SutProvider sutProvider) + catch { - sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).ReturnsNull(); - - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.CreateSponsorshipAsync(null, orgUser, PlanSponsorshipType.FamiliesForEnterprise, default, default)); - - Assert.Contains("Cannot offer a Families Organization Sponsorship to yourself. Choose a different email.", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CreateAsync(default); - } - - [Theory, BitAutoData] - public async Task CreateSponsorship_OfferedToSelf_ThrowsBadRequest(OrganizationUser orgUser, string sponsoredEmail, User user, SutProvider sutProvider) - { - user.Email = sponsoredEmail; - sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).Returns(user); - - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.CreateSponsorshipAsync(null, orgUser, PlanSponsorshipType.FamiliesForEnterprise, sponsoredEmail, default)); - - Assert.Contains("Cannot offer a Families Organization Sponsorship to yourself. Choose a different email.", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CreateAsync(default); - } - - [Theory, BitMemberAutoData(nameof(NonEnterprisePlanTypes))] - public async Task CreateSponsorship_BadSponsoringOrgPlan_ThrowsBadRequest(PlanType sponsoringOrgPlan, - Organization org, OrganizationUser orgUser, User user, SutProvider sutProvider) - { - org.PlanType = sponsoringOrgPlan; - orgUser.Status = OrganizationUserStatusType.Confirmed; - - sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).Returns(user); - - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.CreateSponsorshipAsync(org, orgUser, PlanSponsorshipType.FamiliesForEnterprise, default, default)); - - Assert.Contains("Specified Organization cannot sponsor other organizations.", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CreateAsync(default); - } - - [Theory] - [BitMemberAutoData(nameof(NonConfirmedOrganizationUsersStatuses))] - public async Task CreateSponsorship_BadSponsoringUserStatus_ThrowsBadRequest( - OrganizationUserStatusType statusType, Organization org, OrganizationUser orgUser, User user, - SutProvider sutProvider) - { - org.PlanType = PlanType.EnterpriseAnnually; - orgUser.Status = statusType; - - sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).Returns(user); - - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.CreateSponsorshipAsync(org, orgUser, PlanSponsorshipType.FamiliesForEnterprise, default, default)); - - Assert.Contains("Only confirmed users can sponsor other organizations.", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CreateAsync(default); - } - - [Theory] - [OrganizationSponsorshipCustomize] - [BitAutoData] - public async Task CreateSponsorship_AlreadySponsoring_Throws(Organization org, - OrganizationUser orgUser, User user, OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - org.PlanType = PlanType.EnterpriseAnnually; - orgUser.Status = OrganizationUserStatusType.Confirmed; - - sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).Returns(user); - sutProvider.GetDependency() - .GetBySponsoringOrganizationUserIdAsync(orgUser.Id).Returns(sponsorship); - - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.CreateSponsorshipAsync(org, orgUser, sponsorship.PlanSponsorshipType.Value, default, default)); - - Assert.Contains("Can only sponsor one organization per Organization User.", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CreateAsync(default); - } - - [Theory] - [BitAutoData] - public async Task CreateSponsorship_CreatesSponsorship(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, User user, - string sponsoredEmail, string friendlyName, Guid sponsorshipId, SutProvider sutProvider) - { - sponsoringOrg.PlanType = PlanType.EnterpriseAnnually; - sponsoringOrgUser.Status = OrganizationUserStatusType.Confirmed; - - sutProvider.GetDependency().GetUserByIdAsync(sponsoringOrgUser.UserId.Value).Returns(user); - sutProvider.GetDependency().WhenForAnyArgs(x => x.UpsertAsync(default)).Do(callInfo => - { - var sponsorship = callInfo.Arg(); - sponsorship.Id = sponsorshipId; - }); - - - await sutProvider.Sut.CreateSponsorshipAsync(sponsoringOrg, sponsoringOrgUser, - PlanSponsorshipType.FamiliesForEnterprise, sponsoredEmail, friendlyName); - - var expectedSponsorship = new OrganizationSponsorship - { - Id = sponsorshipId, - SponsoringOrganizationId = sponsoringOrg.Id, - SponsoringOrganizationUserId = sponsoringOrgUser.Id, - FriendlyName = friendlyName, - OfferedToEmail = sponsoredEmail, - PlanSponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, - }; - - await sutProvider.GetDependency().Received(1) - .UpsertAsync(Arg.Is(s => SponsorshipValidator(s, expectedSponsorship))); - } - - [Theory] - [BitAutoData] - public async Task CreateSponsorship_CreateSponsorshipThrows_RevertsDatabase(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, User user, - string sponsoredEmail, string friendlyName, SutProvider sutProvider) - { - sponsoringOrg.PlanType = PlanType.EnterpriseAnnually; - sponsoringOrgUser.Status = OrganizationUserStatusType.Confirmed; - - var expectedException = new Exception(); - OrganizationSponsorship createdSponsorship = null; - sutProvider.GetDependency().GetUserByIdAsync(sponsoringOrgUser.UserId.Value).Returns(user); - sutProvider.GetDependency().UpsertAsync(default).ThrowsForAnyArgs(callInfo => - { - createdSponsorship = callInfo.ArgAt(0); - createdSponsorship.Id = Guid.NewGuid(); - return expectedException; - }); - - var actualException = await Assert.ThrowsAsync(() => - sutProvider.Sut.CreateSponsorshipAsync(sponsoringOrg, sponsoringOrgUser, - PlanSponsorshipType.FamiliesForEnterprise, sponsoredEmail, friendlyName)); - Assert.Same(expectedException, actualException); - - await sutProvider.GetDependency().Received(1) - .DeleteAsync(createdSponsorship); + return false; } } + + [Theory, BitAutoData] + public async Task CreateSponsorship_OfferedToNotFound_ThrowsBadRequest(OrganizationUser orgUser, SutProvider sutProvider) + { + sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).ReturnsNull(); + + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.CreateSponsorshipAsync(null, orgUser, PlanSponsorshipType.FamiliesForEnterprise, default, default)); + + Assert.Contains("Cannot offer a Families Organization Sponsorship to yourself. Choose a different email.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CreateAsync(default); + } + + [Theory, BitAutoData] + public async Task CreateSponsorship_OfferedToSelf_ThrowsBadRequest(OrganizationUser orgUser, string sponsoredEmail, User user, SutProvider sutProvider) + { + user.Email = sponsoredEmail; + sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).Returns(user); + + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.CreateSponsorshipAsync(null, orgUser, PlanSponsorshipType.FamiliesForEnterprise, sponsoredEmail, default)); + + Assert.Contains("Cannot offer a Families Organization Sponsorship to yourself. Choose a different email.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CreateAsync(default); + } + + [Theory, BitMemberAutoData(nameof(NonEnterprisePlanTypes))] + public async Task CreateSponsorship_BadSponsoringOrgPlan_ThrowsBadRequest(PlanType sponsoringOrgPlan, + Organization org, OrganizationUser orgUser, User user, SutProvider sutProvider) + { + org.PlanType = sponsoringOrgPlan; + orgUser.Status = OrganizationUserStatusType.Confirmed; + + sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).Returns(user); + + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.CreateSponsorshipAsync(org, orgUser, PlanSponsorshipType.FamiliesForEnterprise, default, default)); + + Assert.Contains("Specified Organization cannot sponsor other organizations.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CreateAsync(default); + } + + [Theory] + [BitMemberAutoData(nameof(NonConfirmedOrganizationUsersStatuses))] + public async Task CreateSponsorship_BadSponsoringUserStatus_ThrowsBadRequest( + OrganizationUserStatusType statusType, Organization org, OrganizationUser orgUser, User user, + SutProvider sutProvider) + { + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.Status = statusType; + + sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).Returns(user); + + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.CreateSponsorshipAsync(org, orgUser, PlanSponsorshipType.FamiliesForEnterprise, default, default)); + + Assert.Contains("Only confirmed users can sponsor other organizations.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CreateAsync(default); + } + + [Theory] + [OrganizationSponsorshipCustomize] + [BitAutoData] + public async Task CreateSponsorship_AlreadySponsoring_Throws(Organization org, + OrganizationUser orgUser, User user, OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.Status = OrganizationUserStatusType.Confirmed; + + sutProvider.GetDependency().GetUserByIdAsync(orgUser.UserId.Value).Returns(user); + sutProvider.GetDependency() + .GetBySponsoringOrganizationUserIdAsync(orgUser.Id).Returns(sponsorship); + + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.CreateSponsorshipAsync(org, orgUser, sponsorship.PlanSponsorshipType.Value, default, default)); + + Assert.Contains("Can only sponsor one organization per Organization User.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CreateAsync(default); + } + + [Theory] + [BitAutoData] + public async Task CreateSponsorship_CreatesSponsorship(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, User user, + string sponsoredEmail, string friendlyName, Guid sponsorshipId, SutProvider sutProvider) + { + sponsoringOrg.PlanType = PlanType.EnterpriseAnnually; + sponsoringOrgUser.Status = OrganizationUserStatusType.Confirmed; + + sutProvider.GetDependency().GetUserByIdAsync(sponsoringOrgUser.UserId.Value).Returns(user); + sutProvider.GetDependency().WhenForAnyArgs(x => x.UpsertAsync(default)).Do(callInfo => + { + var sponsorship = callInfo.Arg(); + sponsorship.Id = sponsorshipId; + }); + + + await sutProvider.Sut.CreateSponsorshipAsync(sponsoringOrg, sponsoringOrgUser, + PlanSponsorshipType.FamiliesForEnterprise, sponsoredEmail, friendlyName); + + var expectedSponsorship = new OrganizationSponsorship + { + Id = sponsorshipId, + SponsoringOrganizationId = sponsoringOrg.Id, + SponsoringOrganizationUserId = sponsoringOrgUser.Id, + FriendlyName = friendlyName, + OfferedToEmail = sponsoredEmail, + PlanSponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, + }; + + await sutProvider.GetDependency().Received(1) + .UpsertAsync(Arg.Is(s => SponsorshipValidator(s, expectedSponsorship))); + } + + [Theory] + [BitAutoData] + public async Task CreateSponsorship_CreateSponsorshipThrows_RevertsDatabase(Organization sponsoringOrg, OrganizationUser sponsoringOrgUser, User user, + string sponsoredEmail, string friendlyName, SutProvider sutProvider) + { + sponsoringOrg.PlanType = PlanType.EnterpriseAnnually; + sponsoringOrgUser.Status = OrganizationUserStatusType.Confirmed; + + var expectedException = new Exception(); + OrganizationSponsorship createdSponsorship = null; + sutProvider.GetDependency().GetUserByIdAsync(sponsoringOrgUser.UserId.Value).Returns(user); + sutProvider.GetDependency().UpsertAsync(default).ThrowsForAnyArgs(callInfo => + { + createdSponsorship = callInfo.ArgAt(0); + createdSponsorship.Id = Guid.NewGuid(); + return expectedException; + }); + + var actualException = await Assert.ThrowsAsync(() => + sutProvider.Sut.CreateSponsorshipAsync(sponsoringOrg, sponsoringOrgUser, + PlanSponsorshipType.FamiliesForEnterprise, sponsoredEmail, friendlyName)); + Assert.Same(expectedException, actualException); + + await sutProvider.GetDependency().Received(1) + .DeleteAsync(createdSponsorship); + } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/FamiliesForEnterpriseTestsBase.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/FamiliesForEnterpriseTestsBase.cs index 862ae6e80..e49b095d7 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/FamiliesForEnterpriseTestsBase.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/FamiliesForEnterpriseTestsBase.cs @@ -1,25 +1,24 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; + +public abstract class FamiliesForEnterpriseTestsBase { - public abstract class FamiliesForEnterpriseTestsBase - { - public static IEnumerable EnterprisePlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product == ProductType.Enterprise).Select(p => new object[] { p }); + public static IEnumerable EnterprisePlanTypes => + Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product == ProductType.Enterprise).Select(p => new object[] { p }); - public static IEnumerable NonEnterprisePlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product != ProductType.Enterprise).Select(p => new object[] { p }); + public static IEnumerable NonEnterprisePlanTypes => + Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product != ProductType.Enterprise).Select(p => new object[] { p }); - public static IEnumerable FamiliesPlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product == ProductType.Families).Select(p => new object[] { p }); + public static IEnumerable FamiliesPlanTypes => + Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product == ProductType.Families).Select(p => new object[] { p }); - public static IEnumerable NonFamiliesPlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product != ProductType.Families).Select(p => new object[] { p }); + public static IEnumerable NonFamiliesPlanTypes => + Enum.GetValues().Where(p => StaticStore.GetPlan(p).Product != ProductType.Families).Select(p => new object[] { p }); - public static IEnumerable NonConfirmedOrganizationUsersStatuses => - Enum.GetValues() - .Where(s => s != OrganizationUserStatusType.Confirmed) - .Select(s => new object[] { s }); - } + public static IEnumerable NonConfirmedOrganizationUsersStatuses => + Enum.GetValues() + .Where(s => s != OrganizationUserStatusType.Confirmed) + .Select(s => new object[] { s }); } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedRevokeSponsorshipCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedRevokeSponsorshipCommandTests.cs index 6dd913383..7ac1c7128 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedRevokeSponsorshipCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedRevokeSponsorshipCommandTests.cs @@ -6,48 +6,47 @@ using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.SelfHosted +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.SelfHosted; + +[SutProviderCustomize] +[OrganizationSponsorshipCustomize] +public class SelfHostedRevokeSponsorshipCommandTests : CancelSponsorshipCommandTestsBase { - [SutProviderCustomize] - [OrganizationSponsorshipCustomize] - public class SelfHostedRevokeSponsorshipCommandTests : CancelSponsorshipCommandTestsBase + [Theory] + [BitAutoData] + public async Task RevokeSponsorship_NoExistingSponsorship_ThrowsBadRequest( + SutProvider sutProvider) { - [Theory] - [BitAutoData] - public async Task RevokeSponsorship_NoExistingSponsorship_ThrowsBadRequest( - SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.RevokeSponsorshipAsync(null)); + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RevokeSponsorshipAsync(null)); - Assert.Contains("You are not currently sponsoring an organization.", exception.Message); - await AssertDidNotDeleteSponsorshipAsync(sutProvider); - await AssertDidNotUpdateSponsorshipAsync(sutProvider); - } + Assert.Contains("You are not currently sponsoring an organization.", exception.Message); + await AssertDidNotDeleteSponsorshipAsync(sutProvider); + await AssertDidNotUpdateSponsorshipAsync(sutProvider); + } - [Theory] - [BitAutoData] - public async Task RevokeSponsorship_SponsorshipNotSynced_DeletesSponsorship(OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - sponsorship.LastSyncDate = null; + [Theory] + [BitAutoData] + public async Task RevokeSponsorship_SponsorshipNotSynced_DeletesSponsorship(OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + sponsorship.LastSyncDate = null; - await sutProvider.Sut.RevokeSponsorshipAsync(sponsorship); - await AssertDeletedSponsorshipAsync(sponsorship, sutProvider); - } + await sutProvider.Sut.RevokeSponsorshipAsync(sponsorship); + await AssertDeletedSponsorshipAsync(sponsorship, sutProvider); + } - [Theory] - [BitAutoData] - public async Task RevokeSponsorship_SponsorshipSynced_MarksForDeletion(OrganizationSponsorship sponsorship, - SutProvider sutProvider) - { - sponsorship.LastSyncDate = DateTime.UtcNow; + [Theory] + [BitAutoData] + public async Task RevokeSponsorship_SponsorshipSynced_MarksForDeletion(OrganizationSponsorship sponsorship, + SutProvider sutProvider) + { + sponsorship.LastSyncDate = DateTime.UtcNow; - await sutProvider.Sut.RevokeSponsorshipAsync(sponsorship); + await sutProvider.Sut.RevokeSponsorshipAsync(sponsorship); - Assert.True(sponsorship.ToDelete); - await AssertUpdatedSponsorshipAsync(sponsorship, sutProvider); - await AssertDidNotDeleteSponsorshipAsync(sutProvider); - } + Assert.True(sponsorship.ToDelete); + await AssertUpdatedSponsorshipAsync(sponsorship, sutProvider); + await AssertDidNotDeleteSponsorshipAsync(sutProvider); } } diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommandTests.cs index 5c9741f35..5ec93a976 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommandTests.cs @@ -15,174 +15,172 @@ using NSubstitute; using RichardSzalay.MockHttp; using Xunit; -namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.SelfHosted +namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.SelfHosted; + +public class SelfHostedSyncSponsorshipsCommandTests : FamiliesForEnterpriseTestsBase { - public class SelfHostedSyncSponsorshipsCommandTests : FamiliesForEnterpriseTestsBase + public static SutProvider GetSutProvider(bool enableCloudCommunication = true, string identityResponse = null, string apiResponse = null) { + var fixture = new Fixture().WithAutoNSubstitutionsAutoPopulatedProperties(); + fixture.AddMockHttp(); - public static SutProvider GetSutProvider(bool enableCloudCommunication = true, string identityResponse = null, string apiResponse = null) + var settings = fixture.Create(); + settings.SelfHosted = true; + settings.EnableCloudCommunication = enableCloudCommunication; + + var apiUri = fixture.Create(); + var identityUri = fixture.Create(); + settings.Installation.ApiUri.Returns(apiUri.ToString()); + settings.Installation.IdentityUri.Returns(identityUri.ToString()); + + var apiHandler = new MockHttpMessageHandler(); + var identityHandler = new MockHttpMessageHandler(); + var syncUri = string.Concat(apiUri, "organization/sponsorship/sync"); + var tokenUri = string.Concat(identityUri, "connect/token"); + + apiHandler.When(HttpMethod.Post, syncUri) + .Respond("application/json", apiResponse); + identityHandler.When(HttpMethod.Post, tokenUri) + .Respond("application/json", identityResponse ?? "{\"access_token\":\"string\",\"expires_in\":3600,\"token_type\":\"Bearer\",\"scope\":\"string\"}"); + + + var apiHttp = apiHandler.ToHttpClient(); + var identityHttp = identityHandler.ToHttpClient(); + + var mockHttpClientFactory = Substitute.For(); + mockHttpClientFactory.CreateClient(Arg.Is("client")).Returns(apiHttp); + mockHttpClientFactory.CreateClient(Arg.Is("identity")).Returns(identityHttp); + + return new SutProvider(fixture) + .SetDependency(settings) + .SetDependency(mockHttpClientFactory) + .Create(); + } + + [Theory] + [BitAutoData] + public async Task SyncOrganization_BillingSyncKeyDisabled_ThrowsBadRequest( + Guid cloudOrganizationId, OrganizationConnection billingSyncConnection) + { + var sutProvider = GetSutProvider(); + billingSyncConnection.Enabled = false; + billingSyncConnection.SetConfig(new BillingSyncConfig { - var fixture = new Fixture().WithAutoNSubstitutionsAutoPopulatedProperties(); - fixture.AddMockHttp(); + BillingSyncKey = "okslkcslkjf" + }); - var settings = fixture.Create(); - settings.SelfHosted = true; - settings.EnableCloudCommunication = enableCloudCommunication; + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection)); - var apiUri = fixture.Create(); - var identityUri = fixture.Create(); - settings.Installation.ApiUri.Returns(apiUri.ToString()); - settings.Installation.IdentityUri.Returns(identityUri.ToString()); + Assert.Contains($"Billing Sync Key disabled", exception.Message); - var apiHandler = new MockHttpMessageHandler(); - var identityHandler = new MockHttpMessageHandler(); - var syncUri = string.Concat(apiUri, "organization/sponsorship/sync"); - var tokenUri = string.Concat(identityUri, "connect/token"); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertManyAsync(default); + } - apiHandler.When(HttpMethod.Post, syncUri) - .Respond("application/json", apiResponse); - identityHandler.When(HttpMethod.Post, tokenUri) - .Respond("application/json", identityResponse ?? "{\"access_token\":\"string\",\"expires_in\":3600,\"token_type\":\"Bearer\",\"scope\":\"string\"}"); + [Theory] + [BitAutoData] + public async Task SyncOrganization_BillingSyncKeyEmpty_ThrowsBadRequest( + Guid cloudOrganizationId, OrganizationConnection billingSyncConnection) + { + var sutProvider = GetSutProvider(); + billingSyncConnection.Config = ""; + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection)); - var apiHttp = apiHandler.ToHttpClient(); - var identityHttp = identityHandler.ToHttpClient(); + Assert.Contains($"No Billing Sync Key known", exception.Message); - var mockHttpClientFactory = Substitute.For(); - mockHttpClientFactory.CreateClient(Arg.Is("client")).Returns(apiHttp); - mockHttpClientFactory.CreateClient(Arg.Is("identity")).Returns(identityHttp); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertManyAsync(default); + } - return new SutProvider(fixture) - .SetDependency(settings) - .SetDependency(mockHttpClientFactory) - .Create(); - } + [Theory] + [BitAutoData] + public async Task SyncOrganization_CloudCommunicationDisabled_EarlyReturn( + Guid cloudOrganizationId, OrganizationConnection billingSyncConnection) + { + var sutProvider = GetSutProvider(false); - [Theory] - [BitAutoData] - public async Task SyncOrganization_BillingSyncKeyDisabled_ThrowsBadRequest( - Guid cloudOrganizationId, OrganizationConnection billingSyncConnection) - { - var sutProvider = GetSutProvider(); - billingSyncConnection.Enabled = false; - billingSyncConnection.SetConfig(new BillingSyncConfig + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection)); + + Assert.Contains($"Cloud communication is disabled", exception.Message); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertManyAsync(default); + } + + [Theory] + [OrganizationSponsorshipCustomize] + [BitAutoData] + public async Task SyncOrganization_SyncsSponsorships( + Guid cloudOrganizationId, OrganizationConnection billingSyncConnection, IEnumerable sponsorships) + { + var syncJsonResponse = JsonSerializer.Serialize(new OrganizationSponsorshipSyncResponseModel( + new OrganizationSponsorshipSyncData { - BillingSyncKey = "okslkcslkjf" - }); + SponsorshipsBatch = sponsorships.Select(o => new OrganizationSponsorshipData(o)) + })); - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection)); - - Assert.Contains($"Billing Sync Key disabled", exception.Message); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertManyAsync(default); - } - - [Theory] - [BitAutoData] - public async Task SyncOrganization_BillingSyncKeyEmpty_ThrowsBadRequest( - Guid cloudOrganizationId, OrganizationConnection billingSyncConnection) + var sutProvider = GetSutProvider(apiResponse: syncJsonResponse); + billingSyncConnection.SetConfig(new BillingSyncConfig { - var sutProvider = GetSutProvider(); - billingSyncConnection.Config = ""; + BillingSyncKey = "okslkcslkjf" + }); + sutProvider.GetDependency() + .GetManyBySponsoringOrganizationAsync(Arg.Any()).Returns(sponsorships.ToList()); - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection)); + await sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection); - Assert.Contains($"No Billing Sync Key known", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + await sutProvider.GetDependency() + .Received(1) + .UpsertManyAsync(Arg.Any>()); + } - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertManyAsync(default); - } - - [Theory] - [BitAutoData] - public async Task SyncOrganization_CloudCommunicationDisabled_EarlyReturn( - Guid cloudOrganizationId, OrganizationConnection billingSyncConnection) - { - var sutProvider = GetSutProvider(false); - - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection)); - - Assert.Contains($"Cloud communication is disabled", exception.Message); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertManyAsync(default); - } - - [Theory] - [OrganizationSponsorshipCustomize] - [BitAutoData] - public async Task SyncOrganization_SyncsSponsorships( - Guid cloudOrganizationId, OrganizationConnection billingSyncConnection, IEnumerable sponsorships) - { - var syncJsonResponse = JsonSerializer.Serialize(new OrganizationSponsorshipSyncResponseModel( - new OrganizationSponsorshipSyncData - { - SponsorshipsBatch = sponsorships.Select(o => new OrganizationSponsorshipData(o)) - })); - - var sutProvider = GetSutProvider(apiResponse: syncJsonResponse); - billingSyncConnection.SetConfig(new BillingSyncConfig + [Theory] + [OrganizationSponsorshipCustomize(ToDelete = true)] + [BitAutoData] + public async Task SyncOrganization_DeletesSponsorships( + Guid cloudOrganizationId, OrganizationConnection billingSyncConnection, IEnumerable sponsorships) + { + var syncJsonResponse = JsonSerializer.Serialize(new OrganizationSponsorshipSyncResponseModel( + new OrganizationSponsorshipSyncData { - BillingSyncKey = "okslkcslkjf" - }); - sutProvider.GetDependency() - .GetManyBySponsoringOrganizationAsync(Arg.Any()).Returns(sponsorships.ToList()); + SponsorshipsBatch = sponsorships.Select(o => new OrganizationSponsorshipData(o) { CloudSponsorshipRemoved = true }) + })); - await sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .DeleteManyAsync(default); - await sutProvider.GetDependency() - .Received(1) - .UpsertManyAsync(Arg.Any>()); - } - - [Theory] - [OrganizationSponsorshipCustomize(ToDelete = true)] - [BitAutoData] - public async Task SyncOrganization_DeletesSponsorships( - Guid cloudOrganizationId, OrganizationConnection billingSyncConnection, IEnumerable sponsorships) + var sutProvider = GetSutProvider(apiResponse: syncJsonResponse); + billingSyncConnection.SetConfig(new BillingSyncConfig { - var syncJsonResponse = JsonSerializer.Serialize(new OrganizationSponsorshipSyncResponseModel( - new OrganizationSponsorshipSyncData - { - SponsorshipsBatch = sponsorships.Select(o => new OrganizationSponsorshipData(o) { CloudSponsorshipRemoved = true }) - })); + BillingSyncKey = "okslkcslkjf" + }); + sutProvider.GetDependency() + .GetManyBySponsoringOrganizationAsync(Arg.Any()).Returns(sponsorships.ToList()); - var sutProvider = GetSutProvider(apiResponse: syncJsonResponse); - billingSyncConnection.SetConfig(new BillingSyncConfig - { - BillingSyncKey = "okslkcslkjf" - }); - sutProvider.GetDependency() - .GetManyBySponsoringOrganizationAsync(Arg.Any()).Returns(sponsorships.ToList()); + await sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection); - await sutProvider.Sut.SyncOrganization(billingSyncConnection.OrganizationId, cloudOrganizationId, billingSyncConnection); - - await sutProvider.GetDependency() - .Received(1) - .DeleteManyAsync(Arg.Any>()); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertManyAsync(default); - } + await sutProvider.GetDependency() + .Received(1) + .DeleteManyAsync(Arg.Any>()); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertManyAsync(default); } } diff --git a/test/Core.Test/Resources/VerifyResources.cs b/test/Core.Test/Resources/VerifyResources.cs index 821eb87e2..028ac3e9e 100644 --- a/test/Core.Test/Resources/VerifyResources.cs +++ b/test/Core.Test/Resources/VerifyResources.cs @@ -1,28 +1,27 @@ using Bit.Core.Utilities; using Xunit; -namespace Bit.Core.Test.Resources +namespace Bit.Core.Test.Resources; + +public class VerifyResources { - public class VerifyResources + [Theory] + [MemberData(nameof(GetResources))] + public void Resource_FoundAndReadable(string resourceName) { - [Theory] - [MemberData(nameof(GetResources))] - public void Resource_FoundAndReadable(string resourceName) - { - var assembly = typeof(CoreHelpers).Assembly; + var assembly = typeof(CoreHelpers).Assembly; - using (var resource = assembly.GetManifestResourceStream(resourceName)) - { - Assert.NotNull(resource); - Assert.True(resource.CanRead); - } - } - - public static IEnumerable GetResources() + using (var resource = assembly.GetManifestResourceStream(resourceName)) { - yield return new[] { "Bit.Core.licensing.cer" }; - yield return new[] { "Bit.Core.MailTemplates.Handlebars.AddedCredit.html.hbs" }; - yield return new[] { "Bit.Core.MailTemplates.Handlebars.Layouts.Basic.html.hbs" }; + Assert.NotNull(resource); + Assert.True(resource.CanRead); } } + + public static IEnumerable GetResources() + { + yield return new[] { "Bit.Core.licensing.cer" }; + yield return new[] { "Bit.Core.MailTemplates.Handlebars.AddedCredit.html.hbs" }; + yield return new[] { "Bit.Core.MailTemplates.Handlebars.Layouts.Basic.html.hbs" }; + } } diff --git a/test/Core.Test/Services/AmazonSesMailDeliveryServiceTests.cs b/test/Core.Test/Services/AmazonSesMailDeliveryServiceTests.cs index 6c07c897d..71bbc9f13 100644 --- a/test/Core.Test/Services/AmazonSesMailDeliveryServiceTests.cs +++ b/test/Core.Test/Services/AmazonSesMailDeliveryServiceTests.cs @@ -8,80 +8,79 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class AmazonSesMailDeliveryServiceTests : IDisposable { - public class AmazonSesMailDeliveryServiceTests : IDisposable + private readonly AmazonSesMailDeliveryService _sut; + + private readonly GlobalSettings _globalSettings; + private readonly IWebHostEnvironment _hostingEnvironment; + private readonly ILogger _logger; + private readonly IAmazonSimpleEmailService _amazonSimpleEmailService; + + public AmazonSesMailDeliveryServiceTests() { - private readonly AmazonSesMailDeliveryService _sut; - - private readonly GlobalSettings _globalSettings; - private readonly IWebHostEnvironment _hostingEnvironment; - private readonly ILogger _logger; - private readonly IAmazonSimpleEmailService _amazonSimpleEmailService; - - public AmazonSesMailDeliveryServiceTests() + _globalSettings = new GlobalSettings { - _globalSettings = new GlobalSettings + Amazon = + { + AccessKeyId = "AccessKeyId-AmazonSesMailDeliveryServiceTests", + AccessKeySecret = "AccessKeySecret-AmazonSesMailDeliveryServiceTests", + Region = "Region-AmazonSesMailDeliveryServiceTests" + } + }; + + _hostingEnvironment = Substitute.For(); + _logger = Substitute.For>(); + _amazonSimpleEmailService = Substitute.For(); + + _sut = new AmazonSesMailDeliveryService( + _globalSettings, + _hostingEnvironment, + _logger, + _amazonSimpleEmailService + ); + } + + public void Dispose() + { + _sut?.Dispose(); + } + + [Fact] + public async Task SendEmailAsync_CallsSendEmailAsync_WhenMessageIsValid() + { + var mailMessage = new MailMessage + { + ToEmails = new List { "ToEmails" }, + BccEmails = new List { "BccEmails" }, + Subject = "Subject", + HtmlContent = "HtmlContent", + TextContent = "TextContent", + Category = "Category" + }; + + await _sut.SendEmailAsync(mailMessage); + + await _amazonSimpleEmailService.Received(1).SendEmailAsync( + Arg.Do(request => { - Amazon = - { - AccessKeyId = "AccessKeyId-AmazonSesMailDeliveryServiceTests", - AccessKeySecret = "AccessKeySecret-AmazonSesMailDeliveryServiceTests", - Region = "Region-AmazonSesMailDeliveryServiceTests" - } - }; + Assert.False(string.IsNullOrEmpty(request.Source)); - _hostingEnvironment = Substitute.For(); - _logger = Substitute.For>(); - _amazonSimpleEmailService = Substitute.For(); + Assert.Single(request.Destination.ToAddresses); + Assert.Equal(mailMessage.ToEmails.First(), request.Destination.ToAddresses.First()); - _sut = new AmazonSesMailDeliveryService( - _globalSettings, - _hostingEnvironment, - _logger, - _amazonSimpleEmailService - ); - } + Assert.Equal(mailMessage.Subject, request.Message.Subject.Data); + Assert.Equal(mailMessage.HtmlContent, request.Message.Body.Html.Data); + Assert.Equal(mailMessage.TextContent, request.Message.Body.Text.Data); - public void Dispose() - { - _sut?.Dispose(); - } + Assert.Single(request.Destination.BccAddresses); + Assert.Equal(mailMessage.BccEmails.First(), request.Destination.BccAddresses.First()); - [Fact] - public async Task SendEmailAsync_CallsSendEmailAsync_WhenMessageIsValid() - { - var mailMessage = new MailMessage - { - ToEmails = new List { "ToEmails" }, - BccEmails = new List { "BccEmails" }, - Subject = "Subject", - HtmlContent = "HtmlContent", - TextContent = "TextContent", - Category = "Category" - }; - - await _sut.SendEmailAsync(mailMessage); - - await _amazonSimpleEmailService.Received(1).SendEmailAsync( - Arg.Do(request => - { - Assert.False(string.IsNullOrEmpty(request.Source)); - - Assert.Single(request.Destination.ToAddresses); - Assert.Equal(mailMessage.ToEmails.First(), request.Destination.ToAddresses.First()); - - Assert.Equal(mailMessage.Subject, request.Message.Subject.Data); - Assert.Equal(mailMessage.HtmlContent, request.Message.Body.Html.Data); - Assert.Equal(mailMessage.TextContent, request.Message.Body.Text.Data); - - Assert.Single(request.Destination.BccAddresses); - Assert.Equal(mailMessage.BccEmails.First(), request.Destination.BccAddresses.First()); - - Assert.Contains(request.Tags, x => x.Name == "Environment"); - Assert.Contains(request.Tags, x => x.Name == "Sender"); - Assert.Contains(request.Tags, x => x.Name == "Category"); - })); - } + Assert.Contains(request.Tags, x => x.Name == "Environment"); + Assert.Contains(request.Tags, x => x.Name == "Sender"); + Assert.Contains(request.Tags, x => x.Name == "Category"); + })); } } diff --git a/test/Core.Test/Services/AmazonSqsBlockIpServiceTests.cs b/test/Core.Test/Services/AmazonSqsBlockIpServiceTests.cs index cf24d6293..5c74386ed 100644 --- a/test/Core.Test/Services/AmazonSqsBlockIpServiceTests.cs +++ b/test/Core.Test/Services/AmazonSqsBlockIpServiceTests.cs @@ -4,74 +4,73 @@ using Bit.Core.Settings; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class AmazonSqsBlockIpServiceTests : IDisposable { - public class AmazonSqsBlockIpServiceTests : IDisposable + private readonly AmazonSqsBlockIpService _sut; + + private readonly GlobalSettings _globalSettings; + private readonly IAmazonSQS _amazonSqs; + + public AmazonSqsBlockIpServiceTests() { - private readonly AmazonSqsBlockIpService _sut; - - private readonly GlobalSettings _globalSettings; - private readonly IAmazonSQS _amazonSqs; - - public AmazonSqsBlockIpServiceTests() + _globalSettings = new GlobalSettings { - _globalSettings = new GlobalSettings + Amazon = { - Amazon = - { - AccessKeyId = "AccessKeyId-AmazonSesMailDeliveryServiceTests", - AccessKeySecret = "AccessKeySecret-AmazonSesMailDeliveryServiceTests", - Region = "Region-AmazonSesMailDeliveryServiceTests" - } - }; + AccessKeyId = "AccessKeyId-AmazonSesMailDeliveryServiceTests", + AccessKeySecret = "AccessKeySecret-AmazonSesMailDeliveryServiceTests", + Region = "Region-AmazonSesMailDeliveryServiceTests" + } + }; - _amazonSqs = Substitute.For(); + _amazonSqs = Substitute.For(); - _sut = new AmazonSqsBlockIpService(_globalSettings, _amazonSqs); - } + _sut = new AmazonSqsBlockIpService(_globalSettings, _amazonSqs); + } - public void Dispose() - { - _sut?.Dispose(); - } + public void Dispose() + { + _sut?.Dispose(); + } - [Fact] - public async Task BlockIpAsync_UnblockCalled_WhenNotPermanent() - { - const string expectedIp = "ip"; + [Fact] + public async Task BlockIpAsync_UnblockCalled_WhenNotPermanent() + { + const string expectedIp = "ip"; - await _sut.BlockIpAsync(expectedIp, false); + await _sut.BlockIpAsync(expectedIp, false); - await _amazonSqs.Received(2).SendMessageAsync( - Arg.Any(), - Arg.Is(expectedIp)); - } + await _amazonSqs.Received(2).SendMessageAsync( + Arg.Any(), + Arg.Is(expectedIp)); + } - [Fact] - public async Task BlockIpAsync_UnblockNotCalled_WhenPermanent() - { - const string expectedIp = "ip"; + [Fact] + public async Task BlockIpAsync_UnblockNotCalled_WhenPermanent() + { + const string expectedIp = "ip"; - await _sut.BlockIpAsync(expectedIp, true); + await _sut.BlockIpAsync(expectedIp, true); - await _amazonSqs.Received(1).SendMessageAsync( - Arg.Any(), - Arg.Is(expectedIp)); - } + await _amazonSqs.Received(1).SendMessageAsync( + Arg.Any(), + Arg.Is(expectedIp)); + } - [Fact] - public async Task BlockIpAsync_NotBlocked_WhenAlreadyBlockedRecently() - { - const string expectedIp = "ip"; + [Fact] + public async Task BlockIpAsync_NotBlocked_WhenAlreadyBlockedRecently() + { + const string expectedIp = "ip"; - await _sut.BlockIpAsync(expectedIp, true); + await _sut.BlockIpAsync(expectedIp, true); - // The second call should hit the already blocked guard clause - await _sut.BlockIpAsync(expectedIp, true); + // The second call should hit the already blocked guard clause + await _sut.BlockIpAsync(expectedIp, true); - await _amazonSqs.Received(1).SendMessageAsync( - Arg.Any(), - Arg.Is(expectedIp)); - } + await _amazonSqs.Received(1).SendMessageAsync( + Arg.Any(), + Arg.Is(expectedIp)); } } diff --git a/test/Core.Test/Services/AppleIapServiceTests.cs b/test/Core.Test/Services/AppleIapServiceTests.cs index ff14e52e8..c376af288 100644 --- a/test/Core.Test/Services/AppleIapServiceTests.cs +++ b/test/Core.Test/Services/AppleIapServiceTests.cs @@ -6,36 +6,35 @@ using NSubstitute; using NSubstitute.Core; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +[SutProviderCustomize] +public class AppleIapServiceTests { - [SutProviderCustomize] - public class AppleIapServiceTests + [Theory, BitAutoData] + public async Task GetReceiptStatusAsync_MoreThanFourAttempts_Throws(SutProvider sutProvider) { - [Theory, BitAutoData] - public async Task GetReceiptStatusAsync_MoreThanFourAttempts_Throws(SutProvider sutProvider) + var result = await sutProvider.Sut.GetReceiptStatusAsync("test", false, 5, null); + Assert.Null(result); + + var errorLog = sutProvider.GetDependency>() + .ReceivedCalls() + .SingleOrDefault(LogOneWarning); + + Assert.True(errorLog != null, "Must contain one error log of warning level containing 'null'"); + + static bool LogOneWarning(ICall call) { - var result = await sutProvider.Sut.GetReceiptStatusAsync("test", false, 5, null); - Assert.Null(result); - - var errorLog = sutProvider.GetDependency>() - .ReceivedCalls() - .SingleOrDefault(LogOneWarning); - - Assert.True(errorLog != null, "Must contain one error log of warning level containing 'null'"); - - static bool LogOneWarning(ICall call) + if (call.GetMethodInfo().Name != "Log") { - if (call.GetMethodInfo().Name != "Log") - { - return false; - } - - var args = call.GetArguments(); - var logLevel = (LogLevel)args[0]; - var exception = (Exception)args[3]; - - return logLevel == LogLevel.Warning && exception.Message.Contains("null"); + return false; } + + var args = call.GetArguments(); + var logLevel = (LogLevel)args[0]; + var exception = (Exception)args[3]; + + return logLevel == LogLevel.Warning && exception.Message.Contains("null"); } } } diff --git a/test/Core.Test/Services/AzureAttachmentStorageServiceTests.cs b/test/Core.Test/Services/AzureAttachmentStorageServiceTests.cs index 75cda61aa..21a5fa3f8 100644 --- a/test/Core.Test/Services/AzureAttachmentStorageServiceTests.cs +++ b/test/Core.Test/Services/AzureAttachmentStorageServiceTests.cs @@ -4,29 +4,28 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class AzureAttachmentStorageServiceTests { - public class AzureAttachmentStorageServiceTests + private readonly AzureAttachmentStorageService _sut; + + private readonly GlobalSettings _globalSettings; + private readonly ILogger _logger; + + public AzureAttachmentStorageServiceTests() { - private readonly AzureAttachmentStorageService _sut; + _globalSettings = new GlobalSettings(); + _logger = Substitute.For>(); - private readonly GlobalSettings _globalSettings; - private readonly ILogger _logger; + _sut = new AzureAttachmentStorageService(_globalSettings, _logger); + } - public AzureAttachmentStorageServiceTests() - { - _globalSettings = new GlobalSettings(); - _logger = Substitute.For>(); - - _sut = new AzureAttachmentStorageService(_globalSettings, _logger); - } - - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); - } + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); } } diff --git a/test/Core.Test/Services/AzureQueueBlockIpServiceTests.cs b/test/Core.Test/Services/AzureQueueBlockIpServiceTests.cs index e4ad8bab9..9efbe7180 100644 --- a/test/Core.Test/Services/AzureQueueBlockIpServiceTests.cs +++ b/test/Core.Test/Services/AzureQueueBlockIpServiceTests.cs @@ -2,27 +2,26 @@ using Bit.Core.Settings; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class AzureQueueBlockIpServiceTests { - public class AzureQueueBlockIpServiceTests + private readonly AzureQueueBlockIpService _sut; + + private readonly GlobalSettings _globalSettings; + + public AzureQueueBlockIpServiceTests() { - private readonly AzureQueueBlockIpService _sut; + _globalSettings = new GlobalSettings(); - private readonly GlobalSettings _globalSettings; + _sut = new AzureQueueBlockIpService(_globalSettings); + } - public AzureQueueBlockIpServiceTests() - { - _globalSettings = new GlobalSettings(); - - _sut = new AzureQueueBlockIpService(_globalSettings); - } - - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); - } + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); } } diff --git a/test/Core.Test/Services/AzureQueueEventWriteServiceTests.cs b/test/Core.Test/Services/AzureQueueEventWriteServiceTests.cs index ce44b5f30..2c4916dc6 100644 --- a/test/Core.Test/Services/AzureQueueEventWriteServiceTests.cs +++ b/test/Core.Test/Services/AzureQueueEventWriteServiceTests.cs @@ -4,31 +4,30 @@ using Bit.Core.Settings; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class AzureQueueEventWriteServiceTests { - public class AzureQueueEventWriteServiceTests + private readonly AzureQueueEventWriteService _sut; + + private readonly GlobalSettings _globalSettings; + private readonly IEventRepository _eventRepository; + + public AzureQueueEventWriteServiceTests() { - private readonly AzureQueueEventWriteService _sut; + _globalSettings = new GlobalSettings(); + _eventRepository = Substitute.For(); - private readonly GlobalSettings _globalSettings; - private readonly IEventRepository _eventRepository; + _sut = new AzureQueueEventWriteService( + _globalSettings + ); + } - public AzureQueueEventWriteServiceTests() - { - _globalSettings = new GlobalSettings(); - _eventRepository = Substitute.For(); - - _sut = new AzureQueueEventWriteService( - _globalSettings - ); - } - - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); - } + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); } } diff --git a/test/Core.Test/Services/AzureQueuePushNotificationServiceTests.cs b/test/Core.Test/Services/AzureQueuePushNotificationServiceTests.cs index abb6ad31a..7f9cb750a 100644 --- a/test/Core.Test/Services/AzureQueuePushNotificationServiceTests.cs +++ b/test/Core.Test/Services/AzureQueuePushNotificationServiceTests.cs @@ -4,32 +4,31 @@ using Microsoft.AspNetCore.Http; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class AzureQueuePushNotificationServiceTests { - public class AzureQueuePushNotificationServiceTests + private readonly AzureQueuePushNotificationService _sut; + + private readonly GlobalSettings _globalSettings; + private readonly IHttpContextAccessor _httpContextAccessor; + + public AzureQueuePushNotificationServiceTests() { - private readonly AzureQueuePushNotificationService _sut; + _globalSettings = new GlobalSettings(); + _httpContextAccessor = Substitute.For(); - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; + _sut = new AzureQueuePushNotificationService( + _globalSettings, + _httpContextAccessor + ); + } - public AzureQueuePushNotificationServiceTests() - { - _globalSettings = new GlobalSettings(); - _httpContextAccessor = Substitute.For(); - - _sut = new AzureQueuePushNotificationService( - _globalSettings, - _httpContextAccessor - ); - } - - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); - } + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); } } diff --git a/test/Core.Test/Services/CipherServiceTests.cs b/test/Core.Test/Services/CipherServiceTests.cs index 1e3444481..f036b973f 100644 --- a/test/Core.Test/Services/CipherServiceTests.cs +++ b/test/Core.Test/Services/CipherServiceTests.cs @@ -9,209 +9,208 @@ using Core.Models.Data; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class CipherServiceTests { - public class CipherServiceTests + [Theory, UserCipherAutoData] + public async Task SaveAsync_WrongRevisionDate_Throws(SutProvider sutProvider, Cipher cipher) { - [Theory, UserCipherAutoData] - public async Task SaveAsync_WrongRevisionDate_Throws(SutProvider sutProvider, Cipher cipher) - { - var lastKnownRevisionDate = cipher.RevisionDate.AddDays(-1); + var lastKnownRevisionDate = cipher.RevisionDate.AddDays(-1); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(cipher, cipher.UserId.Value, lastKnownRevisionDate)); - Assert.Contains("out of date", exception.Message); - } + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(cipher, cipher.UserId.Value, lastKnownRevisionDate)); + Assert.Contains("out of date", exception.Message); + } - [Theory, UserCipherAutoData] - public async Task SaveDetailsAsync_WrongRevisionDate_Throws(SutProvider sutProvider, - CipherDetails cipherDetails) - { - var lastKnownRevisionDate = cipherDetails.RevisionDate.AddDays(-1); + [Theory, UserCipherAutoData] + public async Task SaveDetailsAsync_WrongRevisionDate_Throws(SutProvider sutProvider, + CipherDetails cipherDetails) + { + var lastKnownRevisionDate = cipherDetails.RevisionDate.AddDays(-1); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveDetailsAsync(cipherDetails, cipherDetails.UserId.Value, lastKnownRevisionDate)); - Assert.Contains("out of date", exception.Message); - } + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveDetailsAsync(cipherDetails, cipherDetails.UserId.Value, lastKnownRevisionDate)); + Assert.Contains("out of date", exception.Message); + } - [Theory, UserCipherAutoData] - public async Task ShareAsync_WrongRevisionDate_Throws(SutProvider sutProvider, Cipher cipher, - Organization organization, List collectionIds) - { - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var lastKnownRevisionDate = cipher.RevisionDate.AddDays(-1); + [Theory, UserCipherAutoData] + public async Task ShareAsync_WrongRevisionDate_Throws(SutProvider sutProvider, Cipher cipher, + Organization organization, List collectionIds) + { + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + var lastKnownRevisionDate = cipher.RevisionDate.AddDays(-1); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.ShareAsync(cipher, cipher, organization.Id, collectionIds, cipher.UserId.Value, - lastKnownRevisionDate)); - Assert.Contains("out of date", exception.Message); - } + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ShareAsync(cipher, cipher, organization.Id, collectionIds, cipher.UserId.Value, + lastKnownRevisionDate)); + Assert.Contains("out of date", exception.Message); + } - [Theory, UserCipherAutoData("99ab4f6c-44f8-4ff5-be7a-75c37c33c69e")] - public async Task ShareManyAsync_WrongRevisionDate_Throws(SutProvider sutProvider, - IEnumerable ciphers, Guid organizationId, List collectionIds) - { - sutProvider.GetDependency().GetByIdAsync(organizationId) - .Returns(new Organization - { - PlanType = Enums.PlanType.EnterpriseAnnually, - MaxStorageGb = 100 - }); - - var cipherInfos = ciphers.Select(c => (c, (DateTime?)c.RevisionDate.AddDays(-1))); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.ShareManyAsync(cipherInfos, organizationId, collectionIds, ciphers.First().UserId.Value)); - Assert.Contains("out of date", exception.Message); - } - - [Theory] - [InlineUserCipherAutoData("")] - [InlineUserCipherAutoData("Correct Time")] - public async Task SaveAsync_CorrectRevisionDate_Passes(string revisionDateString, - SutProvider sutProvider, Cipher cipher) - { - var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipher.RevisionDate; - - await sutProvider.Sut.SaveAsync(cipher, cipher.UserId.Value, lastKnownRevisionDate); - await sutProvider.GetDependency().Received(1).ReplaceAsync(cipher); - } - - [Theory] - [InlineUserCipherAutoData("")] - [InlineUserCipherAutoData("Correct Time")] - public async Task SaveDetailsAsync_CorrectRevisionDate_Passes(string revisionDateString, - SutProvider sutProvider, CipherDetails cipherDetails) - { - var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipherDetails.RevisionDate; - - await sutProvider.Sut.SaveDetailsAsync(cipherDetails, cipherDetails.UserId.Value, lastKnownRevisionDate); - await sutProvider.GetDependency().Received(1).ReplaceAsync(cipherDetails); - } - - [Theory] - [InlineUserCipherAutoData("")] - [InlineUserCipherAutoData("Correct Time")] - public async Task ShareAsync_CorrectRevisionDate_Passes(string revisionDateString, - SutProvider sutProvider, Cipher cipher, Organization organization, List collectionIds) - { - var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipher.RevisionDate; - var cipherRepository = sutProvider.GetDependency(); - cipherRepository.ReplaceAsync(cipher, collectionIds).Returns(true); - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - - await sutProvider.Sut.ShareAsync(cipher, cipher, organization.Id, collectionIds, cipher.UserId.Value, - lastKnownRevisionDate); - await cipherRepository.Received(1).ReplaceAsync(cipher, collectionIds); - } - - [Theory] - [InlineKnownUserCipherAutoData(userId: "99ab4f6c-44f8-4ff5-be7a-75c37c33c69e", "")] - [InlineKnownUserCipherAutoData(userId: "99ab4f6c-44f8-4ff5-be7a-75c37c33c69e", "CorrectTime")] - public async Task ShareManyAsync_CorrectRevisionDate_Passes(string revisionDateString, - SutProvider sutProvider, IEnumerable ciphers, Organization organization, List collectionIds) - { - sutProvider.GetDependency().GetByIdAsync(organization.Id) - .Returns(new Organization - { - PlanType = Enums.PlanType.EnterpriseAnnually, - MaxStorageGb = 100 - }); - - var cipherInfos = ciphers.Select(c => (c, - string.IsNullOrEmpty(revisionDateString) ? null : (DateTime?)c.RevisionDate)); - var sharingUserId = ciphers.First().UserId.Value; - - await sutProvider.Sut.ShareManyAsync(cipherInfos, organization.Id, collectionIds, sharingUserId); - await sutProvider.GetDependency().Received(1).UpdateCiphersAsync(sharingUserId, - Arg.Is>(arg => arg.Except(ciphers).IsNullOrEmpty())); - } - - [Theory] - [InlineKnownUserCipherAutoData("c64d8a15-606e-41d6-9c7e-174d4d8f3b2e", "c64d8a15-606e-41d6-9c7e-174d4d8f3b2e")] - [InlineOrganizationCipherAutoData("c64d8a15-606e-41d6-9c7e-174d4d8f3b2e")] - public async Task RestoreAsync_UpdatesCipher(Guid restoringUserId, Cipher cipher, SutProvider sutProvider) - { - sutProvider.GetDependency().GetCanEditByIdAsync(restoringUserId, cipher.Id).Returns(true); - - var initialRevisionDate = new DateTime(1970, 1, 1, 0, 0, 0); - cipher.DeletedDate = initialRevisionDate; - cipher.RevisionDate = initialRevisionDate; - - await sutProvider.Sut.RestoreAsync(cipher, restoringUserId, cipher.OrganizationId.HasValue); - - Assert.Null(cipher.DeletedDate); - Assert.NotEqual(initialRevisionDate, cipher.RevisionDate); - } - - [Theory] - [InlineKnownUserCipherAutoData("c64d8a15-606e-41d6-9c7e-174d4d8f3b2e", "c64d8a15-606e-41d6-9c7e-174d4d8f3b2e")] - public async Task RestoreManyAsync_UpdatesCiphers(Guid restoringUserId, IEnumerable ciphers, - SutProvider sutProvider) - { - var previousRevisionDate = DateTime.UtcNow; - foreach (var cipher in ciphers) + [Theory, UserCipherAutoData("99ab4f6c-44f8-4ff5-be7a-75c37c33c69e")] + public async Task ShareManyAsync_WrongRevisionDate_Throws(SutProvider sutProvider, + IEnumerable ciphers, Guid organizationId, List collectionIds) + { + sutProvider.GetDependency().GetByIdAsync(organizationId) + .Returns(new Organization { - cipher.RevisionDate = previousRevisionDate; - } - - var revisionDate = previousRevisionDate + TimeSpan.FromMinutes(1); - sutProvider.GetDependency().RestoreAsync(Arg.Any>(), restoringUserId) - .Returns(revisionDate); - - await sutProvider.Sut.RestoreManyAsync(ciphers, restoringUserId); - - foreach (var cipher in ciphers) - { - Assert.Null(cipher.DeletedDate); - Assert.Equal(revisionDate, cipher.RevisionDate); - } - } - - [Theory] - [InlineUserCipherAutoData] - public async Task ShareManyAsync_FreeOrgWithAttachment_Throws(SutProvider sutProvider, - IEnumerable ciphers, Guid organizationId, List collectionIds) - { - sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(new Organization - { - PlanType = Enums.PlanType.Free + PlanType = Enums.PlanType.EnterpriseAnnually, + MaxStorageGb = 100 }); - ciphers.FirstOrDefault().Attachments = - "{\"attachment1\":{\"Size\":\"250\",\"FileName\":\"superCoolFile\"," - + "\"Key\":\"superCoolFile\",\"ContainerName\":\"testContainer\",\"Validated\":false}}"; - var cipherInfos = ciphers.Select(c => (c, - (DateTime?)c.RevisionDate)); - var sharingUserId = ciphers.First().UserId.Value; + var cipherInfos = ciphers.Select(c => (c, (DateTime?)c.RevisionDate.AddDays(-1))); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.ShareManyAsync(cipherInfos, organizationId, collectionIds, sharingUserId)); - Assert.Contains("This organization cannot use attachments", exception.Message); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ShareManyAsync(cipherInfos, organizationId, collectionIds, ciphers.First().UserId.Value)); + Assert.Contains("out of date", exception.Message); + } + + [Theory] + [InlineUserCipherAutoData("")] + [InlineUserCipherAutoData("Correct Time")] + public async Task SaveAsync_CorrectRevisionDate_Passes(string revisionDateString, + SutProvider sutProvider, Cipher cipher) + { + var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipher.RevisionDate; + + await sutProvider.Sut.SaveAsync(cipher, cipher.UserId.Value, lastKnownRevisionDate); + await sutProvider.GetDependency().Received(1).ReplaceAsync(cipher); + } + + [Theory] + [InlineUserCipherAutoData("")] + [InlineUserCipherAutoData("Correct Time")] + public async Task SaveDetailsAsync_CorrectRevisionDate_Passes(string revisionDateString, + SutProvider sutProvider, CipherDetails cipherDetails) + { + var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipherDetails.RevisionDate; + + await sutProvider.Sut.SaveDetailsAsync(cipherDetails, cipherDetails.UserId.Value, lastKnownRevisionDate); + await sutProvider.GetDependency().Received(1).ReplaceAsync(cipherDetails); + } + + [Theory] + [InlineUserCipherAutoData("")] + [InlineUserCipherAutoData("Correct Time")] + public async Task ShareAsync_CorrectRevisionDate_Passes(string revisionDateString, + SutProvider sutProvider, Cipher cipher, Organization organization, List collectionIds) + { + var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipher.RevisionDate; + var cipherRepository = sutProvider.GetDependency(); + cipherRepository.ReplaceAsync(cipher, collectionIds).Returns(true); + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + await sutProvider.Sut.ShareAsync(cipher, cipher, organization.Id, collectionIds, cipher.UserId.Value, + lastKnownRevisionDate); + await cipherRepository.Received(1).ReplaceAsync(cipher, collectionIds); + } + + [Theory] + [InlineKnownUserCipherAutoData(userId: "99ab4f6c-44f8-4ff5-be7a-75c37c33c69e", "")] + [InlineKnownUserCipherAutoData(userId: "99ab4f6c-44f8-4ff5-be7a-75c37c33c69e", "CorrectTime")] + public async Task ShareManyAsync_CorrectRevisionDate_Passes(string revisionDateString, + SutProvider sutProvider, IEnumerable ciphers, Organization organization, List collectionIds) + { + sutProvider.GetDependency().GetByIdAsync(organization.Id) + .Returns(new Organization + { + PlanType = Enums.PlanType.EnterpriseAnnually, + MaxStorageGb = 100 + }); + + var cipherInfos = ciphers.Select(c => (c, + string.IsNullOrEmpty(revisionDateString) ? null : (DateTime?)c.RevisionDate)); + var sharingUserId = ciphers.First().UserId.Value; + + await sutProvider.Sut.ShareManyAsync(cipherInfos, organization.Id, collectionIds, sharingUserId); + await sutProvider.GetDependency().Received(1).UpdateCiphersAsync(sharingUserId, + Arg.Is>(arg => arg.Except(ciphers).IsNullOrEmpty())); + } + + [Theory] + [InlineKnownUserCipherAutoData("c64d8a15-606e-41d6-9c7e-174d4d8f3b2e", "c64d8a15-606e-41d6-9c7e-174d4d8f3b2e")] + [InlineOrganizationCipherAutoData("c64d8a15-606e-41d6-9c7e-174d4d8f3b2e")] + public async Task RestoreAsync_UpdatesCipher(Guid restoringUserId, Cipher cipher, SutProvider sutProvider) + { + sutProvider.GetDependency().GetCanEditByIdAsync(restoringUserId, cipher.Id).Returns(true); + + var initialRevisionDate = new DateTime(1970, 1, 1, 0, 0, 0); + cipher.DeletedDate = initialRevisionDate; + cipher.RevisionDate = initialRevisionDate; + + await sutProvider.Sut.RestoreAsync(cipher, restoringUserId, cipher.OrganizationId.HasValue); + + Assert.Null(cipher.DeletedDate); + Assert.NotEqual(initialRevisionDate, cipher.RevisionDate); + } + + [Theory] + [InlineKnownUserCipherAutoData("c64d8a15-606e-41d6-9c7e-174d4d8f3b2e", "c64d8a15-606e-41d6-9c7e-174d4d8f3b2e")] + public async Task RestoreManyAsync_UpdatesCiphers(Guid restoringUserId, IEnumerable ciphers, + SutProvider sutProvider) + { + var previousRevisionDate = DateTime.UtcNow; + foreach (var cipher in ciphers) + { + cipher.RevisionDate = previousRevisionDate; } - [Theory] - [InlineUserCipherAutoData] - public async Task ShareManyAsync_PaidOrgWithAttachment_Passes(SutProvider sutProvider, - IEnumerable ciphers, Guid organizationId, List collectionIds) + var revisionDate = previousRevisionDate + TimeSpan.FromMinutes(1); + sutProvider.GetDependency().RestoreAsync(Arg.Any>(), restoringUserId) + .Returns(revisionDate); + + await sutProvider.Sut.RestoreManyAsync(ciphers, restoringUserId); + + foreach (var cipher in ciphers) { - sutProvider.GetDependency().GetByIdAsync(organizationId) - .Returns(new Organization - { - PlanType = Enums.PlanType.EnterpriseAnnually, - MaxStorageGb = 100 - }); - ciphers.FirstOrDefault().Attachments = - "{\"attachment1\":{\"Size\":\"250\",\"FileName\":\"superCoolFile\"," - + "\"Key\":\"superCoolFile\",\"ContainerName\":\"testContainer\",\"Validated\":false}}"; - - var cipherInfos = ciphers.Select(c => (c, - (DateTime?)c.RevisionDate)); - var sharingUserId = ciphers.First().UserId.Value; - - await sutProvider.Sut.ShareManyAsync(cipherInfos, organizationId, collectionIds, sharingUserId); - await sutProvider.GetDependency().Received(1).UpdateCiphersAsync(sharingUserId, - Arg.Is>(arg => arg.Except(ciphers).IsNullOrEmpty())); + Assert.Null(cipher.DeletedDate); + Assert.Equal(revisionDate, cipher.RevisionDate); } } + + [Theory] + [InlineUserCipherAutoData] + public async Task ShareManyAsync_FreeOrgWithAttachment_Throws(SutProvider sutProvider, + IEnumerable ciphers, Guid organizationId, List collectionIds) + { + sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(new Organization + { + PlanType = Enums.PlanType.Free + }); + ciphers.FirstOrDefault().Attachments = + "{\"attachment1\":{\"Size\":\"250\",\"FileName\":\"superCoolFile\"," + + "\"Key\":\"superCoolFile\",\"ContainerName\":\"testContainer\",\"Validated\":false}}"; + + var cipherInfos = ciphers.Select(c => (c, + (DateTime?)c.RevisionDate)); + var sharingUserId = ciphers.First().UserId.Value; + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ShareManyAsync(cipherInfos, organizationId, collectionIds, sharingUserId)); + Assert.Contains("This organization cannot use attachments", exception.Message); + } + + [Theory] + [InlineUserCipherAutoData] + public async Task ShareManyAsync_PaidOrgWithAttachment_Passes(SutProvider sutProvider, + IEnumerable ciphers, Guid organizationId, List collectionIds) + { + sutProvider.GetDependency().GetByIdAsync(organizationId) + .Returns(new Organization + { + PlanType = Enums.PlanType.EnterpriseAnnually, + MaxStorageGb = 100 + }); + ciphers.FirstOrDefault().Attachments = + "{\"attachment1\":{\"Size\":\"250\",\"FileName\":\"superCoolFile\"," + + "\"Key\":\"superCoolFile\",\"ContainerName\":\"testContainer\",\"Validated\":false}}"; + + var cipherInfos = ciphers.Select(c => (c, + (DateTime?)c.RevisionDate)); + var sharingUserId = ciphers.First().UserId.Value; + + await sutProvider.Sut.ShareManyAsync(cipherInfos, organizationId, collectionIds, sharingUserId); + await sutProvider.GetDependency().Received(1).UpdateCiphersAsync(sharingUserId, + Arg.Is>(arg => arg.Except(ciphers).IsNullOrEmpty())); + } } diff --git a/test/Core.Test/Services/CollectionServiceTests.cs b/test/Core.Test/Services/CollectionServiceTests.cs index cf4228b3f..b6a68b58e 100644 --- a/test/Core.Test/Services/CollectionServiceTests.cs +++ b/test/Core.Test/Services/CollectionServiceTests.cs @@ -10,161 +10,160 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class CollectionServiceTest { - public class CollectionServiceTest + [Theory, CollectionAutoData] + public async Task SaveAsync_DefaultId_CreatesCollectionInTheRepository(Collection collection, Organization organization, SutProvider sutProvider) { - [Theory, CollectionAutoData] - public async Task SaveAsync_DefaultId_CreatesCollectionInTheRepository(Collection collection, Organization organization, SutProvider sutProvider) - { - collection.Id = default; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var utcNow = DateTime.UtcNow; + collection.Id = default; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + var utcNow = DateTime.UtcNow; - await sutProvider.Sut.SaveAsync(collection); + await sutProvider.Sut.SaveAsync(collection); - await sutProvider.GetDependency().Received().CreateAsync(collection); - await sutProvider.GetDependency().Received() - .LogCollectionEventAsync(collection, EventType.Collection_Created); - Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } + await sutProvider.GetDependency().Received().CreateAsync(collection); + await sutProvider.GetDependency().Received() + .LogCollectionEventAsync(collection, EventType.Collection_Created); + Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } - [Theory, CollectionAutoData] - public async Task SaveAsync_DefaultIdWithGroups_CreateCollectionWithGroupsInRepository(Collection collection, - IEnumerable groups, Organization organization, SutProvider sutProvider) - { - collection.Id = default; - organization.UseGroups = true; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var utcNow = DateTime.UtcNow; + [Theory, CollectionAutoData] + public async Task SaveAsync_DefaultIdWithGroups_CreateCollectionWithGroupsInRepository(Collection collection, + IEnumerable groups, Organization organization, SutProvider sutProvider) + { + collection.Id = default; + organization.UseGroups = true; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + var utcNow = DateTime.UtcNow; - await sutProvider.Sut.SaveAsync(collection, groups); + await sutProvider.Sut.SaveAsync(collection, groups); - await sutProvider.GetDependency().Received().CreateAsync(collection, groups); - await sutProvider.GetDependency().Received() - .LogCollectionEventAsync(collection, EventType.Collection_Created); - Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } + await sutProvider.GetDependency().Received().CreateAsync(collection, groups); + await sutProvider.GetDependency().Received() + .LogCollectionEventAsync(collection, EventType.Collection_Created); + Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } - [Theory, CollectionAutoData] - public async Task SaveAsync_NonDefaultId_ReplacesCollectionInRepository(Collection collection, Organization organization, SutProvider sutProvider) - { - var creationDate = collection.CreationDate; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var utcNow = DateTime.UtcNow; + [Theory, CollectionAutoData] + public async Task SaveAsync_NonDefaultId_ReplacesCollectionInRepository(Collection collection, Organization organization, SutProvider sutProvider) + { + var creationDate = collection.CreationDate; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + var utcNow = DateTime.UtcNow; - await sutProvider.Sut.SaveAsync(collection); + await sutProvider.Sut.SaveAsync(collection); - await sutProvider.GetDependency().Received().ReplaceAsync(collection); - await sutProvider.GetDependency().Received() - .LogCollectionEventAsync(collection, EventType.Collection_Updated); - Assert.Equal(collection.CreationDate, creationDate); - Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } + await sutProvider.GetDependency().Received().ReplaceAsync(collection); + await sutProvider.GetDependency().Received() + .LogCollectionEventAsync(collection, EventType.Collection_Updated); + Assert.Equal(collection.CreationDate, creationDate); + Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } - [Theory, CollectionAutoData] - public async Task SaveAsync_OrganizationNotUseGroup_CreateCollectionWithoutGroupsInRepository(Collection collection, IEnumerable groups, - Organization organization, SutProvider sutProvider) - { - collection.Id = default; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var utcNow = DateTime.UtcNow; + [Theory, CollectionAutoData] + public async Task SaveAsync_OrganizationNotUseGroup_CreateCollectionWithoutGroupsInRepository(Collection collection, IEnumerable groups, + Organization organization, SutProvider sutProvider) + { + collection.Id = default; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + var utcNow = DateTime.UtcNow; - await sutProvider.Sut.SaveAsync(collection, groups); + await sutProvider.Sut.SaveAsync(collection, groups); - await sutProvider.GetDependency().Received().CreateAsync(collection); - await sutProvider.GetDependency().Received() - .LogCollectionEventAsync(collection, EventType.Collection_Created); - Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } + await sutProvider.GetDependency().Received().CreateAsync(collection); + await sutProvider.GetDependency().Received() + .LogCollectionEventAsync(collection, EventType.Collection_Created); + Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } - [Theory, CollectionAutoData] - public async Task SaveAsync_DefaultIdWithUserId_UpdateUserInCollectionRepository(Collection collection, - Organization organization, OrganizationUser organizationUser, SutProvider sutProvider) - { - collection.Id = default; - organizationUser.Status = OrganizationUserStatusType.Confirmed; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - sutProvider.GetDependency().GetByOrganizationAsync(organization.Id, organizationUser.Id) - .Returns(organizationUser); - var utcNow = DateTime.UtcNow; + [Theory, CollectionAutoData] + public async Task SaveAsync_DefaultIdWithUserId_UpdateUserInCollectionRepository(Collection collection, + Organization organization, OrganizationUser organizationUser, SutProvider sutProvider) + { + collection.Id = default; + organizationUser.Status = OrganizationUserStatusType.Confirmed; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + sutProvider.GetDependency().GetByOrganizationAsync(organization.Id, organizationUser.Id) + .Returns(organizationUser); + var utcNow = DateTime.UtcNow; - await sutProvider.Sut.SaveAsync(collection, null, organizationUser.Id); + await sutProvider.Sut.SaveAsync(collection, null, organizationUser.Id); - await sutProvider.GetDependency().Received().CreateAsync(collection); - await sutProvider.GetDependency().Received() - .GetByOrganizationAsync(organization.Id, organizationUser.Id); - await sutProvider.GetDependency().Received().UpdateUsersAsync(collection.Id, Arg.Any>()); - await sutProvider.GetDependency().Received() - .LogCollectionEventAsync(collection, EventType.Collection_Created); - Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } + await sutProvider.GetDependency().Received().CreateAsync(collection); + await sutProvider.GetDependency().Received() + .GetByOrganizationAsync(organization.Id, organizationUser.Id); + await sutProvider.GetDependency().Received().UpdateUsersAsync(collection.Id, Arg.Any>()); + await sutProvider.GetDependency().Received() + .LogCollectionEventAsync(collection, EventType.Collection_Created); + Assert.True(collection.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(collection.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_NonExistingOrganizationId_ThrowsBadRequest(Collection collection, SutProvider sutProvider) - { - var ex = await Assert.ThrowsAsync(() => sutProvider.Sut.SaveAsync(collection)); - Assert.Contains("Organization not found", ex.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default, default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogCollectionEventAsync(default, default); - } + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_NonExistingOrganizationId_ThrowsBadRequest(Collection collection, SutProvider sutProvider) + { + var ex = await Assert.ThrowsAsync(() => sutProvider.Sut.SaveAsync(collection)); + Assert.Contains("Organization not found", ex.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default, default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogCollectionEventAsync(default, default); + } - [Theory, CollectionAutoData] - public async Task SaveAsync_ExceedsOrganizationMaxCollections_ThrowsBadRequest(Collection collection, Organization organization, SutProvider sutProvider) - { - collection.Id = default; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - sutProvider.GetDependency().GetCountByOrganizationIdAsync(organization.Id) - .Returns(organization.MaxCollections.Value); + [Theory, CollectionAutoData] + public async Task SaveAsync_ExceedsOrganizationMaxCollections_ThrowsBadRequest(Collection collection, Organization organization, SutProvider sutProvider) + { + collection.Id = default; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + sutProvider.GetDependency().GetCountByOrganizationIdAsync(organization.Id) + .Returns(organization.MaxCollections.Value); - var ex = await Assert.ThrowsAsync(() => sutProvider.Sut.SaveAsync(collection)); - Assert.Equal($@"You have reached the maximum number of collections ({organization.MaxCollections.Value}) for this organization.", ex.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default, default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogCollectionEventAsync(default, default); - } + var ex = await Assert.ThrowsAsync(() => sutProvider.Sut.SaveAsync(collection)); + Assert.Equal($@"You have reached the maximum number of collections ({organization.MaxCollections.Value}) for this organization.", ex.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default, default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogCollectionEventAsync(default, default); + } - [Theory, CollectionAutoData] - public async Task DeleteUserAsync_DeletesValidUserWhoBelongsToCollection(Collection collection, - Organization organization, OrganizationUser organizationUser, SutProvider sutProvider) - { - collection.OrganizationId = organization.Id; - organizationUser.OrganizationId = organization.Id; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - sutProvider.GetDependency().GetByIdAsync(organizationUser.Id) - .Returns(organizationUser); + [Theory, CollectionAutoData] + public async Task DeleteUserAsync_DeletesValidUserWhoBelongsToCollection(Collection collection, + Organization organization, OrganizationUser organizationUser, SutProvider sutProvider) + { + collection.OrganizationId = organization.Id; + organizationUser.OrganizationId = organization.Id; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + sutProvider.GetDependency().GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); - await sutProvider.Sut.DeleteUserAsync(collection, organizationUser.Id); + await sutProvider.Sut.DeleteUserAsync(collection, organizationUser.Id); - await sutProvider.GetDependency().Received() - .DeleteUserAsync(collection.Id, organizationUser.Id); - await sutProvider.GetDependency().Received().LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Updated); - } + await sutProvider.GetDependency().Received() + .DeleteUserAsync(collection.Id, organizationUser.Id); + await sutProvider.GetDependency().Received().LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Updated); + } - [Theory, CollectionAutoData] - public async Task DeleteUserAsync_InvalidUser_ThrowsNotFound(Collection collection, Organization organization, - OrganizationUser organizationUser, SutProvider sutProvider) - { - collection.OrganizationId = organization.Id; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - sutProvider.GetDependency().GetByIdAsync(organizationUser.Id) - .Returns(organizationUser); + [Theory, CollectionAutoData] + public async Task DeleteUserAsync_InvalidUser_ThrowsNotFound(Collection collection, Organization organization, + OrganizationUser organizationUser, SutProvider sutProvider) + { + collection.OrganizationId = organization.Id; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + sutProvider.GetDependency().GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); - // user not in organization - await Assert.ThrowsAsync(() => - sutProvider.Sut.DeleteUserAsync(collection, organizationUser.Id)); - // invalid user - await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteUserAsync(collection, Guid.NewGuid())); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().DeleteUserAsync(default, default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .LogOrganizationUserEventAsync(default, default); - } + // user not in organization + await Assert.ThrowsAsync(() => + sutProvider.Sut.DeleteUserAsync(collection, organizationUser.Id)); + // invalid user + await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteUserAsync(collection, Guid.NewGuid())); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().DeleteUserAsync(default, default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .LogOrganizationUserEventAsync(default, default); } } diff --git a/test/Core.Test/Services/DeviceServiceTests.cs b/test/Core.Test/Services/DeviceServiceTests.cs index f3a50d4d0..8bc921283 100644 --- a/test/Core.Test/Services/DeviceServiceTests.cs +++ b/test/Core.Test/Services/DeviceServiceTests.cs @@ -5,33 +5,32 @@ using Bit.Core.Services; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class DeviceServiceTests { - public class DeviceServiceTests + [Fact] + public async Task DeviceSaveShouldUpdateRevisionDateAndPushRegistration() { - [Fact] - public async Task DeviceSaveShouldUpdateRevisionDateAndPushRegistration() + var deviceRepo = Substitute.For(); + var pushRepo = Substitute.For(); + var deviceService = new DeviceService(deviceRepo, pushRepo); + + var id = Guid.NewGuid(); + var userId = Guid.NewGuid(); + var device = new Device { - var deviceRepo = Substitute.For(); - var pushRepo = Substitute.For(); - var deviceService = new DeviceService(deviceRepo, pushRepo); + Id = id, + Name = "test device", + Type = DeviceType.Android, + UserId = userId, + PushToken = "testtoken", + Identifier = "testid" + }; + await deviceService.SaveAsync(device); - var id = Guid.NewGuid(); - var userId = Guid.NewGuid(); - var device = new Device - { - Id = id, - Name = "test device", - Type = DeviceType.Android, - UserId = userId, - PushToken = "testtoken", - Identifier = "testid" - }; - await deviceService.SaveAsync(device); - - Assert.True(device.RevisionDate - DateTime.UtcNow < TimeSpan.FromSeconds(1)); - await pushRepo.Received().CreateOrUpdateRegistrationAsync("testtoken", id.ToString(), - userId.ToString(), "testid", DeviceType.Android); - } + Assert.True(device.RevisionDate - DateTime.UtcNow < TimeSpan.FromSeconds(1)); + await pushRepo.Received().CreateOrUpdateRegistrationAsync("testtoken", id.ToString(), + userId.ToString(), "testid", DeviceType.Android); } } diff --git a/test/Core.Test/Services/EmergencyAccessServiceTests.cs b/test/Core.Test/Services/EmergencyAccessServiceTests.cs index bdbb6953b..6f8576d8e 100644 --- a/test/Core.Test/Services/EmergencyAccessServiceTests.cs +++ b/test/Core.Test/Services/EmergencyAccessServiceTests.cs @@ -9,163 +9,162 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class EmergencyAccessServiceTests { - public class EmergencyAccessServiceTests + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_PremiumCannotUpdate( + SutProvider sutProvider, User savingUser) { - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_PremiumCannotUpdate( - SutProvider sutProvider, User savingUser) + savingUser.Premium = false; + var emergencyAccess = new EmergencyAccess { - savingUser.Premium = false; - var emergencyAccess = new EmergencyAccess - { - Type = Enums.EmergencyAccessType.Takeover, - GrantorId = savingUser.Id, - }; + Type = Enums.EmergencyAccessType.Takeover, + GrantorId = savingUser.Id, + }; - sutProvider.GetDependency().GetUserByIdAsync(savingUser.Id).Returns(savingUser); + sutProvider.GetDependency().GetUserByIdAsync(savingUser.Id).Returns(savingUser); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(emergencyAccess, savingUser)); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(emergencyAccess, savingUser)); - Assert.Contains("Not a premium user.", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); - } + Assert.Contains("Not a premium user.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task InviteAsync_UserWithKeyConnectorCannotUseTakeover( - SutProvider sutProvider, User invitingUser, string email, int waitTime) + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task InviteAsync_UserWithKeyConnectorCannotUseTakeover( + SutProvider sutProvider, User invitingUser, string email, int waitTime) + { + invitingUser.UsesKeyConnector = true; + sutProvider.GetDependency().CanAccessPremium(invitingUser).Returns(true); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.InviteAsync(invitingUser, email, Enums.EmergencyAccessType.Takeover, waitTime)); + + Assert.Contains("You cannot use Emergency Access Takeover because you are using Key Connector", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ConfirmUserAsync_UserWithKeyConnectorCannotUseTakeover( + SutProvider sutProvider, User confirmingUser, string key) + { + confirmingUser.UsesKeyConnector = true; + var emergencyAccess = new EmergencyAccess { - invitingUser.UsesKeyConnector = true; - sutProvider.GetDependency().CanAccessPremium(invitingUser).Returns(true); + Status = Enums.EmergencyAccessStatusType.Accepted, + GrantorId = confirmingUser.Id, + Type = Enums.EmergencyAccessType.Takeover, + }; - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.InviteAsync(invitingUser, email, Enums.EmergencyAccessType.Takeover, waitTime)); + sutProvider.GetDependency().GetByIdAsync(confirmingUser.Id).Returns(confirmingUser); + sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); - Assert.Contains("You cannot use Emergency Access Takeover because you are using Key Connector", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); - } + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(new Guid(), key, confirmingUser.Id)); - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ConfirmUserAsync_UserWithKeyConnectorCannotUseTakeover( - SutProvider sutProvider, User confirmingUser, string key) + Assert.Contains("You cannot use Emergency Access Takeover because you are using Key Connector", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_UserWithKeyConnectorCannotUseTakeover( + SutProvider sutProvider, User savingUser) + { + savingUser.UsesKeyConnector = true; + var emergencyAccess = new EmergencyAccess { - confirmingUser.UsesKeyConnector = true; - var emergencyAccess = new EmergencyAccess - { - Status = Enums.EmergencyAccessStatusType.Accepted, - GrantorId = confirmingUser.Id, - Type = Enums.EmergencyAccessType.Takeover, - }; + Type = Enums.EmergencyAccessType.Takeover, + GrantorId = savingUser.Id, + }; - sutProvider.GetDependency().GetByIdAsync(confirmingUser.Id).Returns(confirmingUser); - sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); + var userService = sutProvider.GetDependency(); + userService.GetUserByIdAsync(savingUser.Id).Returns(savingUser); + userService.CanAccessPremium(savingUser).Returns(true); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.ConfirmUserAsync(new Guid(), key, confirmingUser.Id)); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(emergencyAccess, savingUser)); - Assert.Contains("You cannot use Emergency Access Takeover because you are using Key Connector", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); - } + Assert.Contains("You cannot use Emergency Access Takeover because you are using Key Connector", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_UserWithKeyConnectorCannotUseTakeover( - SutProvider sutProvider, User savingUser) + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task InitiateAsync_UserWithKeyConnectorCannotUseTakeover( + SutProvider sutProvider, User initiatingUser, User grantor) + { + grantor.UsesKeyConnector = true; + var emergencyAccess = new EmergencyAccess { - savingUser.UsesKeyConnector = true; - var emergencyAccess = new EmergencyAccess - { - Type = Enums.EmergencyAccessType.Takeover, - GrantorId = savingUser.Id, - }; + Status = Enums.EmergencyAccessStatusType.Confirmed, + GranteeId = initiatingUser.Id, + GrantorId = grantor.Id, + Type = Enums.EmergencyAccessType.Takeover, + }; - var userService = sutProvider.GetDependency(); - userService.GetUserByIdAsync(savingUser.Id).Returns(savingUser); - userService.CanAccessPremium(savingUser).Returns(true); + sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); + sutProvider.GetDependency().GetByIdAsync(grantor.Id).Returns(grantor); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(emergencyAccess, savingUser)); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.InitiateAsync(new Guid(), initiatingUser)); - Assert.Contains("You cannot use Emergency Access Takeover because you are using Key Connector", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); - } + Assert.Contains("You cannot takeover an account that is using Key Connector", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task InitiateAsync_UserWithKeyConnectorCannotUseTakeover( - SutProvider sutProvider, User initiatingUser, User grantor) + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task TakeoverAsync_UserWithKeyConnectorCannotUseTakeover( + SutProvider sutProvider, User requestingUser, User grantor) + { + grantor.UsesKeyConnector = true; + var emergencyAccess = new EmergencyAccess { - grantor.UsesKeyConnector = true; - var emergencyAccess = new EmergencyAccess - { - Status = Enums.EmergencyAccessStatusType.Confirmed, - GranteeId = initiatingUser.Id, - GrantorId = grantor.Id, - Type = Enums.EmergencyAccessType.Takeover, - }; + GrantorId = grantor.Id, + GranteeId = requestingUser.Id, + Status = Enums.EmergencyAccessStatusType.RecoveryApproved, + Type = Enums.EmergencyAccessType.Takeover, + }; - sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); - sutProvider.GetDependency().GetByIdAsync(grantor.Id).Returns(grantor); + sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); + sutProvider.GetDependency().GetByIdAsync(grantor.Id).Returns(grantor); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.InitiateAsync(new Guid(), initiatingUser)); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.TakeoverAsync(new Guid(), requestingUser)); - Assert.Contains("You cannot takeover an account that is using Key Connector", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); - } + Assert.Contains("You cannot takeover an account that is using Key Connector", exception.Message); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task TakeoverAsync_UserWithKeyConnectorCannotUseTakeover( - SutProvider sutProvider, User requestingUser, User grantor) + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task PasswordAsync_Disables_2FA_Providers_And_Unknown_Device_Verification_On_The_Grantor( + SutProvider sutProvider, User requestingUser, User grantor) + { + grantor.UsesKeyConnector = true; + grantor.UnknownDeviceVerificationEnabled = true; + grantor.SetTwoFactorProviders(new Dictionary { - grantor.UsesKeyConnector = true; - var emergencyAccess = new EmergencyAccess + [TwoFactorProviderType.Email] = new TwoFactorProvider { - GrantorId = grantor.Id, - GranteeId = requestingUser.Id, - Status = Enums.EmergencyAccessStatusType.RecoveryApproved, - Type = Enums.EmergencyAccessType.Takeover, - }; - - sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); - sutProvider.GetDependency().GetByIdAsync(grantor.Id).Returns(grantor); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.TakeoverAsync(new Guid(), requestingUser)); - - Assert.Contains("You cannot takeover an account that is using Key Connector", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task PasswordAsync_Disables_2FA_Providers_And_Unknown_Device_Verification_On_The_Grantor( - SutProvider sutProvider, User requestingUser, User grantor) + MetaData = new Dictionary { ["Email"] = "asdfasf" }, + Enabled = true + } + }); + var emergencyAccess = new EmergencyAccess { - grantor.UsesKeyConnector = true; - grantor.UnknownDeviceVerificationEnabled = true; - grantor.SetTwoFactorProviders(new Dictionary - { - [TwoFactorProviderType.Email] = new TwoFactorProvider - { - MetaData = new Dictionary { ["Email"] = "asdfasf" }, - Enabled = true - } - }); - var emergencyAccess = new EmergencyAccess - { - GrantorId = grantor.Id, - GranteeId = requestingUser.Id, - Status = Enums.EmergencyAccessStatusType.RecoveryApproved, - Type = Enums.EmergencyAccessType.Takeover, - }; + GrantorId = grantor.Id, + GranteeId = requestingUser.Id, + Status = Enums.EmergencyAccessStatusType.RecoveryApproved, + Type = Enums.EmergencyAccessType.Takeover, + }; - sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); - sutProvider.GetDependency().GetByIdAsync(grantor.Id).Returns(grantor); + sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(emergencyAccess); + sutProvider.GetDependency().GetByIdAsync(grantor.Id).Returns(grantor); - await sutProvider.Sut.PasswordAsync(Guid.NewGuid(), requestingUser, "blablahash", "blablakey"); + await sutProvider.Sut.PasswordAsync(Guid.NewGuid(), requestingUser, "blablahash", "blablakey"); - Assert.False(grantor.UnknownDeviceVerificationEnabled); - Assert.Empty(grantor.GetTwoFactorProviders()); - await sutProvider.GetDependency().Received().ReplaceAsync(grantor); - } + Assert.False(grantor.UnknownDeviceVerificationEnabled); + Assert.Empty(grantor.GetTwoFactorProviders()); + await sutProvider.GetDependency().Received().ReplaceAsync(grantor); } } diff --git a/test/Core.Test/Services/EventServiceTests.cs b/test/Core.Test/Services/EventServiceTests.cs index 988d84e13..214f120b8 100644 --- a/test/Core.Test/Services/EventServiceTests.cs +++ b/test/Core.Test/Services/EventServiceTests.cs @@ -11,99 +11,98 @@ using Bit.Test.Common.Helpers; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +[SutProviderCustomize] +public class EventServiceTests { - [SutProviderCustomize] - public class EventServiceTests + public static IEnumerable InstallationIdTestCases => TestCaseHelper.GetCombinationsOfMultipleLists( + new object[] { Guid.NewGuid(), null }, + Enum.GetValues().Select(e => (object)e) + ).Select(p => p.ToArray()); + + [Theory] + [BitMemberAutoData(nameof(InstallationIdTestCases))] + public async Task LogOrganizationEvent_ProvidesInstallationId(Guid? installationId, EventType eventType, + Organization organization, SutProvider sutProvider) { - public static IEnumerable InstallationIdTestCases => TestCaseHelper.GetCombinationsOfMultipleLists( - new object[] { Guid.NewGuid(), null }, - Enum.GetValues().Select(e => (object)e) - ).Select(p => p.ToArray()); + organization.Enabled = true; + organization.UseEvents = true; - [Theory] - [BitMemberAutoData(nameof(InstallationIdTestCases))] - public async Task LogOrganizationEvent_ProvidesInstallationId(Guid? installationId, EventType eventType, - Organization organization, SutProvider sutProvider) + sutProvider.GetDependency().InstallationId.Returns(installationId); + + await sutProvider.Sut.LogOrganizationEventAsync(organization, eventType); + + await sutProvider.GetDependency().Received(1).CreateAsync(Arg.Is(e => + e.OrganizationId == organization.Id && + e.Type == eventType && + e.InstallationId == installationId)); + } + + [Theory, BitAutoData] + public async Task LogOrganizationUserEvent_LogsRequiredInfo(OrganizationUser orgUser, EventType eventType, DateTime date, + Guid actingUserId, Guid providerId, string ipAddress, DeviceType deviceType, SutProvider sutProvider) + { + var orgAbilities = new Dictionary() { - organization.Enabled = true; - organization.UseEvents = true; + {orgUser.OrganizationId, new OrganizationAbility() { UseEvents = true, Enabled = true } } + }; + sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(orgAbilities); + sutProvider.GetDependency().UserId.Returns(actingUserId); + sutProvider.GetDependency().IpAddress.Returns(ipAddress); + sutProvider.GetDependency().DeviceType.Returns(deviceType); + sutProvider.GetDependency().ProviderIdForOrg(Arg.Any()).Returns(providerId); - sutProvider.GetDependency().InstallationId.Returns(installationId); + await sutProvider.Sut.LogOrganizationUserEventAsync(orgUser, eventType, date); - await sutProvider.Sut.LogOrganizationEventAsync(organization, eventType); - - await sutProvider.GetDependency().Received(1).CreateAsync(Arg.Is(e => - e.OrganizationId == organization.Id && - e.Type == eventType && - e.InstallationId == installationId)); - } - - [Theory, BitAutoData] - public async Task LogOrganizationUserEvent_LogsRequiredInfo(OrganizationUser orgUser, EventType eventType, DateTime date, - Guid actingUserId, Guid providerId, string ipAddress, DeviceType deviceType, SutProvider sutProvider) - { - var orgAbilities = new Dictionary() + var expected = new List() { + new EventMessage() { - {orgUser.OrganizationId, new OrganizationAbility() { UseEvents = true, Enabled = true } } - }; - sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(orgAbilities); - sutProvider.GetDependency().UserId.Returns(actingUserId); - sutProvider.GetDependency().IpAddress.Returns(ipAddress); - sutProvider.GetDependency().DeviceType.Returns(deviceType); - sutProvider.GetDependency().ProviderIdForOrg(Arg.Any()).Returns(providerId); + IpAddress = ipAddress, + DeviceType = deviceType, + OrganizationId = orgUser.OrganizationId, + UserId = orgUser.UserId, + OrganizationUserId = orgUser.Id, + ProviderId = providerId, + Type = eventType, + ActingUserId = actingUserId, + Date = date + } + }; - await sutProvider.Sut.LogOrganizationUserEventAsync(orgUser, eventType, date); + await sutProvider.GetDependency().Received(1).CreateManyAsync(Arg.Is(AssertHelper.AssertPropertyEqual(expected, new[] { "IdempotencyId" }))); + } - var expected = new List() { - new EventMessage() - { - IpAddress = ipAddress, - DeviceType = deviceType, - OrganizationId = orgUser.OrganizationId, - UserId = orgUser.UserId, - OrganizationUserId = orgUser.Id, - ProviderId = providerId, - Type = eventType, - ActingUserId = actingUserId, - Date = date - } - }; - - await sutProvider.GetDependency().Received(1).CreateManyAsync(Arg.Is(AssertHelper.AssertPropertyEqual(expected, new[] { "IdempotencyId" }))); - } - - [Theory, BitAutoData] - public async Task LogProviderUserEvent_LogsRequiredInfo(ProviderUser providerUser, EventType eventType, DateTime date, - Guid actingUserId, Guid providerId, string ipAddress, DeviceType deviceType, SutProvider sutProvider) + [Theory, BitAutoData] + public async Task LogProviderUserEvent_LogsRequiredInfo(ProviderUser providerUser, EventType eventType, DateTime date, + Guid actingUserId, Guid providerId, string ipAddress, DeviceType deviceType, SutProvider sutProvider) + { + var providerAbilities = new Dictionary() { - var providerAbilities = new Dictionary() + {providerUser.ProviderId, new ProviderAbility() { UseEvents = true, Enabled = true } } + }; + sutProvider.GetDependency().GetProviderAbilitiesAsync().Returns(providerAbilities); + sutProvider.GetDependency().UserId.Returns(actingUserId); + sutProvider.GetDependency().IpAddress.Returns(ipAddress); + sutProvider.GetDependency().DeviceType.Returns(deviceType); + sutProvider.GetDependency().ProviderIdForOrg(Arg.Any()).Returns(providerId); + + await sutProvider.Sut.LogProviderUserEventAsync(providerUser, eventType, date); + + var expected = new List() { + new EventMessage() { - {providerUser.ProviderId, new ProviderAbility() { UseEvents = true, Enabled = true } } - }; - sutProvider.GetDependency().GetProviderAbilitiesAsync().Returns(providerAbilities); - sutProvider.GetDependency().UserId.Returns(actingUserId); - sutProvider.GetDependency().IpAddress.Returns(ipAddress); - sutProvider.GetDependency().DeviceType.Returns(deviceType); - sutProvider.GetDependency().ProviderIdForOrg(Arg.Any()).Returns(providerId); + IpAddress = ipAddress, + DeviceType = deviceType, + ProviderId = providerUser.ProviderId, + UserId = providerUser.UserId, + ProviderUserId = providerUser.Id, + Type = eventType, + ActingUserId = actingUserId, + Date = date + } + }; - await sutProvider.Sut.LogProviderUserEventAsync(providerUser, eventType, date); - - var expected = new List() { - new EventMessage() - { - IpAddress = ipAddress, - DeviceType = deviceType, - ProviderId = providerUser.ProviderId, - UserId = providerUser.UserId, - ProviderUserId = providerUser.Id, - Type = eventType, - ActingUserId = actingUserId, - Date = date - } - }; - - await sutProvider.GetDependency().Received(1).CreateManyAsync(Arg.Is(AssertHelper.AssertPropertyEqual(expected, new[] { "IdempotencyId" }))); - } + await sutProvider.GetDependency().Received(1).CreateManyAsync(Arg.Is(AssertHelper.AssertPropertyEqual(expected, new[] { "IdempotencyId" }))); } } diff --git a/test/Core.Test/Services/GroupServiceTests.cs b/test/Core.Test/Services/GroupServiceTests.cs index 84f11cbb1..04aad9726 100644 --- a/test/Core.Test/Services/GroupServiceTests.cs +++ b/test/Core.Test/Services/GroupServiceTests.cs @@ -10,128 +10,127 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class GroupServiceTests { - public class GroupServiceTests + [Theory, GroupOrganizationAutoData] + public async Task SaveAsync_DefaultGroupId_CreatesGroupInRepository(Group group, Organization organization, SutProvider sutProvider) { - [Theory, GroupOrganizationAutoData] - public async Task SaveAsync_DefaultGroupId_CreatesGroupInRepository(Group group, Organization organization, SutProvider sutProvider) - { - group.Id = default(Guid); - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - organization.UseGroups = true; - var utcNow = DateTime.UtcNow; + group.Id = default(Guid); + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + organization.UseGroups = true; + var utcNow = DateTime.UtcNow; - await sutProvider.Sut.SaveAsync(group); + await sutProvider.Sut.SaveAsync(group); - await sutProvider.GetDependency().Received().CreateAsync(group); - await sutProvider.GetDependency().Received() - .LogGroupEventAsync(group, EventType.Group_Created); - Assert.True(group.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(group.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } + await sutProvider.GetDependency().Received().CreateAsync(group); + await sutProvider.GetDependency().Received() + .LogGroupEventAsync(group, EventType.Group_Created); + Assert.True(group.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(group.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } - [Theory, GroupOrganizationAutoData] - public async Task SaveAsync_DefaultGroupIdAndCollections_CreatesGroupInRepository(Group group, Organization organization, List collections, SutProvider sutProvider) - { - group.Id = default(Guid); - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - organization.UseGroups = true; - var utcNow = DateTime.UtcNow; + [Theory, GroupOrganizationAutoData] + public async Task SaveAsync_DefaultGroupIdAndCollections_CreatesGroupInRepository(Group group, Organization organization, List collections, SutProvider sutProvider) + { + group.Id = default(Guid); + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + organization.UseGroups = true; + var utcNow = DateTime.UtcNow; - await sutProvider.Sut.SaveAsync(group, collections); + await sutProvider.Sut.SaveAsync(group, collections); - await sutProvider.GetDependency().Received().CreateAsync(group, collections); - await sutProvider.GetDependency().Received() - .LogGroupEventAsync(group, EventType.Group_Created); - Assert.True(group.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(group.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } + await sutProvider.GetDependency().Received().CreateAsync(group, collections); + await sutProvider.GetDependency().Received() + .LogGroupEventAsync(group, EventType.Group_Created); + Assert.True(group.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(group.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } - [Theory, GroupOrganizationAutoData] - public async Task SaveAsync_NonDefaultGroupId_ReplaceGroupInRepository(Group group, Organization organization, List collections, SutProvider sutProvider) - { - organization.UseGroups = true; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + [Theory, GroupOrganizationAutoData] + public async Task SaveAsync_NonDefaultGroupId_ReplaceGroupInRepository(Group group, Organization organization, List collections, SutProvider sutProvider) + { + organization.UseGroups = true; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - await sutProvider.Sut.SaveAsync(group, collections); + await sutProvider.Sut.SaveAsync(group, collections); - await sutProvider.GetDependency().Received().ReplaceAsync(group, collections); - await sutProvider.GetDependency().Received() - .LogGroupEventAsync(group, EventType.Group_Updated); - Assert.True(group.RevisionDate - DateTime.UtcNow < TimeSpan.FromSeconds(1)); - } + await sutProvider.GetDependency().Received().ReplaceAsync(group, collections); + await sutProvider.GetDependency().Received() + .LogGroupEventAsync(group, EventType.Group_Updated); + Assert.True(group.RevisionDate - DateTime.UtcNow < TimeSpan.FromSeconds(1)); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_NonExistingOrganizationId_ThrowsBadRequest(Group group, SutProvider sutProvider) - { - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(group)); - Assert.Contains("Organization not found", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogGroupEventAsync(default, default, default); - } + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_NonExistingOrganizationId_ThrowsBadRequest(Group group, SutProvider sutProvider) + { + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(group)); + Assert.Contains("Organization not found", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogGroupEventAsync(default, default, default); + } - [Theory, GroupOrganizationNotUseGroupsAutoData] - public async Task SaveAsync_OrganizationDoesNotUseGroups_ThrowsBadRequest(Group group, Organization organization, SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + [Theory, GroupOrganizationNotUseGroupsAutoData] + public async Task SaveAsync_OrganizationDoesNotUseGroups_ThrowsBadRequest(Group group, Organization organization, SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(group)); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(group)); - Assert.Contains("This organization cannot use groups", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogGroupEventAsync(default, default, default); - } + Assert.Contains("This organization cannot use groups", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().CreateAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().LogGroupEventAsync(default, default, default); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteAsync_ValidData_DeletesGroup(Group group, SutProvider sutProvider) - { - await sutProvider.Sut.DeleteAsync(group); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteAsync_ValidData_DeletesGroup(Group group, SutProvider sutProvider) + { + await sutProvider.Sut.DeleteAsync(group); - await sutProvider.GetDependency().Received().DeleteAsync(group); - await sutProvider.GetDependency().Received() - .LogGroupEventAsync(group, EventType.Group_Deleted); - } + await sutProvider.GetDependency().Received().DeleteAsync(group); + await sutProvider.GetDependency().Received() + .LogGroupEventAsync(group, EventType.Group_Deleted); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUserAsync_ValidData_DeletesUserInGroupRepository(Group group, Organization organization, OrganizationUser organizationUser, SutProvider sutProvider) - { - group.OrganizationId = organization.Id; - organization.UseGroups = true; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - organizationUser.OrganizationId = organization.Id; - sutProvider.GetDependency().GetByIdAsync(organizationUser.Id) - .Returns(organizationUser); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUserAsync_ValidData_DeletesUserInGroupRepository(Group group, Organization organization, OrganizationUser organizationUser, SutProvider sutProvider) + { + group.OrganizationId = organization.Id; + organization.UseGroups = true; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + organizationUser.OrganizationId = organization.Id; + sutProvider.GetDependency().GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); - await sutProvider.Sut.DeleteUserAsync(group, organizationUser.Id); + await sutProvider.Sut.DeleteUserAsync(group, organizationUser.Id); - await sutProvider.GetDependency().Received().DeleteUserAsync(group.Id, organizationUser.Id); - await sutProvider.GetDependency().Received() - .LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_UpdatedGroups); - } + await sutProvider.GetDependency().Received().DeleteUserAsync(group.Id, organizationUser.Id); + await sutProvider.GetDependency().Received() + .LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_UpdatedGroups); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUserAsync_InvalidUser_ThrowsNotFound(Group group, Organization organization, OrganizationUser organizationUser, SutProvider sutProvider) - { - group.OrganizationId = organization.Id; - organization.UseGroups = true; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - // organizationUser.OrganizationId = organization.Id; - sutProvider.GetDependency().GetByIdAsync(organizationUser.Id) - .Returns(organizationUser); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUserAsync_InvalidUser_ThrowsNotFound(Group group, Organization organization, OrganizationUser organizationUser, SutProvider sutProvider) + { + group.OrganizationId = organization.Id; + organization.UseGroups = true; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + // organizationUser.OrganizationId = organization.Id; + sutProvider.GetDependency().GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); - // user not in organization - await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteUserAsync(group, organizationUser.Id)); - // invalid user - await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteUserAsync(group, Guid.NewGuid())); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .DeleteUserAsync(default, default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .LogOrganizationUserEventAsync(default, default); - } + // user not in organization + await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteUserAsync(group, organizationUser.Id)); + // invalid user + await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteUserAsync(group, Guid.NewGuid())); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .DeleteUserAsync(default, default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .LogOrganizationUserEventAsync(default, default); } } diff --git a/test/Core.Test/Services/HandlebarsMailServiceTests.cs b/test/Core.Test/Services/HandlebarsMailServiceTests.cs index 39348ad8e..5127eb2b4 100644 --- a/test/Core.Test/Services/HandlebarsMailServiceTests.cs +++ b/test/Core.Test/Services/HandlebarsMailServiceTests.cs @@ -8,169 +8,168 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class HandlebarsMailServiceTests { - public class HandlebarsMailServiceTests + private readonly HandlebarsMailService _sut; + + private readonly GlobalSettings _globalSettings; + private readonly IMailDeliveryService _mailDeliveryService; + private readonly IMailEnqueuingService _mailEnqueuingService; + + public HandlebarsMailServiceTests() { - private readonly HandlebarsMailService _sut; + _globalSettings = new GlobalSettings(); + _mailDeliveryService = Substitute.For(); + _mailEnqueuingService = Substitute.For(); - private readonly GlobalSettings _globalSettings; - private readonly IMailDeliveryService _mailDeliveryService; - private readonly IMailEnqueuingService _mailEnqueuingService; + _sut = new HandlebarsMailService( + _globalSettings, + _mailDeliveryService, + _mailEnqueuingService + ); + } - public HandlebarsMailServiceTests() + [Fact(Skip = "For local development")] + public async Task SendAllEmails() + { + // This test is only opt in and is more for development purposes. + // This will send all emails to the test email address so that they can be viewed. + var namedParameters = new Dictionary<(string, Type), object> { - _globalSettings = new GlobalSettings(); - _mailDeliveryService = Substitute.For(); - _mailEnqueuingService = Substitute.For(); + // TODO: Swith to use env variable + { ("email", typeof(string)), "test@bitwarden.com" }, + { ("user", typeof(User)), new User + { + Id = Guid.NewGuid(), + Email = "test@bitwarden.com", + }}, + { ("userId", typeof(Guid)), Guid.NewGuid() }, + { ("token", typeof(string)), "test_token" }, + { ("fromEmail", typeof(string)), "test@bitwarden.com" }, + { ("toEmail", typeof(string)), "test@bitwarden.com" }, + { ("newEmailAddress", typeof(string)), "test@bitwarden.com" }, + { ("hint", typeof(string)), "Test Hint" }, + { ("organizationName", typeof(string)), "Test Organization Name" }, + { ("orgUser", typeof(OrganizationUser)), new OrganizationUser + { + Id = Guid.NewGuid(), + Email = "test@bitwarden.com", + OrganizationId = Guid.NewGuid(), - _sut = new HandlebarsMailService( - _globalSettings, - _mailDeliveryService, - _mailEnqueuingService - ); + }}, + { ("token", typeof(ExpiringToken)), new ExpiringToken("test_token", DateTime.UtcNow.AddDays(1))}, + { ("organization", typeof(Organization)), new Organization + { + Id = Guid.NewGuid(), + Name = "Test Organization Name", + Seats = 5 + }}, + { ("initialSeatCount", typeof(int)), 5}, + { ("ownerEmails", typeof(IEnumerable)), new [] { "test@bitwarden.com" }}, + { ("maxSeatCount", typeof(int)), 5 }, + { ("userIdentifier", typeof(string)), "test_user" }, + { ("adminEmails", typeof(IEnumerable)), new [] { "test@bitwarden.com" }}, + { ("returnUrl", typeof(string)), "https://bitwarden.com/" }, + { ("amount", typeof(decimal)), 1.00M }, + { ("dueDate", typeof(DateTime)), DateTime.UtcNow.AddDays(1) }, + { ("items", typeof(List)), new List { "test@bitwarden.com" }}, + { ("mentionInvoices", typeof(bool)), true }, + { ("emails", typeof(IEnumerable)), new [] { "test@bitwarden.com" }}, + { ("deviceType", typeof(string)), "Mobile" }, + { ("timestamp", typeof(DateTime)), DateTime.UtcNow.AddDays(1)}, + { ("ip", typeof(string)), "127.0.0.1" }, + { ("emergencyAccess", typeof(EmergencyAccess)), new EmergencyAccess + { + Id = Guid.NewGuid(), + Email = "test@bitwarden.com", + }}, + { ("granteeEmail", typeof(string)), "test@bitwarden.com" }, + { ("grantorName", typeof(string)), "Test User" }, + { ("initiatingName", typeof(string)), "Test" }, + { ("approvingName", typeof(string)), "Test Name" }, + { ("rejectingName", typeof(string)), "Test Name" }, + { ("provider", typeof(Provider)), new Provider + { + Id = Guid.NewGuid(), + }}, + { ("name", typeof(string)), "Test Name" }, + { ("ea", typeof(EmergencyAccess)), new EmergencyAccess + { + Id = Guid.NewGuid(), + Email = "test@bitwarden.com", + }}, + { ("userName", typeof(string)), "testUser" }, + { ("orgName", typeof(string)), "Test Org Name" }, + { ("providerName", typeof(string)), "testProvider" }, + { ("providerUser", typeof(ProviderUser)), new ProviderUser + { + ProviderId = Guid.NewGuid(), + Id = Guid.NewGuid(), + }}, + { ("familyUserEmail", typeof(string)), "test@bitwarden.com" }, + { ("sponsorEmail", typeof(string)), "test@bitwarden.com" }, + { ("familyOrgName", typeof(string)), "Test Org Name" }, + // Swap existingAccount to true or false to generate different versions of the SendFamiliesForEnterpriseOfferEmailAsync emails. + { ("existingAccount", typeof(bool)), false }, + { ("sponsorshipEndDate", typeof(DateTime)), DateTime.UtcNow.AddDays(1)}, + { ("sponsorOrgName", typeof(string)), "Sponsor Test Org Name" }, + { ("expirationDate", typeof(DateTime)), DateTime.Now.AddDays(3) }, + { ("utcNow", typeof(DateTime)), DateTime.UtcNow }, + }; + + var globalSettings = new GlobalSettings + { + Mail = new GlobalSettings.MailSettings + { + Smtp = new GlobalSettings.MailSettings.SmtpSettings + { + Host = "localhost", + TrustServer = true, + Port = 10250, + }, + ReplyToEmail = "noreply@bitwarden.com", + }, + SiteName = "Bitwarden", + }; + + var mailDeliveryService = new MailKitSmtpMailDeliveryService(globalSettings, Substitute.For>()); + + var handlebarsService = new HandlebarsMailService(globalSettings, mailDeliveryService, new BlockingMailEnqueuingService()); + + var sendMethods = typeof(IMailService).GetMethods(BindingFlags.Public | BindingFlags.Instance) + .Where(m => m.Name.StartsWith("Send") && m.Name != "SendEnqueuedMailMessageAsync"); + + foreach (var sendMethod in sendMethods) + { + await InvokeMethod(sendMethod); } - [Fact(Skip = "For local development")] - public async Task SendAllEmails() + async Task InvokeMethod(MethodInfo method) { - // This test is only opt in and is more for development purposes. - // This will send all emails to the test email address so that they can be viewed. - var namedParameters = new Dictionary<(string, Type), object> + var parameters = method.GetParameters(); + var args = new object[parameters.Length]; + + for (var i = 0; i < parameters.Length; i++) { - // TODO: Swith to use env variable - { ("email", typeof(string)), "test@bitwarden.com" }, - { ("user", typeof(User)), new User + if (!namedParameters.TryGetValue((parameters[i].Name, parameters[i].ParameterType), out var value)) { - Id = Guid.NewGuid(), - Email = "test@bitwarden.com", - }}, - { ("userId", typeof(Guid)), Guid.NewGuid() }, - { ("token", typeof(string)), "test_token" }, - { ("fromEmail", typeof(string)), "test@bitwarden.com" }, - { ("toEmail", typeof(string)), "test@bitwarden.com" }, - { ("newEmailAddress", typeof(string)), "test@bitwarden.com" }, - { ("hint", typeof(string)), "Test Hint" }, - { ("organizationName", typeof(string)), "Test Organization Name" }, - { ("orgUser", typeof(OrganizationUser)), new OrganizationUser - { - Id = Guid.NewGuid(), - Email = "test@bitwarden.com", - OrganizationId = Guid.NewGuid(), - - }}, - { ("token", typeof(ExpiringToken)), new ExpiringToken("test_token", DateTime.UtcNow.AddDays(1))}, - { ("organization", typeof(Organization)), new Organization - { - Id = Guid.NewGuid(), - Name = "Test Organization Name", - Seats = 5 - }}, - { ("initialSeatCount", typeof(int)), 5}, - { ("ownerEmails", typeof(IEnumerable)), new [] { "test@bitwarden.com" }}, - { ("maxSeatCount", typeof(int)), 5 }, - { ("userIdentifier", typeof(string)), "test_user" }, - { ("adminEmails", typeof(IEnumerable)), new [] { "test@bitwarden.com" }}, - { ("returnUrl", typeof(string)), "https://bitwarden.com/" }, - { ("amount", typeof(decimal)), 1.00M }, - { ("dueDate", typeof(DateTime)), DateTime.UtcNow.AddDays(1) }, - { ("items", typeof(List)), new List { "test@bitwarden.com" }}, - { ("mentionInvoices", typeof(bool)), true }, - { ("emails", typeof(IEnumerable)), new [] { "test@bitwarden.com" }}, - { ("deviceType", typeof(string)), "Mobile" }, - { ("timestamp", typeof(DateTime)), DateTime.UtcNow.AddDays(1)}, - { ("ip", typeof(string)), "127.0.0.1" }, - { ("emergencyAccess", typeof(EmergencyAccess)), new EmergencyAccess - { - Id = Guid.NewGuid(), - Email = "test@bitwarden.com", - }}, - { ("granteeEmail", typeof(string)), "test@bitwarden.com" }, - { ("grantorName", typeof(string)), "Test User" }, - { ("initiatingName", typeof(string)), "Test" }, - { ("approvingName", typeof(string)), "Test Name" }, - { ("rejectingName", typeof(string)), "Test Name" }, - { ("provider", typeof(Provider)), new Provider - { - Id = Guid.NewGuid(), - }}, - { ("name", typeof(string)), "Test Name" }, - { ("ea", typeof(EmergencyAccess)), new EmergencyAccess - { - Id = Guid.NewGuid(), - Email = "test@bitwarden.com", - }}, - { ("userName", typeof(string)), "testUser" }, - { ("orgName", typeof(string)), "Test Org Name" }, - { ("providerName", typeof(string)), "testProvider" }, - { ("providerUser", typeof(ProviderUser)), new ProviderUser - { - ProviderId = Guid.NewGuid(), - Id = Guid.NewGuid(), - }}, - { ("familyUserEmail", typeof(string)), "test@bitwarden.com" }, - { ("sponsorEmail", typeof(string)), "test@bitwarden.com" }, - { ("familyOrgName", typeof(string)), "Test Org Name" }, - // Swap existingAccount to true or false to generate different versions of the SendFamiliesForEnterpriseOfferEmailAsync emails. - { ("existingAccount", typeof(bool)), false }, - { ("sponsorshipEndDate", typeof(DateTime)), DateTime.UtcNow.AddDays(1)}, - { ("sponsorOrgName", typeof(string)), "Sponsor Test Org Name" }, - { ("expirationDate", typeof(DateTime)), DateTime.Now.AddDays(3) }, - { ("utcNow", typeof(DateTime)), DateTime.UtcNow }, - }; - - var globalSettings = new GlobalSettings - { - Mail = new GlobalSettings.MailSettings - { - Smtp = new GlobalSettings.MailSettings.SmtpSettings - { - Host = "localhost", - TrustServer = true, - Port = 10250, - }, - ReplyToEmail = "noreply@bitwarden.com", - }, - SiteName = "Bitwarden", - }; - - var mailDeliveryService = new MailKitSmtpMailDeliveryService(globalSettings, Substitute.For>()); - - var handlebarsService = new HandlebarsMailService(globalSettings, mailDeliveryService, new BlockingMailEnqueuingService()); - - var sendMethods = typeof(IMailService).GetMethods(BindingFlags.Public | BindingFlags.Instance) - .Where(m => m.Name.StartsWith("Send") && m.Name != "SendEnqueuedMailMessageAsync"); - - foreach (var sendMethod in sendMethods) - { - await InvokeMethod(sendMethod); - } - - async Task InvokeMethod(MethodInfo method) - { - var parameters = method.GetParameters(); - var args = new object[parameters.Length]; - - for (var i = 0; i < parameters.Length; i++) - { - if (!namedParameters.TryGetValue((parameters[i].Name, parameters[i].ParameterType), out var value)) - { - throw new InvalidOperationException($"Couldn't find a parameter for name '{parameters[i].Name}' and type '{parameters[i].ParameterType.FullName}'"); - } - - args[i] = value; + throw new InvalidOperationException($"Couldn't find a parameter for name '{parameters[i].Name}' and type '{parameters[i].ParameterType.FullName}'"); } - await (Task)method.Invoke(handlebarsService, args); + args[i] = value; } - } - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact] - public void ServiceExists() - { - Assert.NotNull(_sut); + await (Task)method.Invoke(handlebarsService, args); } } + + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact] + public void ServiceExists() + { + Assert.NotNull(_sut); + } } diff --git a/test/Core.Test/Services/InMemoryApplicationCacheServiceTests.cs b/test/Core.Test/Services/InMemoryApplicationCacheServiceTests.cs index 8deae6364..ff8e734b3 100644 --- a/test/Core.Test/Services/InMemoryApplicationCacheServiceTests.cs +++ b/test/Core.Test/Services/InMemoryApplicationCacheServiceTests.cs @@ -3,29 +3,28 @@ using Bit.Core.Services; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class InMemoryApplicationCacheServiceTests { - public class InMemoryApplicationCacheServiceTests + private readonly InMemoryApplicationCacheService _sut; + + private readonly IOrganizationRepository _organizationRepository; + private readonly IProviderRepository _providerRepository; + + public InMemoryApplicationCacheServiceTests() { - private readonly InMemoryApplicationCacheService _sut; + _organizationRepository = Substitute.For(); + _providerRepository = Substitute.For(); - private readonly IOrganizationRepository _organizationRepository; - private readonly IProviderRepository _providerRepository; + _sut = new InMemoryApplicationCacheService(_organizationRepository, _providerRepository); + } - public InMemoryApplicationCacheServiceTests() - { - _organizationRepository = Substitute.For(); - _providerRepository = Substitute.For(); - - _sut = new InMemoryApplicationCacheService(_organizationRepository, _providerRepository); - } - - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact] - public void ServiceExists() - { - Assert.NotNull(_sut); - } + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact] + public void ServiceExists() + { + Assert.NotNull(_sut); } } diff --git a/test/Core.Test/Services/InMemoryServiceBusApplicationCacheServiceTests.cs b/test/Core.Test/Services/InMemoryServiceBusApplicationCacheServiceTests.cs index 33f23ea18..f74aa6f50 100644 --- a/test/Core.Test/Services/InMemoryServiceBusApplicationCacheServiceTests.cs +++ b/test/Core.Test/Services/InMemoryServiceBusApplicationCacheServiceTests.cs @@ -4,35 +4,34 @@ using Bit.Core.Settings; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class InMemoryServiceBusApplicationCacheServiceTests { - public class InMemoryServiceBusApplicationCacheServiceTests + private readonly InMemoryServiceBusApplicationCacheService _sut; + + private readonly IOrganizationRepository _organizationRepository; + private readonly IProviderRepository _providerRepository; + private readonly GlobalSettings _globalSettings; + + public InMemoryServiceBusApplicationCacheServiceTests() { - private readonly InMemoryServiceBusApplicationCacheService _sut; + _organizationRepository = Substitute.For(); + _providerRepository = Substitute.For(); + _globalSettings = new GlobalSettings(); - private readonly IOrganizationRepository _organizationRepository; - private readonly IProviderRepository _providerRepository; - private readonly GlobalSettings _globalSettings; + _sut = new InMemoryServiceBusApplicationCacheService( + _organizationRepository, + _providerRepository, + _globalSettings + ); + } - public InMemoryServiceBusApplicationCacheServiceTests() - { - _organizationRepository = Substitute.For(); - _providerRepository = Substitute.For(); - _globalSettings = new GlobalSettings(); - - _sut = new InMemoryServiceBusApplicationCacheService( - _organizationRepository, - _providerRepository, - _globalSettings - ); - } - - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); - } + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); } } diff --git a/test/Core.Test/Services/LicensingServiceTests.cs b/test/Core.Test/Services/LicensingServiceTests.cs index 2e94ef2b5..4a8ba0255 100644 --- a/test/Core.Test/Services/LicensingServiceTests.cs +++ b/test/Core.Test/Services/LicensingServiceTests.cs @@ -9,53 +9,52 @@ using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +[SutProviderCustomize] +public class LicensingServiceTests { - [SutProviderCustomize] - public class LicensingServiceTests + private static string licenseFilePath(Guid orgId) => + Path.Combine(OrganizationLicenseDirectory.Value, $"{orgId}.json"); + private static string LicenseDirectory => Path.GetDirectoryName(OrganizationLicenseDirectory.Value); + private static Lazy OrganizationLicenseDirectory => new(() => { - private static string licenseFilePath(Guid orgId) => - Path.Combine(OrganizationLicenseDirectory.Value, $"{orgId}.json"); - private static string LicenseDirectory => Path.GetDirectoryName(OrganizationLicenseDirectory.Value); - private static Lazy OrganizationLicenseDirectory => new(() => + var directory = Path.Combine(Path.GetTempPath(), "organization"); + if (!Directory.Exists(directory)) { - var directory = Path.Combine(Path.GetTempPath(), "organization"); - if (!Directory.Exists(directory)) - { - Directory.CreateDirectory(directory); - } - return directory; - }); - - public static SutProvider GetSutProvider() - { - var fixture = new Fixture().WithAutoNSubstitutions(); - - var settings = fixture.Create(); - settings.LicenseDirectory = LicenseDirectory; - settings.SelfHosted = true; - - return new SutProvider(fixture) - .SetDependency(settings) - .Create(); + Directory.CreateDirectory(directory); } + return directory; + }); - [Theory, BitAutoData, OrganizationLicenseCustomize] - public async Task ReadOrganizationLicense(Organization organization, OrganizationLicense license) + public static SutProvider GetSutProvider() + { + var fixture = new Fixture().WithAutoNSubstitutions(); + + var settings = fixture.Create(); + settings.LicenseDirectory = LicenseDirectory; + settings.SelfHosted = true; + + return new SutProvider(fixture) + .SetDependency(settings) + .Create(); + } + + [Theory, BitAutoData, OrganizationLicenseCustomize] + public async Task ReadOrganizationLicense(Organization organization, OrganizationLicense license) + { + var sutProvider = GetSutProvider(); + + File.WriteAllText(licenseFilePath(organization.Id), JsonSerializer.Serialize(license)); + + var actual = await sutProvider.Sut.ReadOrganizationLicenseAsync(organization); + try { - var sutProvider = GetSutProvider(); - - File.WriteAllText(licenseFilePath(organization.Id), JsonSerializer.Serialize(license)); - - var actual = await sutProvider.Sut.ReadOrganizationLicenseAsync(organization); - try - { - Assert.Equal(JsonSerializer.Serialize(license), JsonSerializer.Serialize(actual)); - } - finally - { - Directory.Delete(OrganizationLicenseDirectory.Value, true); - } + Assert.Equal(JsonSerializer.Serialize(license), JsonSerializer.Serialize(actual)); + } + finally + { + Directory.Delete(OrganizationLicenseDirectory.Value, true); } } } diff --git a/test/Core.Test/Services/LocalAttachmentStorageServiceTests.cs b/test/Core.Test/Services/LocalAttachmentStorageServiceTests.cs index 63a3e8bc8..cf05933f4 100644 --- a/test/Core.Test/Services/LocalAttachmentStorageServiceTests.cs +++ b/test/Core.Test/Services/LocalAttachmentStorageServiceTests.cs @@ -11,194 +11,75 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class LocalAttachmentStorageServiceTests { - public class LocalAttachmentStorageServiceTests + + private void AssertFileCreation(string expectedPath, string expectedFileContents) { + Assert.True(File.Exists(expectedPath)); + Assert.Equal(expectedFileContents, File.ReadAllText(expectedPath)); + } - private void AssertFileCreation(string expectedPath, string expectedFileContents) + [Theory] + [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaData) })] + [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaDataWithoutContainer) })] + [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaDataWithoutKey) })] + public async Task UploadNewAttachmentAsync_Success(string stream, Cipher cipher, CipherAttachment.MetaData attachmentData) + { + using (var tempDirectory = new TempDirectory()) { - Assert.True(File.Exists(expectedPath)); - Assert.Equal(expectedFileContents, File.ReadAllText(expectedPath)); - } + var sutProvider = GetSutProvider(tempDirectory); - [Theory] - [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaData) })] - [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaDataWithoutContainer) })] - [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaDataWithoutKey) })] - public async Task UploadNewAttachmentAsync_Success(string stream, Cipher cipher, CipherAttachment.MetaData attachmentData) + await sutProvider.Sut.UploadNewAttachmentAsync(new MemoryStream(Encoding.UTF8.GetBytes(stream)), + cipher, attachmentData); + + AssertFileCreation($"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}", stream); + } + } + + [Theory] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] + public async Task UploadShareAttachmentAsync_Success(string stream, Cipher cipher, CipherAttachment.MetaData attachmentData) + { + using (var tempDirectory = new TempDirectory()) { - using (var tempDirectory = new TempDirectory()) - { - var sutProvider = GetSutProvider(tempDirectory); + var sutProvider = GetSutProvider(tempDirectory); - await sutProvider.Sut.UploadNewAttachmentAsync(new MemoryStream(Encoding.UTF8.GetBytes(stream)), - cipher, attachmentData); + await sutProvider.Sut.UploadShareAttachmentAsync(new MemoryStream(Encoding.UTF8.GetBytes(stream)), + cipher.Id, cipher.OrganizationId.Value, attachmentData); - AssertFileCreation($"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}", stream); - } + AssertFileCreation($"{tempDirectory}/temp/{cipher.Id}/{cipher.OrganizationId}/{attachmentData.AttachmentId}", stream); } + } - [Theory] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] - public async Task UploadShareAttachmentAsync_Success(string stream, Cipher cipher, CipherAttachment.MetaData attachmentData) + [Theory] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] + public async Task StartShareAttachmentAsync_NoSource_NoWork(Cipher cipher, CipherAttachment.MetaData attachmentData) + { + using (var tempDirectory = new TempDirectory()) { - using (var tempDirectory = new TempDirectory()) - { - var sutProvider = GetSutProvider(tempDirectory); + var sutProvider = GetSutProvider(tempDirectory); - await sutProvider.Sut.UploadShareAttachmentAsync(new MemoryStream(Encoding.UTF8.GetBytes(stream)), - cipher.Id, cipher.OrganizationId.Value, attachmentData); + await sutProvider.Sut.StartShareAttachmentAsync(cipher.Id, cipher.OrganizationId.Value, attachmentData); - AssertFileCreation($"{tempDirectory}/temp/{cipher.Id}/{cipher.OrganizationId}/{attachmentData.AttachmentId}", stream); - } + Assert.False(File.Exists($"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}")); + Assert.False(File.Exists($"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}")); } + } - [Theory] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] - public async Task StartShareAttachmentAsync_NoSource_NoWork(Cipher cipher, CipherAttachment.MetaData attachmentData) - { - using (var tempDirectory = new TempDirectory()) - { - var sutProvider = GetSutProvider(tempDirectory); - - await sutProvider.Sut.StartShareAttachmentAsync(cipher.Id, cipher.OrganizationId.Value, attachmentData); - - Assert.False(File.Exists($"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}")); - Assert.False(File.Exists($"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}")); - } - } - - [Theory] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] - public async Task StartShareAttachmentAsync_NoDest_NoWork(string source, Cipher cipher, CipherAttachment.MetaData attachmentData) - { - using (var tempDirectory = new TempDirectory()) - { - var sutProvider = GetSutProvider(tempDirectory); - - var sourcePath = $"{tempDirectory}/temp/{cipher.Id}/{cipher.OrganizationId}/{attachmentData.AttachmentId}"; - var destPath = $"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}"; - var rollBackPath = $"{tempDirectory}/temp/{cipher.Id}/{attachmentData.AttachmentId}"; - Directory.CreateDirectory(Path.GetDirectoryName(sourcePath)); - File.WriteAllText(sourcePath, source); - - await sutProvider.Sut.StartShareAttachmentAsync(cipher.Id, cipher.OrganizationId.Value, attachmentData); - - Assert.True(File.Exists(sourcePath)); - Assert.Equal(source, File.ReadAllText(sourcePath)); - Assert.False(File.Exists(destPath)); - Assert.False(File.Exists(rollBackPath)); - } - } - - - [Theory] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] - public async Task StartShareAttachmentAsync_Success(string source, string destOriginal, Cipher cipher, CipherAttachment.MetaData attachmentData) - { - using (var tempDirectory = new TempDirectory()) - { - await StartShareAttachmentAsync(source, destOriginal, cipher, attachmentData, tempDirectory); - } - } - - [Theory] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] - [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] - public async Task RollbackShareAttachmentAsync_Success(string source, string destOriginal, Cipher cipher, CipherAttachment.MetaData attachmentData) - { - using (var tempDirectory = new TempDirectory()) - { - var sutProvider = GetSutProvider(tempDirectory); - - var sourcePath = $"{tempDirectory}/temp/{cipher.Id}/{cipher.OrganizationId}/{attachmentData.AttachmentId}"; - var destPath = $"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}"; - var rollBackPath = $"{tempDirectory}/temp/{cipher.Id}/{attachmentData.AttachmentId}"; - - await StartShareAttachmentAsync(source, destOriginal, cipher, attachmentData, tempDirectory); - await sutProvider.Sut.RollbackShareAttachmentAsync(cipher.Id, cipher.OrganizationId.Value, attachmentData, "Not Used Here"); - - Assert.True(File.Exists(destPath)); - Assert.Equal(destOriginal, File.ReadAllText(destPath)); - Assert.False(File.Exists(sourcePath)); - Assert.False(File.Exists(rollBackPath)); - } - } - - [Theory] - [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaData) })] - [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaDataWithoutContainer) })] - [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaDataWithoutKey) })] - public async Task DeleteAttachmentAsync_Success(Cipher cipher, CipherAttachment.MetaData attachmentData) - { - using (var tempDirectory = new TempDirectory()) - { - var sutProvider = GetSutProvider(tempDirectory); - - var expectedPath = $"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}"; - Directory.CreateDirectory(Path.GetDirectoryName(expectedPath)); - File.Create(expectedPath).Close(); - - await sutProvider.Sut.DeleteAttachmentAsync(cipher.Id, attachmentData); - - Assert.False(File.Exists(expectedPath)); - } - } - - [Theory] - [InlineUserCipherAutoData] - [InlineOrganizationCipherAutoData] - public async Task CleanupAsync_Succes(Cipher cipher) - { - using (var tempDirectory = new TempDirectory()) - { - var sutProvider = GetSutProvider(tempDirectory); - - var tempPath = $"{tempDirectory}/temp/{cipher.Id}"; - var permPath = $"{tempDirectory}/{cipher.Id}"; - Directory.CreateDirectory(tempPath); - Directory.CreateDirectory(permPath); - - await sutProvider.Sut.CleanupAsync(cipher.Id); - - Assert.False(Directory.Exists(tempPath)); - Assert.True(Directory.Exists(permPath)); - } - } - - [Theory] - [InlineUserCipherAutoData] - [InlineOrganizationCipherAutoData] - public async Task DeleteAttachmentsForCipherAsync_Succes(Cipher cipher) - { - using (var tempDirectory = new TempDirectory()) - { - var sutProvider = GetSutProvider(tempDirectory); - - var tempPath = $"{tempDirectory}/temp/{cipher.Id}"; - var permPath = $"{tempDirectory}/{cipher.Id}"; - Directory.CreateDirectory(tempPath); - Directory.CreateDirectory(permPath); - - await sutProvider.Sut.DeleteAttachmentsForCipherAsync(cipher.Id); - - Assert.True(Directory.Exists(tempPath)); - Assert.False(Directory.Exists(permPath)); - } - } - - private async Task StartShareAttachmentAsync(string source, string destOriginal, Cipher cipher, - CipherAttachment.MetaData attachmentData, TempDirectory tempDirectory) + [Theory] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] + public async Task StartShareAttachmentAsync_NoDest_NoWork(string source, Cipher cipher, CipherAttachment.MetaData attachmentData) + { + using (var tempDirectory = new TempDirectory()) { var sutProvider = GetSutProvider(tempDirectory); @@ -206,26 +87,144 @@ namespace Bit.Core.Test.Services var destPath = $"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}"; var rollBackPath = $"{tempDirectory}/temp/{cipher.Id}/{attachmentData.AttachmentId}"; Directory.CreateDirectory(Path.GetDirectoryName(sourcePath)); - Directory.CreateDirectory(Path.GetDirectoryName(destPath)); File.WriteAllText(sourcePath, source); - File.WriteAllText(destPath, destOriginal); await sutProvider.Sut.StartShareAttachmentAsync(cipher.Id, cipher.OrganizationId.Value, attachmentData); - Assert.False(File.Exists(sourcePath)); - Assert.True(File.Exists(destPath)); - Assert.Equal(source, File.ReadAllText(destPath)); - Assert.True(File.Exists(rollBackPath)); - Assert.Equal(destOriginal, File.ReadAllText(rollBackPath)); - } - - private SutProvider GetSutProvider(TempDirectory tempDirectory) - { - var fixture = new Fixture().WithAutoNSubstitutions(); - fixture.Freeze().Attachment.BaseDirectory.Returns(tempDirectory.Directory); - fixture.Freeze().Attachment.BaseUrl.Returns(Guid.NewGuid().ToString()); - - return new SutProvider(fixture).Create(); + Assert.True(File.Exists(sourcePath)); + Assert.Equal(source, File.ReadAllText(sourcePath)); + Assert.False(File.Exists(destPath)); + Assert.False(File.Exists(rollBackPath)); } } + + + [Theory] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] + public async Task StartShareAttachmentAsync_Success(string source, string destOriginal, Cipher cipher, CipherAttachment.MetaData attachmentData) + { + using (var tempDirectory = new TempDirectory()) + { + await StartShareAttachmentAsync(source, destOriginal, cipher, attachmentData, tempDirectory); + } + } + + [Theory] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaData) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutContainer) })] + [InlineCustomAutoData(new[] { typeof(OrganizationCipher), typeof(MetaDataWithoutKey) })] + public async Task RollbackShareAttachmentAsync_Success(string source, string destOriginal, Cipher cipher, CipherAttachment.MetaData attachmentData) + { + using (var tempDirectory = new TempDirectory()) + { + var sutProvider = GetSutProvider(tempDirectory); + + var sourcePath = $"{tempDirectory}/temp/{cipher.Id}/{cipher.OrganizationId}/{attachmentData.AttachmentId}"; + var destPath = $"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}"; + var rollBackPath = $"{tempDirectory}/temp/{cipher.Id}/{attachmentData.AttachmentId}"; + + await StartShareAttachmentAsync(source, destOriginal, cipher, attachmentData, tempDirectory); + await sutProvider.Sut.RollbackShareAttachmentAsync(cipher.Id, cipher.OrganizationId.Value, attachmentData, "Not Used Here"); + + Assert.True(File.Exists(destPath)); + Assert.Equal(destOriginal, File.ReadAllText(destPath)); + Assert.False(File.Exists(sourcePath)); + Assert.False(File.Exists(rollBackPath)); + } + } + + [Theory] + [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaData) })] + [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaDataWithoutContainer) })] + [InlineCustomAutoData(new[] { typeof(UserCipher), typeof(MetaDataWithoutKey) })] + public async Task DeleteAttachmentAsync_Success(Cipher cipher, CipherAttachment.MetaData attachmentData) + { + using (var tempDirectory = new TempDirectory()) + { + var sutProvider = GetSutProvider(tempDirectory); + + var expectedPath = $"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}"; + Directory.CreateDirectory(Path.GetDirectoryName(expectedPath)); + File.Create(expectedPath).Close(); + + await sutProvider.Sut.DeleteAttachmentAsync(cipher.Id, attachmentData); + + Assert.False(File.Exists(expectedPath)); + } + } + + [Theory] + [InlineUserCipherAutoData] + [InlineOrganizationCipherAutoData] + public async Task CleanupAsync_Succes(Cipher cipher) + { + using (var tempDirectory = new TempDirectory()) + { + var sutProvider = GetSutProvider(tempDirectory); + + var tempPath = $"{tempDirectory}/temp/{cipher.Id}"; + var permPath = $"{tempDirectory}/{cipher.Id}"; + Directory.CreateDirectory(tempPath); + Directory.CreateDirectory(permPath); + + await sutProvider.Sut.CleanupAsync(cipher.Id); + + Assert.False(Directory.Exists(tempPath)); + Assert.True(Directory.Exists(permPath)); + } + } + + [Theory] + [InlineUserCipherAutoData] + [InlineOrganizationCipherAutoData] + public async Task DeleteAttachmentsForCipherAsync_Succes(Cipher cipher) + { + using (var tempDirectory = new TempDirectory()) + { + var sutProvider = GetSutProvider(tempDirectory); + + var tempPath = $"{tempDirectory}/temp/{cipher.Id}"; + var permPath = $"{tempDirectory}/{cipher.Id}"; + Directory.CreateDirectory(tempPath); + Directory.CreateDirectory(permPath); + + await sutProvider.Sut.DeleteAttachmentsForCipherAsync(cipher.Id); + + Assert.True(Directory.Exists(tempPath)); + Assert.False(Directory.Exists(permPath)); + } + } + + private async Task StartShareAttachmentAsync(string source, string destOriginal, Cipher cipher, + CipherAttachment.MetaData attachmentData, TempDirectory tempDirectory) + { + var sutProvider = GetSutProvider(tempDirectory); + + var sourcePath = $"{tempDirectory}/temp/{cipher.Id}/{cipher.OrganizationId}/{attachmentData.AttachmentId}"; + var destPath = $"{tempDirectory}/{cipher.Id}/{attachmentData.AttachmentId}"; + var rollBackPath = $"{tempDirectory}/temp/{cipher.Id}/{attachmentData.AttachmentId}"; + Directory.CreateDirectory(Path.GetDirectoryName(sourcePath)); + Directory.CreateDirectory(Path.GetDirectoryName(destPath)); + File.WriteAllText(sourcePath, source); + File.WriteAllText(destPath, destOriginal); + + await sutProvider.Sut.StartShareAttachmentAsync(cipher.Id, cipher.OrganizationId.Value, attachmentData); + + Assert.False(File.Exists(sourcePath)); + Assert.True(File.Exists(destPath)); + Assert.Equal(source, File.ReadAllText(destPath)); + Assert.True(File.Exists(rollBackPath)); + Assert.Equal(destOriginal, File.ReadAllText(rollBackPath)); + } + + private SutProvider GetSutProvider(TempDirectory tempDirectory) + { + var fixture = new Fixture().WithAutoNSubstitutions(); + fixture.Freeze().Attachment.BaseDirectory.Returns(tempDirectory.Directory); + fixture.Freeze().Attachment.BaseUrl.Returns(Guid.NewGuid().ToString()); + + return new SutProvider(fixture).Create(); + } } diff --git a/test/Core.Test/Services/MailKitSmtpMailDeliveryServiceTests.cs b/test/Core.Test/Services/MailKitSmtpMailDeliveryServiceTests.cs index d4c5208a2..4e7e36fe0 100644 --- a/test/Core.Test/Services/MailKitSmtpMailDeliveryServiceTests.cs +++ b/test/Core.Test/Services/MailKitSmtpMailDeliveryServiceTests.cs @@ -4,35 +4,34 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class MailKitSmtpMailDeliveryServiceTests { - public class MailKitSmtpMailDeliveryServiceTests + private readonly MailKitSmtpMailDeliveryService _sut; + + private readonly GlobalSettings _globalSettings; + private readonly ILogger _logger; + + public MailKitSmtpMailDeliveryServiceTests() { - private readonly MailKitSmtpMailDeliveryService _sut; + _globalSettings = new GlobalSettings(); + _logger = Substitute.For>(); - private readonly GlobalSettings _globalSettings; - private readonly ILogger _logger; + _globalSettings.Mail.Smtp.Host = "unittests.example.com"; + _globalSettings.Mail.ReplyToEmail = "noreply@unittests.example.com"; - public MailKitSmtpMailDeliveryServiceTests() - { - _globalSettings = new GlobalSettings(); - _logger = Substitute.For>(); + _sut = new MailKitSmtpMailDeliveryService( + _globalSettings, + _logger + ); + } - _globalSettings.Mail.Smtp.Host = "unittests.example.com"; - _globalSettings.Mail.ReplyToEmail = "noreply@unittests.example.com"; - - _sut = new MailKitSmtpMailDeliveryService( - _globalSettings, - _logger - ); - } - - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact] - public void ServiceExists() - { - Assert.NotNull(_sut); - } + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact] + public void ServiceExists() + { + Assert.NotNull(_sut); } } diff --git a/test/Core.Test/Services/MultiServicePushNotificationServiceTests.cs b/test/Core.Test/Services/MultiServicePushNotificationServiceTests.cs index 925456619..b1876f1dd 100644 --- a/test/Core.Test/Services/MultiServicePushNotificationServiceTests.cs +++ b/test/Core.Test/Services/MultiServicePushNotificationServiceTests.cs @@ -6,50 +6,49 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class MultiServicePushNotificationServiceTests { - public class MultiServicePushNotificationServiceTests + private readonly MultiServicePushNotificationService _sut; + + private readonly IHttpClientFactory _httpFactory; + private readonly IDeviceRepository _deviceRepository; + private readonly IInstallationDeviceRepository _installationDeviceRepository; + private readonly GlobalSettings _globalSettings; + private readonly IHttpContextAccessor _httpContextAccessor; + private readonly ILogger _logger; + private readonly ILogger _relayLogger; + private readonly ILogger _hubLogger; + + public MultiServicePushNotificationServiceTests() { - private readonly MultiServicePushNotificationService _sut; + _httpFactory = Substitute.For(); + _deviceRepository = Substitute.For(); + _installationDeviceRepository = Substitute.For(); + _globalSettings = new GlobalSettings(); + _httpContextAccessor = Substitute.For(); + _logger = Substitute.For>(); + _relayLogger = Substitute.For>(); + _hubLogger = Substitute.For>(); - private readonly IHttpClientFactory _httpFactory; - private readonly IDeviceRepository _deviceRepository; - private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; - private readonly ILogger _logger; - private readonly ILogger _relayLogger; - private readonly ILogger _hubLogger; + _sut = new MultiServicePushNotificationService( + _httpFactory, + _deviceRepository, + _installationDeviceRepository, + _globalSettings, + _httpContextAccessor, + _logger, + _relayLogger, + _hubLogger + ); + } - public MultiServicePushNotificationServiceTests() - { - _httpFactory = Substitute.For(); - _deviceRepository = Substitute.For(); - _installationDeviceRepository = Substitute.For(); - _globalSettings = new GlobalSettings(); - _httpContextAccessor = Substitute.For(); - _logger = Substitute.For>(); - _relayLogger = Substitute.For>(); - _hubLogger = Substitute.For>(); - - _sut = new MultiServicePushNotificationService( - _httpFactory, - _deviceRepository, - _installationDeviceRepository, - _globalSettings, - _httpContextAccessor, - _logger, - _relayLogger, - _hubLogger - ); - } - - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact] - public void ServiceExists() - { - Assert.NotNull(_sut); - } + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact] + public void ServiceExists() + { + Assert.NotNull(_sut); } } diff --git a/test/Core.Test/Services/NotificationHubPushNotificationServiceTests.cs b/test/Core.Test/Services/NotificationHubPushNotificationServiceTests.cs index ea59da3ed..a066eee8b 100644 --- a/test/Core.Test/Services/NotificationHubPushNotificationServiceTests.cs +++ b/test/Core.Test/Services/NotificationHubPushNotificationServiceTests.cs @@ -5,35 +5,34 @@ using Microsoft.AspNetCore.Http; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class NotificationHubPushNotificationServiceTests { - public class NotificationHubPushNotificationServiceTests + private readonly NotificationHubPushNotificationService _sut; + + private readonly IInstallationDeviceRepository _installationDeviceRepository; + private readonly GlobalSettings _globalSettings; + private readonly IHttpContextAccessor _httpContextAccessor; + + public NotificationHubPushNotificationServiceTests() { - private readonly NotificationHubPushNotificationService _sut; + _installationDeviceRepository = Substitute.For(); + _globalSettings = new GlobalSettings(); + _httpContextAccessor = Substitute.For(); - private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; + _sut = new NotificationHubPushNotificationService( + _installationDeviceRepository, + _globalSettings, + _httpContextAccessor + ); + } - public NotificationHubPushNotificationServiceTests() - { - _installationDeviceRepository = Substitute.For(); - _globalSettings = new GlobalSettings(); - _httpContextAccessor = Substitute.For(); - - _sut = new NotificationHubPushNotificationService( - _installationDeviceRepository, - _globalSettings, - _httpContextAccessor - ); - } - - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); - } + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); } } diff --git a/test/Core.Test/Services/NotificationHubPushRegistrationServiceTests.cs b/test/Core.Test/Services/NotificationHubPushRegistrationServiceTests.cs index 432a79686..8e2a19d7b 100644 --- a/test/Core.Test/Services/NotificationHubPushRegistrationServiceTests.cs +++ b/test/Core.Test/Services/NotificationHubPushRegistrationServiceTests.cs @@ -4,32 +4,31 @@ using Bit.Core.Settings; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class NotificationHubPushRegistrationServiceTests { - public class NotificationHubPushRegistrationServiceTests + private readonly NotificationHubPushRegistrationService _sut; + + private readonly IInstallationDeviceRepository _installationDeviceRepository; + private readonly GlobalSettings _globalSettings; + + public NotificationHubPushRegistrationServiceTests() { - private readonly NotificationHubPushRegistrationService _sut; + _installationDeviceRepository = Substitute.For(); + _globalSettings = new GlobalSettings(); - private readonly IInstallationDeviceRepository _installationDeviceRepository; - private readonly GlobalSettings _globalSettings; + _sut = new NotificationHubPushRegistrationService( + _installationDeviceRepository, + _globalSettings + ); + } - public NotificationHubPushRegistrationServiceTests() - { - _installationDeviceRepository = Substitute.For(); - _globalSettings = new GlobalSettings(); - - _sut = new NotificationHubPushRegistrationService( - _installationDeviceRepository, - _globalSettings - ); - } - - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); - } + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); } } diff --git a/test/Core.Test/Services/NotificationsApiPushNotificationServiceTests.cs b/test/Core.Test/Services/NotificationsApiPushNotificationServiceTests.cs index 59976f57d..d1ba15d6a 100644 --- a/test/Core.Test/Services/NotificationsApiPushNotificationServiceTests.cs +++ b/test/Core.Test/Services/NotificationsApiPushNotificationServiceTests.cs @@ -5,38 +5,37 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class NotificationsApiPushNotificationServiceTests { - public class NotificationsApiPushNotificationServiceTests + private readonly NotificationsApiPushNotificationService _sut; + + private readonly IHttpClientFactory _httpFactory; + private readonly GlobalSettings _globalSettings; + private readonly IHttpContextAccessor _httpContextAccessor; + private readonly ILogger _logger; + + public NotificationsApiPushNotificationServiceTests() { - private readonly NotificationsApiPushNotificationService _sut; + _httpFactory = Substitute.For(); + _globalSettings = new GlobalSettings(); + _httpContextAccessor = Substitute.For(); + _logger = Substitute.For>(); - private readonly IHttpClientFactory _httpFactory; - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; - private readonly ILogger _logger; + _sut = new NotificationsApiPushNotificationService( + _httpFactory, + _globalSettings, + _httpContextAccessor, + _logger + ); + } - public NotificationsApiPushNotificationServiceTests() - { - _httpFactory = Substitute.For(); - _globalSettings = new GlobalSettings(); - _httpContextAccessor = Substitute.For(); - _logger = Substitute.For>(); - - _sut = new NotificationsApiPushNotificationService( - _httpFactory, - _globalSettings, - _httpContextAccessor, - _logger - ); - } - - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); - } + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); } } diff --git a/test/Core.Test/Services/OrganizationServiceTests.cs b/test/Core.Test/Services/OrganizationServiceTests.cs index acf2b7f17..73f0b8cfe 100644 --- a/test/Core.Test/Services/OrganizationServiceTests.cs +++ b/test/Core.Test/Services/OrganizationServiceTests.cs @@ -20,936 +20,935 @@ using Organization = Bit.Core.Entities.Organization; using OrganizationUser = Bit.Core.Entities.OrganizationUser; using Policy = Bit.Core.Entities.Policy; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class OrganizationServiceTests { - public class OrganizationServiceTests + // [Fact] + [Theory, PaidOrganizationAutoData] + public async Task OrgImportCreateNewUsers(SutProvider sutProvider, Guid userId, + Organization org, List existingUsers, List newUsers) { - // [Fact] - [Theory, PaidOrganizationAutoData] - public async Task OrgImportCreateNewUsers(SutProvider sutProvider, Guid userId, - Organization org, List existingUsers, List newUsers) + org.UseDirectory = true; + org.Seats = 10; + newUsers.Add(new ImportedOrganizationUser { - org.UseDirectory = true; - org.Seats = 10; - newUsers.Add(new ImportedOrganizationUser + Email = existingUsers.First().Email, + ExternalId = existingUsers.First().ExternalId + }); + var expectedNewUsersCount = newUsers.Count - 1; + + existingUsers.First().Type = OrganizationUserType.Owner; + + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(org.Id) + .Returns(existingUsers); + sutProvider.GetDependency().GetCountByOrganizationIdAsync(org.Id) + .Returns(existingUsers.Count); + sutProvider.GetDependency().GetManyByOrganizationAsync(org.Id, OrganizationUserType.Owner) + .Returns(existingUsers.Select(u => new OrganizationUser { Status = OrganizationUserStatusType.Confirmed, Type = OrganizationUserType.Owner, Id = u.Id }).ToList()); + sutProvider.GetDependency().ManageUsers(org.Id).Returns(true); + + await sutProvider.Sut.ImportAsync(org.Id, userId, null, newUsers, null, false); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + await sutProvider.GetDependency().Received(1) + .UpsertManyAsync(Arg.Is>(users => users.Count() == 0)); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CreateAsync(default); + + // Create new users + await sutProvider.GetDependency().Received(1) + .CreateManyAsync(Arg.Is>(users => users.Count() == expectedNewUsersCount)); + await sutProvider.GetDependency().Received(1) + .BulkSendOrganizationInviteEmailAsync(org.Name, + Arg.Is>(messages => messages.Count() == expectedNewUsersCount)); + + // Send events + await sutProvider.GetDependency().Received(1) + .LogOrganizationUserEventsAsync(Arg.Is>(events => + events.Count() == expectedNewUsersCount)); + await sutProvider.GetDependency().Received(1) + .RaiseEventAsync(Arg.Is(referenceEvent => + referenceEvent.Type == ReferenceEventType.InvitedUsers && referenceEvent.Id == org.Id && + referenceEvent.Users == expectedNewUsersCount)); + } + + [Theory, PaidOrganizationAutoData] + public async Task OrgImportCreateNewUsersAndMarryExistingUser(SutProvider sutProvider, + Guid userId, Organization org, List existingUsers, + List newUsers) + { + org.UseDirectory = true; + org.Seats = newUsers.Count + existingUsers.Count + 1; + var reInvitedUser = existingUsers.First(); + reInvitedUser.ExternalId = null; + newUsers.Add(new ImportedOrganizationUser + { + Email = reInvitedUser.Email, + ExternalId = reInvitedUser.Email, + }); + var expectedNewUsersCount = newUsers.Count - 1; + + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(org.Id) + .Returns(existingUsers); + sutProvider.GetDependency().GetCountByOrganizationIdAsync(org.Id) + .Returns(existingUsers.Count); + sutProvider.GetDependency().GetByIdAsync(reInvitedUser.Id) + .Returns(new OrganizationUser { Id = reInvitedUser.Id }); + sutProvider.GetDependency().GetManyByOrganizationAsync(org.Id, OrganizationUserType.Owner) + .Returns(existingUsers.Select(u => new OrganizationUser { Status = OrganizationUserStatusType.Confirmed, Type = OrganizationUserType.Owner, Id = u.Id }).ToList()); + var currentContext = sutProvider.GetDependency(); + currentContext.ManageUsers(org.Id).Returns(true); + + await sutProvider.Sut.ImportAsync(org.Id, userId, null, newUsers, null, false); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CreateAsync(default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CreateAsync(default, default); + + // Upserted existing user + await sutProvider.GetDependency().Received(1) + .UpsertManyAsync(Arg.Is>(users => users.Count() == 1)); + + // Created and invited new users + await sutProvider.GetDependency().Received(1) + .CreateManyAsync(Arg.Is>(users => users.Count() == expectedNewUsersCount)); + await sutProvider.GetDependency().Received(1) + .BulkSendOrganizationInviteEmailAsync(org.Name, + Arg.Is>(messages => messages.Count() == expectedNewUsersCount)); + + // Sent events + await sutProvider.GetDependency().Received(1) + .LogOrganizationUserEventsAsync(Arg.Is>(events => + events.Where(e => e.Item2 == EventType.OrganizationUser_Invited).Count() == expectedNewUsersCount)); + await sutProvider.GetDependency().Received(1) + .RaiseEventAsync(Arg.Is(referenceEvent => + referenceEvent.Type == ReferenceEventType.InvitedUsers && referenceEvent.Id == org.Id && + referenceEvent.Users == expectedNewUsersCount)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpgradePlan_OrganizationIsNull_Throws(Guid organizationId, OrganizationUpgrade upgrade, + SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(Task.FromResult(null)); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.UpgradePlanAsync(organizationId, upgrade)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpgradePlan_GatewayCustomIdIsNull_Throws(Organization organization, OrganizationUpgrade upgrade, + SutProvider sutProvider) + { + organization.GatewayCustomerId = string.Empty; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade)); + Assert.Contains("no payment method", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpgradePlan_AlreadyInPlan_Throws(Organization organization, OrganizationUpgrade upgrade, + SutProvider sutProvider) + { + upgrade.Plan = organization.PlanType; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade)); + Assert.Contains("already on this plan", exception.Message); + } + + [Theory, PaidOrganizationAutoData] + public async Task UpgradePlan_UpgradeFromPaidPlan_Throws(Organization organization, OrganizationUpgrade upgrade, + SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade)); + Assert.Contains("can only upgrade", exception.Message); + } + + [Theory] + [FreeOrganizationUpgradeAutoData] + public async Task UpgradePlan_Passes(Organization organization, OrganizationUpgrade upgrade, + SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + await sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade); + await sutProvider.GetDependency().Received(1).ReplaceAsync(organization); + } + + [Theory] + [OrganizationInviteAutoData] + public async Task InviteUser_NoEmails_Throws(Organization organization, OrganizationUser invitor, + OrganizationUserInvite invite, SutProvider sutProvider) + { + invite.Emails = null; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + await Assert.ThrowsAsync( + () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); + } + + [Theory] + [OrganizationInviteAutoData] + public async Task InviteUser_DuplicateEmails_PassesWithoutDuplicates(Organization organization, OrganizationUser invitor, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + OrganizationUserInvite invite, SutProvider sutProvider) + { + invite.Emails = invite.Emails.Append(invite.Emails.First()); + + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + sutProvider.GetDependency().OrganizationOwner(organization.Id).Returns(true); + sutProvider.GetDependency().ManageUsers(organization.Id).Returns(true); + var organizationUserRepository = sutProvider.GetDependency(); + organizationUserRepository.GetManyByOrganizationAsync(organization.Id, OrganizationUserType.Owner) + .Returns(new[] { owner }); + + await sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) }); + + await sutProvider.GetDependency().Received(1) + .BulkSendOrganizationInviteEmailAsync(organization.Name, + Arg.Is>(v => v.Count() == invite.Emails.Distinct().Count())); + } + + [Theory] + [OrganizationInviteAutoData( + inviteeUserType: (int)OrganizationUserType.Admin, + invitorUserType: (int)OrganizationUserType.Owner + )] + public async Task InviteUser_NoOwner_Throws(Organization organization, OrganizationUser invitor, + OrganizationUserInvite invite, SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + sutProvider.GetDependency().OrganizationOwner(organization.Id).Returns(true); + sutProvider.GetDependency().ManageUsers(organization.Id).Returns(true); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); + Assert.Contains("Organization must have at least one confirmed owner.", exception.Message); + } + + [Theory] + [OrganizationInviteAutoData( + inviteeUserType: (int)OrganizationUserType.Owner, + invitorUserType: (int)OrganizationUserType.Admin + )] + public async Task InviteUser_NonOwnerConfiguringOwner_Throws(Organization organization, OrganizationUserInvite invite, + OrganizationUser invitor, SutProvider sutProvider) + { + var organizationRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + organizationRepository.GetByIdAsync(organization.Id).Returns(organization); + currentContext.OrganizationAdmin(organization.Id).Returns(true); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); + Assert.Contains("only an owner", exception.Message.ToLowerInvariant()); + } + + [Theory] + [OrganizationInviteAutoData( + inviteeUserType: (int)OrganizationUserType.Custom, + invitorUserType: (int)OrganizationUserType.User + )] + public async Task InviteUser_NonAdminConfiguringAdmin_Throws(Organization organization, OrganizationUserInvite invite, + OrganizationUser invitor, SutProvider sutProvider) + { + var organizationRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + organizationRepository.GetByIdAsync(organization.Id).Returns(organization); + currentContext.OrganizationUser(organization.Id).Returns(true); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); + Assert.Contains("only owners and admins", exception.Message.ToLowerInvariant()); + } + + [Theory] + [OrganizationInviteAutoData( + inviteeUserType: (int)OrganizationUserType.Manager, + invitorUserType: (int)OrganizationUserType.Custom + )] + public async Task InviteUser_CustomUserWithoutManageUsersConfiguringUser_Throws(Organization organization, OrganizationUserInvite invite, + OrganizationUser invitor, SutProvider sutProvider) + { + invitor.Permissions = JsonSerializer.Serialize(new Permissions() { ManageUsers = false }, + new JsonSerializerOptions { - Email = existingUsers.First().Email, - ExternalId = existingUsers.First().ExternalId + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, }); - var expectedNewUsersCount = newUsers.Count - 1; - existingUsers.First().Type = OrganizationUserType.Owner; + var organizationRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); - sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); - sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(org.Id) - .Returns(existingUsers); - sutProvider.GetDependency().GetCountByOrganizationIdAsync(org.Id) - .Returns(existingUsers.Count); - sutProvider.GetDependency().GetManyByOrganizationAsync(org.Id, OrganizationUserType.Owner) - .Returns(existingUsers.Select(u => new OrganizationUser { Status = OrganizationUserStatusType.Confirmed, Type = OrganizationUserType.Owner, Id = u.Id }).ToList()); - sutProvider.GetDependency().ManageUsers(org.Id).Returns(true); + organizationRepository.GetByIdAsync(organization.Id).Returns(organization); + currentContext.OrganizationCustom(organization.Id).Returns(true); + currentContext.ManageUsers(organization.Id).Returns(false); - await sutProvider.Sut.ImportAsync(org.Id, userId, null, newUsers, null, false); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); + Assert.Contains("account does not have permission", exception.Message.ToLowerInvariant()); + } - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - await sutProvider.GetDependency().Received(1) - .UpsertManyAsync(Arg.Is>(users => users.Count() == 0)); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CreateAsync(default); - - // Create new users - await sutProvider.GetDependency().Received(1) - .CreateManyAsync(Arg.Is>(users => users.Count() == expectedNewUsersCount)); - await sutProvider.GetDependency().Received(1) - .BulkSendOrganizationInviteEmailAsync(org.Name, - Arg.Is>(messages => messages.Count() == expectedNewUsersCount)); - - // Send events - await sutProvider.GetDependency().Received(1) - .LogOrganizationUserEventsAsync(Arg.Is>(events => - events.Count() == expectedNewUsersCount)); - await sutProvider.GetDependency().Received(1) - .RaiseEventAsync(Arg.Is(referenceEvent => - referenceEvent.Type == ReferenceEventType.InvitedUsers && referenceEvent.Id == org.Id && - referenceEvent.Users == expectedNewUsersCount)); - } - - [Theory, PaidOrganizationAutoData] - public async Task OrgImportCreateNewUsersAndMarryExistingUser(SutProvider sutProvider, - Guid userId, Organization org, List existingUsers, - List newUsers) - { - org.UseDirectory = true; - org.Seats = newUsers.Count + existingUsers.Count + 1; - var reInvitedUser = existingUsers.First(); - reInvitedUser.ExternalId = null; - newUsers.Add(new ImportedOrganizationUser + [Theory] + [OrganizationInviteAutoData( + inviteeUserType: (int)OrganizationUserType.Admin, + invitorUserType: (int)OrganizationUserType.Custom + )] + public async Task InviteUser_CustomUserConfiguringAdmin_Throws(Organization organization, OrganizationUserInvite invite, + OrganizationUser invitor, SutProvider sutProvider) + { + invitor.Permissions = JsonSerializer.Serialize(new Permissions() { ManageUsers = true }, + new JsonSerializerOptions { - Email = reInvitedUser.Email, - ExternalId = reInvitedUser.Email, + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, }); - var expectedNewUsersCount = newUsers.Count - 1; - sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); - sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(org.Id) - .Returns(existingUsers); - sutProvider.GetDependency().GetCountByOrganizationIdAsync(org.Id) - .Returns(existingUsers.Count); - sutProvider.GetDependency().GetByIdAsync(reInvitedUser.Id) - .Returns(new OrganizationUser { Id = reInvitedUser.Id }); - sutProvider.GetDependency().GetManyByOrganizationAsync(org.Id, OrganizationUserType.Owner) - .Returns(existingUsers.Select(u => new OrganizationUser { Status = OrganizationUserStatusType.Confirmed, Type = OrganizationUserType.Owner, Id = u.Id }).ToList()); - var currentContext = sutProvider.GetDependency(); - currentContext.ManageUsers(org.Id).Returns(true); - - await sutProvider.Sut.ImportAsync(org.Id, userId, null, newUsers, null, false); - - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CreateAsync(default); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CreateAsync(default, default); - - // Upserted existing user - await sutProvider.GetDependency().Received(1) - .UpsertManyAsync(Arg.Is>(users => users.Count() == 1)); - - // Created and invited new users - await sutProvider.GetDependency().Received(1) - .CreateManyAsync(Arg.Is>(users => users.Count() == expectedNewUsersCount)); - await sutProvider.GetDependency().Received(1) - .BulkSendOrganizationInviteEmailAsync(org.Name, - Arg.Is>(messages => messages.Count() == expectedNewUsersCount)); - - // Sent events - await sutProvider.GetDependency().Received(1) - .LogOrganizationUserEventsAsync(Arg.Is>(events => - events.Where(e => e.Item2 == EventType.OrganizationUser_Invited).Count() == expectedNewUsersCount)); - await sutProvider.GetDependency().Received(1) - .RaiseEventAsync(Arg.Is(referenceEvent => - referenceEvent.Type == ReferenceEventType.InvitedUsers && referenceEvent.Id == org.Id && - referenceEvent.Users == expectedNewUsersCount)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpgradePlan_OrganizationIsNull_Throws(Guid organizationId, OrganizationUpgrade upgrade, - SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(Task.FromResult(null)); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.UpgradePlanAsync(organizationId, upgrade)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpgradePlan_GatewayCustomIdIsNull_Throws(Organization organization, OrganizationUpgrade upgrade, - SutProvider sutProvider) - { - organization.GatewayCustomerId = string.Empty; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade)); - Assert.Contains("no payment method", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpgradePlan_AlreadyInPlan_Throws(Organization organization, OrganizationUpgrade upgrade, - SutProvider sutProvider) - { - upgrade.Plan = organization.PlanType; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade)); - Assert.Contains("already on this plan", exception.Message); - } - - [Theory, PaidOrganizationAutoData] - public async Task UpgradePlan_UpgradeFromPaidPlan_Throws(Organization organization, OrganizationUpgrade upgrade, - SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade)); - Assert.Contains("can only upgrade", exception.Message); - } - - [Theory] - [FreeOrganizationUpgradeAutoData] - public async Task UpgradePlan_Passes(Organization organization, OrganizationUpgrade upgrade, - SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - await sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade); - await sutProvider.GetDependency().Received(1).ReplaceAsync(organization); - } - - [Theory] - [OrganizationInviteAutoData] - public async Task InviteUser_NoEmails_Throws(Organization organization, OrganizationUser invitor, - OrganizationUserInvite invite, SutProvider sutProvider) - { - invite.Emails = null; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - await Assert.ThrowsAsync( - () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); - } - - [Theory] - [OrganizationInviteAutoData] - public async Task InviteUser_DuplicateEmails_PassesWithoutDuplicates(Organization organization, OrganizationUser invitor, - [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, - OrganizationUserInvite invite, SutProvider sutProvider) - { - invite.Emails = invite.Emails.Append(invite.Emails.First()); - - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - sutProvider.GetDependency().OrganizationOwner(organization.Id).Returns(true); - sutProvider.GetDependency().ManageUsers(organization.Id).Returns(true); - var organizationUserRepository = sutProvider.GetDependency(); - organizationUserRepository.GetManyByOrganizationAsync(organization.Id, OrganizationUserType.Owner) - .Returns(new[] { owner }); - - await sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) }); - - await sutProvider.GetDependency().Received(1) - .BulkSendOrganizationInviteEmailAsync(organization.Name, - Arg.Is>(v => v.Count() == invite.Emails.Distinct().Count())); - } - - [Theory] - [OrganizationInviteAutoData( - inviteeUserType: (int)OrganizationUserType.Admin, - invitorUserType: (int)OrganizationUserType.Owner - )] - public async Task InviteUser_NoOwner_Throws(Organization organization, OrganizationUser invitor, - OrganizationUserInvite invite, SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - sutProvider.GetDependency().OrganizationOwner(organization.Id).Returns(true); - sutProvider.GetDependency().ManageUsers(organization.Id).Returns(true); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); - Assert.Contains("Organization must have at least one confirmed owner.", exception.Message); - } - - [Theory] - [OrganizationInviteAutoData( - inviteeUserType: (int)OrganizationUserType.Owner, - invitorUserType: (int)OrganizationUserType.Admin - )] - public async Task InviteUser_NonOwnerConfiguringOwner_Throws(Organization organization, OrganizationUserInvite invite, - OrganizationUser invitor, SutProvider sutProvider) - { - var organizationRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - organizationRepository.GetByIdAsync(organization.Id).Returns(organization); - currentContext.OrganizationAdmin(organization.Id).Returns(true); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); - Assert.Contains("only an owner", exception.Message.ToLowerInvariant()); - } - - [Theory] - [OrganizationInviteAutoData( - inviteeUserType: (int)OrganizationUserType.Custom, - invitorUserType: (int)OrganizationUserType.User - )] - public async Task InviteUser_NonAdminConfiguringAdmin_Throws(Organization organization, OrganizationUserInvite invite, - OrganizationUser invitor, SutProvider sutProvider) - { - var organizationRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - organizationRepository.GetByIdAsync(organization.Id).Returns(organization); - currentContext.OrganizationUser(organization.Id).Returns(true); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); - Assert.Contains("only owners and admins", exception.Message.ToLowerInvariant()); - } - - [Theory] - [OrganizationInviteAutoData( - inviteeUserType: (int)OrganizationUserType.Manager, - invitorUserType: (int)OrganizationUserType.Custom - )] - public async Task InviteUser_CustomUserWithoutManageUsersConfiguringUser_Throws(Organization organization, OrganizationUserInvite invite, - OrganizationUser invitor, SutProvider sutProvider) - { - invitor.Permissions = JsonSerializer.Serialize(new Permissions() { ManageUsers = false }, - new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }); - - var organizationRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - organizationRepository.GetByIdAsync(organization.Id).Returns(organization); - currentContext.OrganizationCustom(organization.Id).Returns(true); - currentContext.ManageUsers(organization.Id).Returns(false); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); - Assert.Contains("account does not have permission", exception.Message.ToLowerInvariant()); - } - - [Theory] - [OrganizationInviteAutoData( - inviteeUserType: (int)OrganizationUserType.Admin, - invitorUserType: (int)OrganizationUserType.Custom - )] - public async Task InviteUser_CustomUserConfiguringAdmin_Throws(Organization organization, OrganizationUserInvite invite, - OrganizationUser invitor, SutProvider sutProvider) - { - invitor.Permissions = JsonSerializer.Serialize(new Permissions() { ManageUsers = true }, - new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }); - - var organizationRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - organizationRepository.GetByIdAsync(organization.Id).Returns(organization); - currentContext.OrganizationCustom(organization.Id).Returns(true); - currentContext.ManageUsers(organization.Id).Returns(true); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); - Assert.Contains("can not manage admins", exception.Message.ToLowerInvariant()); - } - - [Theory] - [OrganizationInviteAutoData( - inviteeUserType: (int)OrganizationUserType.User, - invitorUserType: (int)OrganizationUserType.Owner - )] - public async Task InviteUser_NoPermissionsObject_Passes(Organization organization, OrganizationUserInvite invite, - OrganizationUser invitor, SutProvider sutProvider) - { - invite.Permissions = null; - invitor.Status = OrganizationUserStatusType.Confirmed; - var organizationRepository = sutProvider.GetDependency(); - var organizationUserRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - organizationRepository.GetByIdAsync(organization.Id).Returns(organization); - organizationUserRepository.GetManyByOrganizationAsync(organization.Id, OrganizationUserType.Owner) - .Returns(new[] { invitor }); - currentContext.OrganizationOwner(organization.Id).Returns(true); - currentContext.ManageUsers(organization.Id).Returns(true); - - await sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) }); - } - - [Theory] - [OrganizationInviteAutoData( - inviteeUserType: (int)OrganizationUserType.User, - invitorUserType: (int)OrganizationUserType.Custom - )] - public async Task InviteUser_Passes(Organization organization, IEnumerable<(OrganizationUserInvite invite, string externalId)> invites, - OrganizationUser invitor, - [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, - SutProvider sutProvider) - { - // Autofixture will add collections for all of the invites, remove the first and for all the rest set all access false - invites.First().invite.AccessAll = true; - invites.First().invite.Collections = null; - invites.Skip(1).ToList().ForEach(i => i.invite.AccessAll = false); - - invitor.Permissions = JsonSerializer.Serialize(new Permissions() { ManageUsers = true }, - new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }); - - var organizationRepository = sutProvider.GetDependency(); - var organizationUserRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - organizationRepository.GetByIdAsync(organization.Id).Returns(organization); - organizationUserRepository.GetManyByOrganizationAsync(organization.Id, OrganizationUserType.Owner) - .Returns(new[] { owner }); - currentContext.ManageUsers(organization.Id).Returns(true); - - await sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, invites); - - await sutProvider.GetDependency().Received(1) - .BulkSendOrganizationInviteEmailAsync(organization.Name, - Arg.Is>(v => v.Count() == invites.SelectMany(i => i.invite.Emails).Count())); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveUser_NoUserId_Throws(OrganizationUser user, Guid? savingUserId, - IEnumerable collections, SutProvider sutProvider) - { - user.Id = default(Guid); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveUserAsync(user, savingUserId, collections)); - Assert.Contains("invite the user first", exception.Message.ToLowerInvariant()); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveUser_NoChangeToData_Throws(OrganizationUser user, Guid? savingUserId, - IEnumerable collections, SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - organizationUserRepository.GetByIdAsync(user.Id).Returns(user); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveUserAsync(user, savingUserId, collections)); - Assert.Contains("make changes before saving", exception.Message.ToLowerInvariant()); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveUser_Passes( - OrganizationUser oldUserData, - OrganizationUser newUserData, - IEnumerable collections, - [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser savingUser, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - newUserData.Id = oldUserData.Id; - newUserData.UserId = oldUserData.UserId; - newUserData.OrganizationId = savingUser.OrganizationId = oldUserData.OrganizationId; - organizationUserRepository.GetByIdAsync(oldUserData.Id).Returns(oldUserData); - organizationUserRepository.GetManyByOrganizationAsync(savingUser.OrganizationId, OrganizationUserType.Owner) - .Returns(new List { savingUser }); - currentContext.OrganizationOwner(savingUser.OrganizationId).Returns(true); - - await sutProvider.Sut.SaveUserAsync(newUserData, savingUser.UserId, collections); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUser_InvalidUser(OrganizationUser organizationUser, OrganizationUser deletingUser, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - - organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.DeleteUserAsync(Guid.NewGuid(), organizationUser.Id, deletingUser.UserId)); - Assert.Contains("User not valid.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUser_RemoveYourself(OrganizationUser deletingUser, SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - - organizationUserRepository.GetByIdAsync(deletingUser.Id).Returns(deletingUser); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.DeleteUserAsync(deletingUser.OrganizationId, deletingUser.Id, deletingUser.UserId)); - Assert.Contains("You cannot remove yourself.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUser_NonOwnerRemoveOwner( - [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser organizationUser, - [OrganizationUser(type: OrganizationUserType.Admin)] OrganizationUser deletingUser, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - organizationUser.OrganizationId = deletingUser.OrganizationId; - organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); - currentContext.OrganizationAdmin(deletingUser.OrganizationId).Returns(true); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.DeleteUserAsync(deletingUser.OrganizationId, organizationUser.Id, deletingUser.UserId)); - Assert.Contains("Only owners can delete other owners.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUser_LastOwner( - [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser organizationUser, - OrganizationUser deletingUser, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - - organizationUser.OrganizationId = deletingUser.OrganizationId; - organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); - organizationUserRepository.GetManyByOrganizationAsync(deletingUser.OrganizationId, OrganizationUserType.Owner) - .Returns(new[] { organizationUser }); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.DeleteUserAsync(deletingUser.OrganizationId, organizationUser.Id, null)); - Assert.Contains("Organization must have at least one confirmed owner.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUser_Success( - OrganizationUser organizationUser, - [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser deletingUser, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - organizationUser.OrganizationId = deletingUser.OrganizationId; - organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); - organizationUserRepository.GetByIdAsync(deletingUser.Id).Returns(deletingUser); - organizationUserRepository.GetManyByOrganizationAsync(deletingUser.OrganizationId, OrganizationUserType.Owner) - .Returns(new[] { deletingUser, organizationUser }); - currentContext.OrganizationOwner(deletingUser.OrganizationId).Returns(true); - - await sutProvider.Sut.DeleteUserAsync(deletingUser.OrganizationId, organizationUser.Id, deletingUser.UserId); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUsers_FilterInvalid(OrganizationUser organizationUser, OrganizationUser deletingUser, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var organizationUsers = new[] { organizationUser }; - var organizationUserIds = organizationUsers.Select(u => u.Id); - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.DeleteUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId)); - Assert.Contains("Users invalid.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUsers_RemoveYourself( - [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUser, - OrganizationUser deletingUser, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var organizationUsers = new[] { deletingUser }; - var organizationUserIds = organizationUsers.Select(u => u.Id); - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); - organizationUserRepository.GetManyByOrganizationAsync(default, default).ReturnsForAnyArgs(new[] { orgUser }); - - var result = await sutProvider.Sut.DeleteUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId); - Assert.Contains("You cannot remove yourself.", result[0].Item2); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUsers_NonOwnerRemoveOwner( - [OrganizationUser(type: OrganizationUserType.Admin)] OrganizationUser deletingUser, - [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser orgUser1, - [OrganizationUser(OrganizationUserStatusType.Confirmed)] OrganizationUser orgUser2, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - - orgUser1.OrganizationId = orgUser2.OrganizationId = deletingUser.OrganizationId; - var organizationUsers = new[] { orgUser1 }; - var organizationUserIds = organizationUsers.Select(u => u.Id); - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); - organizationUserRepository.GetManyByOrganizationAsync(default, default).ReturnsForAnyArgs(new[] { orgUser2 }); - - var result = await sutProvider.Sut.DeleteUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId); - Assert.Contains("Only owners can delete other owners.", result[0].Item2); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUsers_LastOwner( - [OrganizationUser(status: OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUser, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - - var organizationUsers = new[] { orgUser }; - var organizationUserIds = organizationUsers.Select(u => u.Id); - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); - organizationUserRepository.GetManyByOrganizationAsync(orgUser.OrganizationId, OrganizationUserType.Owner).Returns(organizationUsers); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.DeleteUsersAsync(orgUser.OrganizationId, organizationUserIds, null)); - Assert.Contains("Organization must have at least one confirmed owner.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task DeleteUsers_Success( - [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser deletingUser, - [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser orgUser1, OrganizationUser orgUser2, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - - orgUser1.OrganizationId = orgUser2.OrganizationId = deletingUser.OrganizationId; - var organizationUsers = new[] { orgUser1, orgUser2 }; - var organizationUserIds = organizationUsers.Select(u => u.Id); - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); - organizationUserRepository.GetByIdAsync(deletingUser.Id).Returns(deletingUser); - organizationUserRepository.GetManyByOrganizationAsync(deletingUser.OrganizationId, OrganizationUserType.Owner) - .Returns(new[] { deletingUser, orgUser1 }); - currentContext.OrganizationOwner(deletingUser.OrganizationId).Returns(true); - - await sutProvider.Sut.DeleteUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ConfirmUser_InvalidStatus(OrganizationUser confirmingUser, - [OrganizationUser(OrganizationUserStatusType.Invited)] OrganizationUser orgUser, string key, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var userService = Substitute.For(); - - organizationUserRepository.GetByIdAsync(orgUser.Id).Returns(orgUser); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); - Assert.Contains("User not valid.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ConfirmUser_WrongOrganization(OrganizationUser confirmingUser, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, string key, - SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var userService = Substitute.For(); - - organizationUserRepository.GetByIdAsync(orgUser.Id).Returns(orgUser); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.ConfirmUserAsync(confirmingUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); - Assert.Contains("User not valid.", exception.Message); - } - - [Theory] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, OrganizationUserType.Owner)] - public async Task ConfirmUserToFree_AlreadyFreeAdminOrOwner_Throws(OrganizationUserType userType, Organization org, OrganizationUser confirmingUser, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, - string key, SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var organizationRepository = sutProvider.GetDependency(); - var userService = Substitute.For(); - var userRepository = sutProvider.GetDependency(); - - org.PlanType = PlanType.Free; - orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; - orgUser.UserId = user.Id; - orgUser.Type = userType; - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); - organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync(orgUser.UserId.Value).Returns(1); - organizationRepository.GetByIdAsync(org.Id).Returns(org); - userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); - Assert.Contains("User can only be an admin of one free organization.", exception.Message); - } - - [Theory] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.Custom, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.Custom, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseAnnually, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseAnnually, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseAnnually2019, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseAnnually2019, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseMonthly, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseMonthly, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseMonthly2019, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseMonthly2019, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.FamiliesAnnually, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.FamiliesAnnually, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.FamiliesAnnually2019, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.FamiliesAnnually2019, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsAnnually, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsAnnually, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsAnnually2019, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsAnnually2019, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsMonthly, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsMonthly, OrganizationUserType.Owner)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsMonthly2019, OrganizationUserType.Admin)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsMonthly2019, OrganizationUserType.Owner)] - public async Task ConfirmUserToNonFree_AlreadyFreeAdminOrOwner_DoesNotThrow(PlanType planType, OrganizationUserType orgUserType, Organization org, OrganizationUser confirmingUser, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, - string key, SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var organizationRepository = sutProvider.GetDependency(); - var userService = Substitute.For(); - var userRepository = sutProvider.GetDependency(); - - org.PlanType = planType; - orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; - orgUser.UserId = user.Id; - orgUser.Type = orgUserType; - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); - organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync(orgUser.UserId.Value).Returns(1); - organizationRepository.GetByIdAsync(org.Id).Returns(org); - userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); - - await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService); - - await sutProvider.GetDependency().Received(1).LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Confirmed); - await sutProvider.GetDependency().Received(1).SendOrganizationConfirmedEmailAsync(org.Name, user.Email); - await organizationUserRepository.Received(1).ReplaceManyAsync(Arg.Is>(users => users.Contains(orgUser) && users.Count == 1)); - } - - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ConfirmUser_SingleOrgPolicy(Organization org, OrganizationUser confirmingUser, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, - OrganizationUser orgUserAnotherOrg, [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, - string key, SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var organizationRepository = sutProvider.GetDependency(); - var policyRepository = sutProvider.GetDependency(); - var userRepository = sutProvider.GetDependency(); - var userService = Substitute.For(); - - org.PlanType = PlanType.EnterpriseAnnually; - orgUser.Status = OrganizationUserStatusType.Accepted; - orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; - orgUser.UserId = orgUserAnotherOrg.UserId = user.Id; - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); - organizationUserRepository.GetManyByManyUsersAsync(default).ReturnsForAnyArgs(new[] { orgUserAnotherOrg }); - organizationRepository.GetByIdAsync(org.Id).Returns(org); - userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); - policyRepository.GetManyByOrganizationIdAsync(org.Id).Returns(new[] { singleOrgPolicy }); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); - Assert.Contains("User is a member of another organization.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ConfirmUser_TwoFactorPolicy(Organization org, OrganizationUser confirmingUser, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, - OrganizationUser orgUserAnotherOrg, [Policy(PolicyType.TwoFactorAuthentication)] Policy twoFactorPolicy, - string key, SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var organizationRepository = sutProvider.GetDependency(); - var policyRepository = sutProvider.GetDependency(); - var userRepository = sutProvider.GetDependency(); - var userService = Substitute.For(); - - org.PlanType = PlanType.EnterpriseAnnually; - orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; - orgUser.UserId = orgUserAnotherOrg.UserId = user.Id; - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); - organizationUserRepository.GetManyByManyUsersAsync(default).ReturnsForAnyArgs(new[] { orgUserAnotherOrg }); - organizationRepository.GetByIdAsync(org.Id).Returns(org); - userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); - policyRepository.GetManyByOrganizationIdAsync(org.Id).Returns(new[] { twoFactorPolicy }); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); - Assert.Contains("User does not have two-step login enabled.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ConfirmUser_Success(Organization org, OrganizationUser confirmingUser, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, - [Policy(PolicyType.TwoFactorAuthentication)] Policy twoFactorPolicy, - [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, string key, SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var organizationRepository = sutProvider.GetDependency(); - var policyRepository = sutProvider.GetDependency(); - var userRepository = sutProvider.GetDependency(); - var userService = Substitute.For(); - - org.PlanType = PlanType.EnterpriseAnnually; - orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; - orgUser.UserId = user.Id; - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); - organizationRepository.GetByIdAsync(org.Id).Returns(org); - userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); - policyRepository.GetManyByOrganizationIdAsync(org.Id).Returns(new[] { twoFactorPolicy, singleOrgPolicy }); - userService.TwoFactorIsEnabledAsync(user).Returns(true); - - await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task ConfirmUsers_Success(Organization org, - OrganizationUser confirmingUser, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser1, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser2, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser3, - OrganizationUser anotherOrgUser, User user1, User user2, User user3, - [Policy(PolicyType.TwoFactorAuthentication)] Policy twoFactorPolicy, - [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, string key, SutProvider sutProvider) - { - var organizationUserRepository = sutProvider.GetDependency(); - var organizationRepository = sutProvider.GetDependency(); - var policyRepository = sutProvider.GetDependency(); - var userRepository = sutProvider.GetDependency(); - var userService = Substitute.For(); - - org.PlanType = PlanType.EnterpriseAnnually; - orgUser1.OrganizationId = orgUser2.OrganizationId = orgUser3.OrganizationId = confirmingUser.OrganizationId = org.Id; - orgUser1.UserId = user1.Id; - orgUser2.UserId = user2.Id; - orgUser3.UserId = user3.Id; - anotherOrgUser.UserId = user3.Id; - var orgUsers = new[] { orgUser1, orgUser2, orgUser3 }; - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(orgUsers); - organizationRepository.GetByIdAsync(org.Id).Returns(org); - userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user1, user2, user3 }); - policyRepository.GetManyByOrganizationIdAsync(org.Id).Returns(new[] { twoFactorPolicy, singleOrgPolicy }); - userService.TwoFactorIsEnabledAsync(user1).Returns(true); - userService.TwoFactorIsEnabledAsync(user2).Returns(false); - userService.TwoFactorIsEnabledAsync(user3).Returns(true); - organizationUserRepository.GetManyByManyUsersAsync(default) - .ReturnsForAnyArgs(new[] { orgUser1, orgUser2, orgUser3, anotherOrgUser }); - - var keys = orgUsers.ToDictionary(ou => ou.Id, _ => key); - var result = await sutProvider.Sut.ConfirmUsersAsync(confirmingUser.OrganizationId, keys, confirmingUser.Id, userService); - Assert.Contains("", result[0].Item2); - Assert.Contains("User does not have two-step login enabled.", result[1].Item2); - Assert.Contains("User is a member of another organization.", result[2].Item2); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpdateOrganizationKeysAsync_WithoutManageResetPassword_Throws(Guid orgId, string publicKey, - string privateKey, SutProvider sutProvider) - { - var currentContext = Substitute.For(); - currentContext.ManageResetPassword(orgId).Returns(false); - - await Assert.ThrowsAsync( - () => sutProvider.Sut.UpdateOrganizationKeysAsync(orgId, publicKey, privateKey)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpdateOrganizationKeysAsync_KeysAlreadySet_Throws(Organization org, string publicKey, - string privateKey, SutProvider sutProvider) - { - var currentContext = sutProvider.GetDependency(); - currentContext.ManageResetPassword(org.Id).Returns(true); - - var organizationRepository = sutProvider.GetDependency(); - organizationRepository.GetByIdAsync(org.Id).Returns(org); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.UpdateOrganizationKeysAsync(org.Id, publicKey, privateKey)); - Assert.Contains("Organization Keys already exist", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpdateOrganizationKeysAsync_KeysAlreadySet_Success(Organization org, string publicKey, - string privateKey, SutProvider sutProvider) - { - org.PublicKey = null; - org.PrivateKey = null; - - var currentContext = sutProvider.GetDependency(); - currentContext.ManageResetPassword(org.Id).Returns(true); - - var organizationRepository = sutProvider.GetDependency(); - organizationRepository.GetByIdAsync(org.Id).Returns(org); - - await sutProvider.Sut.UpdateOrganizationKeysAsync(org.Id, publicKey, privateKey); - } - - [Theory] - [InlinePaidOrganizationAutoData(PlanType.EnterpriseAnnually, new object[] { "Cannot set max seat autoscaling below seat count", 1, 0, 2 })] - [InlinePaidOrganizationAutoData(PlanType.EnterpriseAnnually, new object[] { "Cannot set max seat autoscaling below seat count", 4, -1, 6 })] - [InlineFreeOrganizationAutoData("Your plan does not allow seat autoscaling", 10, 0, null)] - public async Task UpdateSubscription_BadInputThrows(string expectedMessage, - int? maxAutoscaleSeats, int seatAdjustment, int? currentSeats, Organization organization, SutProvider sutProvider) - { - organization.Seats = currentSeats; - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - - var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscription(organization.Id, - seatAdjustment, maxAutoscaleSeats)); - - Assert.Contains(expectedMessage, exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpdateSubscription_NoOrganization_Throws(Guid organizationId, SutProvider sutProvider) - { - sutProvider.GetDependency().GetByIdAsync(organizationId).Returns((Organization)null); - - await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscription(organizationId, 0, null)); - } - - [Theory] - [InlinePaidOrganizationAutoData(0, 100, null, true, "")] - [InlinePaidOrganizationAutoData(0, 100, 100, true, "")] - [InlinePaidOrganizationAutoData(0, null, 100, true, "")] - [InlinePaidOrganizationAutoData(1, 100, null, true, "")] - [InlinePaidOrganizationAutoData(1, 100, 100, false, "Cannot invite new users. Seat limit has been reached")] - public void CanScale(int seatsToAdd, int? currentSeats, int? maxAutoscaleSeats, - bool expectedResult, string expectedFailureMessage, Organization organization, - SutProvider sutProvider) - { - organization.Seats = currentSeats; - organization.MaxAutoscaleSeats = maxAutoscaleSeats; - sutProvider.GetDependency().ManageUsers(organization.Id).Returns(true); - - var (result, failureMessage) = sutProvider.Sut.CanScale(organization, seatsToAdd); - - if (expectedFailureMessage == string.Empty) + var organizationRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + organizationRepository.GetByIdAsync(organization.Id).Returns(organization); + currentContext.OrganizationCustom(organization.Id).Returns(true); + currentContext.ManageUsers(organization.Id).Returns(true); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) })); + Assert.Contains("can not manage admins", exception.Message.ToLowerInvariant()); + } + + [Theory] + [OrganizationInviteAutoData( + inviteeUserType: (int)OrganizationUserType.User, + invitorUserType: (int)OrganizationUserType.Owner + )] + public async Task InviteUser_NoPermissionsObject_Passes(Organization organization, OrganizationUserInvite invite, + OrganizationUser invitor, SutProvider sutProvider) + { + invite.Permissions = null; + invitor.Status = OrganizationUserStatusType.Confirmed; + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + organizationRepository.GetByIdAsync(organization.Id).Returns(organization); + organizationUserRepository.GetManyByOrganizationAsync(organization.Id, OrganizationUserType.Owner) + .Returns(new[] { invitor }); + currentContext.OrganizationOwner(organization.Id).Returns(true); + currentContext.ManageUsers(organization.Id).Returns(true); + + await sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, new (OrganizationUserInvite, string)[] { (invite, null) }); + } + + [Theory] + [OrganizationInviteAutoData( + inviteeUserType: (int)OrganizationUserType.User, + invitorUserType: (int)OrganizationUserType.Custom + )] + public async Task InviteUser_Passes(Organization organization, IEnumerable<(OrganizationUserInvite invite, string externalId)> invites, + OrganizationUser invitor, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + SutProvider sutProvider) + { + // Autofixture will add collections for all of the invites, remove the first and for all the rest set all access false + invites.First().invite.AccessAll = true; + invites.First().invite.Collections = null; + invites.Skip(1).ToList().ForEach(i => i.invite.AccessAll = false); + + invitor.Permissions = JsonSerializer.Serialize(new Permissions() { ManageUsers = true }, + new JsonSerializerOptions { - Assert.Empty(failureMessage); - } - else - { - Assert.Contains(expectedFailureMessage, failureMessage); - } - Assert.Equal(expectedResult, result); - } + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }); - [Theory, PaidOrganizationAutoData] - public void CanScale_FailsOnSelfHosted(Organization organization, - SutProvider sutProvider) + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + organizationRepository.GetByIdAsync(organization.Id).Returns(organization); + organizationUserRepository.GetManyByOrganizationAsync(organization.Id, OrganizationUserType.Owner) + .Returns(new[] { owner }); + currentContext.ManageUsers(organization.Id).Returns(true); + + await sutProvider.Sut.InviteUsersAsync(organization.Id, invitor.UserId, invites); + + await sutProvider.GetDependency().Received(1) + .BulkSendOrganizationInviteEmailAsync(organization.Name, + Arg.Is>(v => v.Count() == invites.SelectMany(i => i.invite.Emails).Count())); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveUser_NoUserId_Throws(OrganizationUser user, Guid? savingUserId, + IEnumerable collections, SutProvider sutProvider) + { + user.Id = default(Guid); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveUserAsync(user, savingUserId, collections)); + Assert.Contains("invite the user first", exception.Message.ToLowerInvariant()); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveUser_NoChangeToData_Throws(OrganizationUser user, Guid? savingUserId, + IEnumerable collections, SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + organizationUserRepository.GetByIdAsync(user.Id).Returns(user); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveUserAsync(user, savingUserId, collections)); + Assert.Contains("make changes before saving", exception.Message.ToLowerInvariant()); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveUser_Passes( + OrganizationUser oldUserData, + OrganizationUser newUserData, + IEnumerable collections, + [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser savingUser, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + newUserData.Id = oldUserData.Id; + newUserData.UserId = oldUserData.UserId; + newUserData.OrganizationId = savingUser.OrganizationId = oldUserData.OrganizationId; + organizationUserRepository.GetByIdAsync(oldUserData.Id).Returns(oldUserData); + organizationUserRepository.GetManyByOrganizationAsync(savingUser.OrganizationId, OrganizationUserType.Owner) + .Returns(new List { savingUser }); + currentContext.OrganizationOwner(savingUser.OrganizationId).Returns(true); + + await sutProvider.Sut.SaveUserAsync(newUserData, savingUser.UserId, collections); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUser_InvalidUser(OrganizationUser organizationUser, OrganizationUser deletingUser, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + + organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteUserAsync(Guid.NewGuid(), organizationUser.Id, deletingUser.UserId)); + Assert.Contains("User not valid.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUser_RemoveYourself(OrganizationUser deletingUser, SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + + organizationUserRepository.GetByIdAsync(deletingUser.Id).Returns(deletingUser); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteUserAsync(deletingUser.OrganizationId, deletingUser.Id, deletingUser.UserId)); + Assert.Contains("You cannot remove yourself.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUser_NonOwnerRemoveOwner( + [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser organizationUser, + [OrganizationUser(type: OrganizationUserType.Admin)] OrganizationUser deletingUser, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + organizationUser.OrganizationId = deletingUser.OrganizationId; + organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); + currentContext.OrganizationAdmin(deletingUser.OrganizationId).Returns(true); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteUserAsync(deletingUser.OrganizationId, organizationUser.Id, deletingUser.UserId)); + Assert.Contains("Only owners can delete other owners.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUser_LastOwner( + [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser organizationUser, + OrganizationUser deletingUser, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + + organizationUser.OrganizationId = deletingUser.OrganizationId; + organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); + organizationUserRepository.GetManyByOrganizationAsync(deletingUser.OrganizationId, OrganizationUserType.Owner) + .Returns(new[] { organizationUser }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteUserAsync(deletingUser.OrganizationId, organizationUser.Id, null)); + Assert.Contains("Organization must have at least one confirmed owner.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUser_Success( + OrganizationUser organizationUser, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser deletingUser, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + organizationUser.OrganizationId = deletingUser.OrganizationId; + organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); + organizationUserRepository.GetByIdAsync(deletingUser.Id).Returns(deletingUser); + organizationUserRepository.GetManyByOrganizationAsync(deletingUser.OrganizationId, OrganizationUserType.Owner) + .Returns(new[] { deletingUser, organizationUser }); + currentContext.OrganizationOwner(deletingUser.OrganizationId).Returns(true); + + await sutProvider.Sut.DeleteUserAsync(deletingUser.OrganizationId, organizationUser.Id, deletingUser.UserId); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUsers_FilterInvalid(OrganizationUser organizationUser, OrganizationUser deletingUser, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var organizationUsers = new[] { organizationUser }; + var organizationUserIds = organizationUsers.Select(u => u.Id); + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId)); + Assert.Contains("Users invalid.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUsers_RemoveYourself( + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUser, + OrganizationUser deletingUser, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var organizationUsers = new[] { deletingUser }; + var organizationUserIds = organizationUsers.Select(u => u.Id); + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); + organizationUserRepository.GetManyByOrganizationAsync(default, default).ReturnsForAnyArgs(new[] { orgUser }); + + var result = await sutProvider.Sut.DeleteUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId); + Assert.Contains("You cannot remove yourself.", result[0].Item2); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUsers_NonOwnerRemoveOwner( + [OrganizationUser(type: OrganizationUserType.Admin)] OrganizationUser deletingUser, + [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Confirmed)] OrganizationUser orgUser2, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + + orgUser1.OrganizationId = orgUser2.OrganizationId = deletingUser.OrganizationId; + var organizationUsers = new[] { orgUser1 }; + var organizationUserIds = organizationUsers.Select(u => u.Id); + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); + organizationUserRepository.GetManyByOrganizationAsync(default, default).ReturnsForAnyArgs(new[] { orgUser2 }); + + var result = await sutProvider.Sut.DeleteUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId); + Assert.Contains("Only owners can delete other owners.", result[0].Item2); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUsers_LastOwner( + [OrganizationUser(status: OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUser, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + + var organizationUsers = new[] { orgUser }; + var organizationUserIds = organizationUsers.Select(u => u.Id); + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); + organizationUserRepository.GetManyByOrganizationAsync(orgUser.OrganizationId, OrganizationUserType.Owner).Returns(organizationUsers); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteUsersAsync(orgUser.OrganizationId, organizationUserIds, null)); + Assert.Contains("Organization must have at least one confirmed owner.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task DeleteUsers_Success( + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser deletingUser, + [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser orgUser1, OrganizationUser orgUser2, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var currentContext = sutProvider.GetDependency(); + + orgUser1.OrganizationId = orgUser2.OrganizationId = deletingUser.OrganizationId; + var organizationUsers = new[] { orgUser1, orgUser2 }; + var organizationUserIds = organizationUsers.Select(u => u.Id); + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); + organizationUserRepository.GetByIdAsync(deletingUser.Id).Returns(deletingUser); + organizationUserRepository.GetManyByOrganizationAsync(deletingUser.OrganizationId, OrganizationUserType.Owner) + .Returns(new[] { deletingUser, orgUser1 }); + currentContext.OrganizationOwner(deletingUser.OrganizationId).Returns(true); + + await sutProvider.Sut.DeleteUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ConfirmUser_InvalidStatus(OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Invited)] OrganizationUser orgUser, string key, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + + organizationUserRepository.GetByIdAsync(orgUser.Id).Returns(orgUser); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); + Assert.Contains("User not valid.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ConfirmUser_WrongOrganization(OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, string key, + SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + + organizationUserRepository.GetByIdAsync(orgUser.Id).Returns(orgUser); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(confirmingUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); + Assert.Contains("User not valid.", exception.Message); + } + + [Theory] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, OrganizationUserType.Owner)] + public async Task ConfirmUserToFree_AlreadyFreeAdminOrOwner_Throws(OrganizationUserType userType, Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + string key, SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + var userRepository = sutProvider.GetDependency(); + + org.PlanType = PlanType.Free; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = user.Id; + orgUser.Type = userType; + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); + organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync(orgUser.UserId.Value).Returns(1); + organizationRepository.GetByIdAsync(org.Id).Returns(org); + userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); + Assert.Contains("User can only be an admin of one free organization.", exception.Message); + } + + [Theory] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.Custom, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.Custom, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseAnnually, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseAnnually, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseAnnually2019, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseAnnually2019, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseMonthly, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseMonthly, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseMonthly2019, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.EnterpriseMonthly2019, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.FamiliesAnnually, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.FamiliesAnnually, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.FamiliesAnnually2019, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.FamiliesAnnually2019, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsAnnually, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsAnnually, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsAnnually2019, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsAnnually2019, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsMonthly, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsMonthly, OrganizationUserType.Owner)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsMonthly2019, OrganizationUserType.Admin)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PlanType.TeamsMonthly2019, OrganizationUserType.Owner)] + public async Task ConfirmUserToNonFree_AlreadyFreeAdminOrOwner_DoesNotThrow(PlanType planType, OrganizationUserType orgUserType, Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + string key, SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + var userRepository = sutProvider.GetDependency(); + + org.PlanType = planType; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = user.Id; + orgUser.Type = orgUserType; + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); + organizationUserRepository.GetCountByFreeOrganizationAdminUserAsync(orgUser.UserId.Value).Returns(1); + organizationRepository.GetByIdAsync(org.Id).Returns(org); + userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); + + await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService); + + await sutProvider.GetDependency().Received(1).LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Confirmed); + await sutProvider.GetDependency().Received(1).SendOrganizationConfirmedEmailAsync(org.Name, user.Email); + await organizationUserRepository.Received(1).ReplaceManyAsync(Arg.Is>(users => users.Contains(orgUser) && users.Count == 1)); + } + + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ConfirmUser_SingleOrgPolicy(Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + OrganizationUser orgUserAnotherOrg, [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, + string key, SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var policyRepository = sutProvider.GetDependency(); + var userRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.Status = OrganizationUserStatusType.Accepted; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = orgUserAnotherOrg.UserId = user.Id; + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); + organizationUserRepository.GetManyByManyUsersAsync(default).ReturnsForAnyArgs(new[] { orgUserAnotherOrg }); + organizationRepository.GetByIdAsync(org.Id).Returns(org); + userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); + policyRepository.GetManyByOrganizationIdAsync(org.Id).Returns(new[] { singleOrgPolicy }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); + Assert.Contains("User is a member of another organization.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ConfirmUser_TwoFactorPolicy(Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + OrganizationUser orgUserAnotherOrg, [Policy(PolicyType.TwoFactorAuthentication)] Policy twoFactorPolicy, + string key, SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var policyRepository = sutProvider.GetDependency(); + var userRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = orgUserAnotherOrg.UserId = user.Id; + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); + organizationUserRepository.GetManyByManyUsersAsync(default).ReturnsForAnyArgs(new[] { orgUserAnotherOrg }); + organizationRepository.GetByIdAsync(org.Id).Returns(org); + userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); + policyRepository.GetManyByOrganizationIdAsync(org.Id).Returns(new[] { twoFactorPolicy }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService)); + Assert.Contains("User does not have two-step login enabled.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ConfirmUser_Success(Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + [Policy(PolicyType.TwoFactorAuthentication)] Policy twoFactorPolicy, + [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, string key, SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var policyRepository = sutProvider.GetDependency(); + var userRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = user.Id; + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); + organizationRepository.GetByIdAsync(org.Id).Returns(org); + userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); + policyRepository.GetManyByOrganizationIdAsync(org.Id).Returns(new[] { twoFactorPolicy, singleOrgPolicy }); + userService.TwoFactorIsEnabledAsync(user).Returns(true); + + await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, userService); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task ConfirmUsers_Success(Organization org, + OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser2, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser3, + OrganizationUser anotherOrgUser, User user1, User user2, User user3, + [Policy(PolicyType.TwoFactorAuthentication)] Policy twoFactorPolicy, + [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy, string key, SutProvider sutProvider) + { + var organizationUserRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var policyRepository = sutProvider.GetDependency(); + var userRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + + org.PlanType = PlanType.EnterpriseAnnually; + orgUser1.OrganizationId = orgUser2.OrganizationId = orgUser3.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser1.UserId = user1.Id; + orgUser2.UserId = user2.Id; + orgUser3.UserId = user3.Id; + anotherOrgUser.UserId = user3.Id; + var orgUsers = new[] { orgUser1, orgUser2, orgUser3 }; + organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(orgUsers); + organizationRepository.GetByIdAsync(org.Id).Returns(org); + userRepository.GetManyAsync(default).ReturnsForAnyArgs(new[] { user1, user2, user3 }); + policyRepository.GetManyByOrganizationIdAsync(org.Id).Returns(new[] { twoFactorPolicy, singleOrgPolicy }); + userService.TwoFactorIsEnabledAsync(user1).Returns(true); + userService.TwoFactorIsEnabledAsync(user2).Returns(false); + userService.TwoFactorIsEnabledAsync(user3).Returns(true); + organizationUserRepository.GetManyByManyUsersAsync(default) + .ReturnsForAnyArgs(new[] { orgUser1, orgUser2, orgUser3, anotherOrgUser }); + + var keys = orgUsers.ToDictionary(ou => ou.Id, _ => key); + var result = await sutProvider.Sut.ConfirmUsersAsync(confirmingUser.OrganizationId, keys, confirmingUser.Id, userService); + Assert.Contains("", result[0].Item2); + Assert.Contains("User does not have two-step login enabled.", result[1].Item2); + Assert.Contains("User is a member of another organization.", result[2].Item2); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpdateOrganizationKeysAsync_WithoutManageResetPassword_Throws(Guid orgId, string publicKey, + string privateKey, SutProvider sutProvider) + { + var currentContext = Substitute.For(); + currentContext.ManageResetPassword(orgId).Returns(false); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateOrganizationKeysAsync(orgId, publicKey, privateKey)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpdateOrganizationKeysAsync_KeysAlreadySet_Throws(Organization org, string publicKey, + string privateKey, SutProvider sutProvider) + { + var currentContext = sutProvider.GetDependency(); + currentContext.ManageResetPassword(org.Id).Returns(true); + + var organizationRepository = sutProvider.GetDependency(); + organizationRepository.GetByIdAsync(org.Id).Returns(org); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateOrganizationKeysAsync(org.Id, publicKey, privateKey)); + Assert.Contains("Organization Keys already exist", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpdateOrganizationKeysAsync_KeysAlreadySet_Success(Organization org, string publicKey, + string privateKey, SutProvider sutProvider) + { + org.PublicKey = null; + org.PrivateKey = null; + + var currentContext = sutProvider.GetDependency(); + currentContext.ManageResetPassword(org.Id).Returns(true); + + var organizationRepository = sutProvider.GetDependency(); + organizationRepository.GetByIdAsync(org.Id).Returns(org); + + await sutProvider.Sut.UpdateOrganizationKeysAsync(org.Id, publicKey, privateKey); + } + + [Theory] + [InlinePaidOrganizationAutoData(PlanType.EnterpriseAnnually, new object[] { "Cannot set max seat autoscaling below seat count", 1, 0, 2 })] + [InlinePaidOrganizationAutoData(PlanType.EnterpriseAnnually, new object[] { "Cannot set max seat autoscaling below seat count", 4, -1, 6 })] + [InlineFreeOrganizationAutoData("Your plan does not allow seat autoscaling", 10, 0, null)] + public async Task UpdateSubscription_BadInputThrows(string expectedMessage, + int? maxAutoscaleSeats, int seatAdjustment, int? currentSeats, Organization organization, SutProvider sutProvider) + { + organization.Seats = currentSeats; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscription(organization.Id, + seatAdjustment, maxAutoscaleSeats)); + + Assert.Contains(expectedMessage, exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpdateSubscription_NoOrganization_Throws(Guid organizationId, SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(organizationId).Returns((Organization)null); + + await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscription(organizationId, 0, null)); + } + + [Theory] + [InlinePaidOrganizationAutoData(0, 100, null, true, "")] + [InlinePaidOrganizationAutoData(0, 100, 100, true, "")] + [InlinePaidOrganizationAutoData(0, null, 100, true, "")] + [InlinePaidOrganizationAutoData(1, 100, null, true, "")] + [InlinePaidOrganizationAutoData(1, 100, 100, false, "Cannot invite new users. Seat limit has been reached")] + public void CanScale(int seatsToAdd, int? currentSeats, int? maxAutoscaleSeats, + bool expectedResult, string expectedFailureMessage, Organization organization, + SutProvider sutProvider) + { + organization.Seats = currentSeats; + organization.MaxAutoscaleSeats = maxAutoscaleSeats; + sutProvider.GetDependency().ManageUsers(organization.Id).Returns(true); + + var (result, failureMessage) = sutProvider.Sut.CanScale(organization, seatsToAdd); + + if (expectedFailureMessage == string.Empty) { - sutProvider.GetDependency().SelfHosted.Returns(true); - var (result, failureMessage) = sutProvider.Sut.CanScale(organization, 10); - - Assert.False(result); - Assert.Contains("Cannot autoscale on self-hosted instance", failureMessage); + Assert.Empty(failureMessage); } - - [Theory, PaidOrganizationAutoData] - public async Task Delete_Success(Organization organization, SutProvider sutProvider) + else { - var organizationRepository = sutProvider.GetDependency(); - var applicationCacheService = sutProvider.GetDependency(); - - await sutProvider.Sut.DeleteAsync(organization); - - await organizationRepository.Received().DeleteAsync(organization); - await applicationCacheService.Received().DeleteOrganizationAbilityAsync(organization.Id); + Assert.Contains(expectedFailureMessage, failureMessage); } + Assert.Equal(expectedResult, result); + } - [Theory, PaidOrganizationAutoData] - public async Task Delete_Fails_KeyConnector(Organization organization, SutProvider sutProvider, - SsoConfig ssoConfig) - { - ssoConfig.Enabled = true; - ssoConfig.SetData(new SsoConfigurationData { KeyConnectorEnabled = true }); - var ssoConfigRepository = sutProvider.GetDependency(); - var organizationRepository = sutProvider.GetDependency(); - var applicationCacheService = sutProvider.GetDependency(); + [Theory, PaidOrganizationAutoData] + public void CanScale_FailsOnSelfHosted(Organization organization, + SutProvider sutProvider) + { + sutProvider.GetDependency().SelfHosted.Returns(true); + var (result, failureMessage) = sutProvider.Sut.CanScale(organization, 10); - ssoConfigRepository.GetByOrganizationIdAsync(organization.Id).Returns(ssoConfig); + Assert.False(result); + Assert.Contains("Cannot autoscale on self-hosted instance", failureMessage); + } - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.DeleteAsync(organization)); + [Theory, PaidOrganizationAutoData] + public async Task Delete_Success(Organization organization, SutProvider sutProvider) + { + var organizationRepository = sutProvider.GetDependency(); + var applicationCacheService = sutProvider.GetDependency(); - Assert.Contains("You cannot delete an Organization that is using Key Connector.", exception.Message); + await sutProvider.Sut.DeleteAsync(organization); - await organizationRepository.DidNotReceiveWithAnyArgs().DeleteAsync(default); - await applicationCacheService.DidNotReceiveWithAnyArgs().DeleteOrganizationAbilityAsync(default); - } + await organizationRepository.Received().DeleteAsync(organization); + await applicationCacheService.Received().DeleteOrganizationAbilityAsync(organization.Id); + } + + [Theory, PaidOrganizationAutoData] + public async Task Delete_Fails_KeyConnector(Organization organization, SutProvider sutProvider, + SsoConfig ssoConfig) + { + ssoConfig.Enabled = true; + ssoConfig.SetData(new SsoConfigurationData { KeyConnectorEnabled = true }); + var ssoConfigRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var applicationCacheService = sutProvider.GetDependency(); + + ssoConfigRepository.GetByOrganizationIdAsync(organization.Id).Returns(ssoConfig); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(organization)); + + Assert.Contains("You cannot delete an Organization that is using Key Connector.", exception.Message); + + await organizationRepository.DidNotReceiveWithAnyArgs().DeleteAsync(default); + await applicationCacheService.DidNotReceiveWithAnyArgs().DeleteOrganizationAbilityAsync(default); } } diff --git a/test/Core.Test/Services/PolicyServiceTests.cs b/test/Core.Test/Services/PolicyServiceTests.cs index 8f99b816c..29b4285a1 100644 --- a/test/Core.Test/Services/PolicyServiceTests.cs +++ b/test/Core.Test/Services/PolicyServiceTests.cs @@ -10,392 +10,391 @@ using NSubstitute; using Xunit; using PolicyFixtures = Bit.Core.Test.AutoFixture.PolicyFixtures; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +[SutProviderCustomize] +public class PolicyServiceTests { - [SutProviderCustomize] - public class PolicyServiceTests + [Theory, BitAutoData] + public async Task SaveAsync_OrganizationDoesNotExist_ThrowsBadRequest( + [PolicyFixtures.Policy(PolicyType.DisableSend)] Policy policy, SutProvider sutProvider) { - [Theory, BitAutoData] - public async Task SaveAsync_OrganizationDoesNotExist_ThrowsBadRequest( - [PolicyFixtures.Policy(PolicyType.DisableSend)] Policy policy, SutProvider sutProvider) + SetupOrg(sutProvider, policy.OrganizationId, null); + + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policy, + Substitute.For(), + Substitute.For(), + Guid.NewGuid())); + + Assert.Contains("Organization not found", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogPolicyEventAsync(default, default, default); + } + + [Theory, BitAutoData] + public async Task SaveAsync_OrganizationCannotUsePolicies_ThrowsBadRequest( + [PolicyFixtures.Policy(PolicyType.DisableSend)] Policy policy, SutProvider sutProvider) + { + var orgId = Guid.NewGuid(); + + SetupOrg(sutProvider, policy.OrganizationId, new Organization { - SetupOrg(sutProvider, policy.OrganizationId, null); + UsePolicies = false, + }); - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), - Substitute.For(), - Guid.NewGuid())); + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policy, + Substitute.For(), + Substitute.For(), + Guid.NewGuid())); - Assert.Contains("Organization not found", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + Assert.Contains("cannot use policies", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default, default); - } + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogPolicyEventAsync(default, default, default); + } - [Theory, BitAutoData] - public async Task SaveAsync_OrganizationCannotUsePolicies_ThrowsBadRequest( - [PolicyFixtures.Policy(PolicyType.DisableSend)] Policy policy, SutProvider sutProvider) + [Theory, BitAutoData] + public async Task SaveAsync_SingleOrg_RequireSsoEnabled_ThrowsBadRequest( + [PolicyFixtures.Policy(PolicyType.SingleOrg)] Policy policy, SutProvider sutProvider) + { + policy.Enabled = false; + + SetupOrg(sutProvider, policy.OrganizationId, new Organization { - var orgId = Guid.NewGuid(); + Id = policy.OrganizationId, + UsePolicies = true, + }); - SetupOrg(sutProvider, policy.OrganizationId, new Organization + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(policy.OrganizationId, PolicyType.RequireSso) + .Returns(Task.FromResult(new Policy { Enabled = true })); + + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policy, + Substitute.For(), + Substitute.For(), + Guid.NewGuid())); + + Assert.Contains("Single Sign-On Authentication policy is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogPolicyEventAsync(default, default, default); + } + + [Theory, BitAutoData] + public async Task SaveAsync_SingleOrg_VaultTimeoutEnabled_ThrowsBadRequest([PolicyFixtures.Policy(Enums.PolicyType.SingleOrg)] Policy policy, SutProvider sutProvider) + { + policy.Enabled = false; + + SetupOrg(sutProvider, policy.OrganizationId, new Organization + { + Id = policy.OrganizationId, + UsePolicies = true, + }); + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(policy.OrganizationId, Enums.PolicyType.MaximumVaultTimeout) + .Returns(new Policy { Enabled = true }); + + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policy, + Substitute.For(), + Substitute.For(), + Guid.NewGuid())); + + Assert.Contains("Maximum Vault Timeout policy is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } + + [Theory] + [BitAutoData(PolicyType.SingleOrg)] + [BitAutoData(PolicyType.RequireSso)] + public async Task SaveAsync_PolicyRequiredByKeyConnector_DisablePolicy_ThrowsBadRequest( + Enums.PolicyType policyType, + Policy policy, + SutProvider sutProvider) + { + policy.Enabled = false; + policy.Type = policyType; + + SetupOrg(sutProvider, policy.OrganizationId, new Organization + { + Id = policy.OrganizationId, + UsePolicies = true, + }); + + var ssoConfig = new SsoConfig { Enabled = true }; + var data = new SsoConfigurationData { KeyConnectorEnabled = true }; + ssoConfig.SetData(data); + + sutProvider.GetDependency() + .GetByOrganizationIdAsync(policy.OrganizationId) + .Returns(ssoConfig); + + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policy, + Substitute.For(), + Substitute.For(), + Guid.NewGuid())); + + Assert.Contains("Key Connector is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } + + [Theory, BitAutoData] + public async Task SaveAsync_RequireSsoPolicy_NotEnabled_ThrowsBadRequestAsync( + [PolicyFixtures.Policy(Enums.PolicyType.RequireSso)] Policy policy, SutProvider sutProvider) + { + policy.Enabled = true; + + SetupOrg(sutProvider, policy.OrganizationId, new Organization + { + Id = policy.OrganizationId, + UsePolicies = true, + }); + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(policy.OrganizationId, PolicyType.SingleOrg) + .Returns(Task.FromResult(new Policy { Enabled = false })); + + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policy, + Substitute.For(), + Substitute.For(), + Guid.NewGuid())); + + Assert.Contains("Single Organization policy not enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogPolicyEventAsync(default, default, default); + } + + [Theory, BitAutoData] + public async Task SaveAsync_NewPolicy_Created( + [PolicyFixtures.Policy(PolicyType.ResetPassword)] Policy policy, SutProvider sutProvider) + { + policy.Id = default; + + SetupOrg(sutProvider, policy.OrganizationId, new Organization + { + Id = policy.OrganizationId, + UsePolicies = true, + }); + + var utcNow = DateTime.UtcNow; + + await sutProvider.Sut.SaveAsync(policy, Substitute.For(), Substitute.For(), Guid.NewGuid()); + + await sutProvider.GetDependency().Received() + .LogPolicyEventAsync(policy, EventType.Policy_Updated); + + await sutProvider.GetDependency().Received() + .UpsertAsync(policy); + + Assert.True(policy.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(policy.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } + + [Theory, BitAutoData] + public async Task SaveAsync_VaultTimeoutPolicy_NotEnabled_ThrowsBadRequestAsync( + [PolicyFixtures.Policy(PolicyType.MaximumVaultTimeout)] Policy policy, SutProvider sutProvider) + { + policy.Enabled = true; + + SetupOrg(sutProvider, policy.OrganizationId, new Organization + { + Id = policy.OrganizationId, + UsePolicies = true, + }); + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(policy.OrganizationId, Enums.PolicyType.SingleOrg) + .Returns(Task.FromResult(new Policy { Enabled = false })); + + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(policy, + Substitute.For(), + Substitute.For(), + Guid.NewGuid())); + + Assert.Contains("Single Organization policy not enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogPolicyEventAsync(default, default, default); + } + + [Theory, BitAutoData] + public async Task SaveAsync_ExistingPolicy_UpdateTwoFactor( + [PolicyFixtures.Policy(PolicyType.TwoFactorAuthentication)] Policy policy, SutProvider sutProvider) + { + // If the policy that this is updating isn't enabled then do some work now that the current one is enabled + + var org = new Organization + { + Id = policy.OrganizationId, + UsePolicies = true, + Name = "TEST", + }; + + SetupOrg(sutProvider, policy.OrganizationId, org); + + sutProvider.GetDependency() + .GetByIdAsync(policy.Id) + .Returns(new Policy { - UsePolicies = false, + Id = policy.Id, + Type = PolicyType.TwoFactorAuthentication, + Enabled = false, }); - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), - Substitute.For(), - Guid.NewGuid())); - - Assert.Contains("cannot use policies", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default, default); - } - - [Theory, BitAutoData] - public async Task SaveAsync_SingleOrg_RequireSsoEnabled_ThrowsBadRequest( - [PolicyFixtures.Policy(PolicyType.SingleOrg)] Policy policy, SutProvider sutProvider) + var orgUserDetail = new Core.Models.Data.Organizations.OrganizationUsers.OrganizationUserUserDetails { - policy.Enabled = false; + Id = Guid.NewGuid(), + Status = OrganizationUserStatusType.Accepted, + Type = OrganizationUserType.User, + // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync + Email = "test@bitwarden.com", + Name = "TEST", + UserId = Guid.NewGuid(), + }; - SetupOrg(sutProvider, policy.OrganizationId, new Organization + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policy.OrganizationId) + .Returns(new List { - Id = policy.OrganizationId, - UsePolicies = true, + orgUserDetail, }); - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policy.OrganizationId, PolicyType.RequireSso) - .Returns(Task.FromResult(new Policy { Enabled = true })); + var userService = Substitute.For(); + var organizationService = Substitute.For(); - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), - Substitute.For(), - Guid.NewGuid())); + userService.TwoFactorIsEnabledAsync(orgUserDetail) + .Returns(false); - Assert.Contains("Single Sign-On Authentication policy is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + var utcNow = DateTime.UtcNow; - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); + var savingUserId = Guid.NewGuid(); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default, default); - } + await sutProvider.Sut.SaveAsync(policy, userService, organizationService, savingUserId); - [Theory, BitAutoData] - public async Task SaveAsync_SingleOrg_VaultTimeoutEnabled_ThrowsBadRequest([PolicyFixtures.Policy(Enums.PolicyType.SingleOrg)] Policy policy, SutProvider sutProvider) + await organizationService.Received() + .DeleteUserAsync(policy.OrganizationId, orgUserDetail.Id, savingUserId); + + await sutProvider.GetDependency().Received() + .SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(org.Name, orgUserDetail.Email); + + await sutProvider.GetDependency().Received() + .LogPolicyEventAsync(policy, EventType.Policy_Updated); + + await sutProvider.GetDependency().Received() + .UpsertAsync(policy); + + Assert.True(policy.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(policy.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } + + [Theory, BitAutoData] + public async Task SaveAsync_ExistingPolicy_UpdateSingleOrg( + [PolicyFixtures.Policy(PolicyType.TwoFactorAuthentication)] Policy policy, SutProvider sutProvider) + { + // If the policy that this is updating isn't enabled then do some work now that the current one is enabled + + var org = new Organization { - policy.Enabled = false; + Id = policy.OrganizationId, + UsePolicies = true, + Name = "TEST", + }; - SetupOrg(sutProvider, policy.OrganizationId, new Organization + SetupOrg(sutProvider, policy.OrganizationId, org); + + sutProvider.GetDependency() + .GetByIdAsync(policy.Id) + .Returns(new Policy { - Id = policy.OrganizationId, - UsePolicies = true, + Id = policy.Id, + Type = PolicyType.SingleOrg, + Enabled = false, }); - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policy.OrganizationId, Enums.PolicyType.MaximumVaultTimeout) - .Returns(new Policy { Enabled = true }); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), - Substitute.For(), - Guid.NewGuid())); - - Assert.Contains("Maximum Vault Timeout policy is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } - - [Theory] - [BitAutoData(PolicyType.SingleOrg)] - [BitAutoData(PolicyType.RequireSso)] - public async Task SaveAsync_PolicyRequiredByKeyConnector_DisablePolicy_ThrowsBadRequest( - Enums.PolicyType policyType, - Policy policy, - SutProvider sutProvider) + var orgUserDetail = new Core.Models.Data.Organizations.OrganizationUsers.OrganizationUserUserDetails { - policy.Enabled = false; - policy.Type = policyType; + Id = Guid.NewGuid(), + Status = OrganizationUserStatusType.Accepted, + Type = OrganizationUserType.User, + // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync + Email = "test@bitwarden.com", + Name = "TEST", + UserId = Guid.NewGuid(), + }; - SetupOrg(sutProvider, policy.OrganizationId, new Organization + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policy.OrganizationId) + .Returns(new List { - Id = policy.OrganizationId, - UsePolicies = true, + orgUserDetail, }); - var ssoConfig = new SsoConfig { Enabled = true }; - var data = new SsoConfigurationData { KeyConnectorEnabled = true }; - ssoConfig.SetData(data); + var userService = Substitute.For(); + var organizationService = Substitute.For(); - sutProvider.GetDependency() - .GetByOrganizationIdAsync(policy.OrganizationId) - .Returns(ssoConfig); + userService.TwoFactorIsEnabledAsync(orgUserDetail) + .Returns(false); - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), - Substitute.For(), - Guid.NewGuid())); + var utcNow = DateTime.UtcNow; - Assert.Contains("Key Connector is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + var savingUserId = Guid.NewGuid(); - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } + await sutProvider.Sut.SaveAsync(policy, userService, organizationService, savingUserId); - [Theory, BitAutoData] - public async Task SaveAsync_RequireSsoPolicy_NotEnabled_ThrowsBadRequestAsync( - [PolicyFixtures.Policy(Enums.PolicyType.RequireSso)] Policy policy, SutProvider sutProvider) - { - policy.Enabled = true; + await sutProvider.GetDependency().Received() + .LogPolicyEventAsync(policy, EventType.Policy_Updated); - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - }); + await sutProvider.GetDependency().Received() + .UpsertAsync(policy); - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policy.OrganizationId, PolicyType.SingleOrg) - .Returns(Task.FromResult(new Policy { Enabled = false })); + Assert.True(policy.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(policy.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), - Substitute.For(), - Guid.NewGuid())); - - Assert.Contains("Single Organization policy not enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default, default); - } - - [Theory, BitAutoData] - public async Task SaveAsync_NewPolicy_Created( - [PolicyFixtures.Policy(PolicyType.ResetPassword)] Policy policy, SutProvider sutProvider) - { - policy.Id = default; - - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - }); - - var utcNow = DateTime.UtcNow; - - await sutProvider.Sut.SaveAsync(policy, Substitute.For(), Substitute.For(), Guid.NewGuid()); - - await sutProvider.GetDependency().Received() - .LogPolicyEventAsync(policy, EventType.Policy_Updated); - - await sutProvider.GetDependency().Received() - .UpsertAsync(policy); - - Assert.True(policy.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(policy.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } - - [Theory, BitAutoData] - public async Task SaveAsync_VaultTimeoutPolicy_NotEnabled_ThrowsBadRequestAsync( - [PolicyFixtures.Policy(PolicyType.MaximumVaultTimeout)] Policy policy, SutProvider sutProvider) - { - policy.Enabled = true; - - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - }); - - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policy.OrganizationId, Enums.PolicyType.SingleOrg) - .Returns(Task.FromResult(new Policy { Enabled = false })); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Substitute.For(), - Substitute.For(), - Guid.NewGuid())); - - Assert.Contains("Single Organization policy not enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default, default); - } - - [Theory, BitAutoData] - public async Task SaveAsync_ExistingPolicy_UpdateTwoFactor( - [PolicyFixtures.Policy(PolicyType.TwoFactorAuthentication)] Policy policy, SutProvider sutProvider) - { - // If the policy that this is updating isn't enabled then do some work now that the current one is enabled - - var org = new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - Name = "TEST", - }; - - SetupOrg(sutProvider, policy.OrganizationId, org); - - sutProvider.GetDependency() - .GetByIdAsync(policy.Id) - .Returns(new Policy - { - Id = policy.Id, - Type = PolicyType.TwoFactorAuthentication, - Enabled = false, - }); - - var orgUserDetail = new Core.Models.Data.Organizations.OrganizationUsers.OrganizationUserUserDetails - { - Id = Guid.NewGuid(), - Status = OrganizationUserStatusType.Accepted, - Type = OrganizationUserType.User, - // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync - Email = "test@bitwarden.com", - Name = "TEST", - UserId = Guid.NewGuid(), - }; - - sutProvider.GetDependency() - .GetManyDetailsByOrganizationAsync(policy.OrganizationId) - .Returns(new List - { - orgUserDetail, - }); - - var userService = Substitute.For(); - var organizationService = Substitute.For(); - - userService.TwoFactorIsEnabledAsync(orgUserDetail) - .Returns(false); - - var utcNow = DateTime.UtcNow; - - var savingUserId = Guid.NewGuid(); - - await sutProvider.Sut.SaveAsync(policy, userService, organizationService, savingUserId); - - await organizationService.Received() - .DeleteUserAsync(policy.OrganizationId, orgUserDetail.Id, savingUserId); - - await sutProvider.GetDependency().Received() - .SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(org.Name, orgUserDetail.Email); - - await sutProvider.GetDependency().Received() - .LogPolicyEventAsync(policy, EventType.Policy_Updated); - - await sutProvider.GetDependency().Received() - .UpsertAsync(policy); - - Assert.True(policy.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(policy.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } - - [Theory, BitAutoData] - public async Task SaveAsync_ExistingPolicy_UpdateSingleOrg( - [PolicyFixtures.Policy(PolicyType.TwoFactorAuthentication)] Policy policy, SutProvider sutProvider) - { - // If the policy that this is updating isn't enabled then do some work now that the current one is enabled - - var org = new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - Name = "TEST", - }; - - SetupOrg(sutProvider, policy.OrganizationId, org); - - sutProvider.GetDependency() - .GetByIdAsync(policy.Id) - .Returns(new Policy - { - Id = policy.Id, - Type = PolicyType.SingleOrg, - Enabled = false, - }); - - var orgUserDetail = new Core.Models.Data.Organizations.OrganizationUsers.OrganizationUserUserDetails - { - Id = Guid.NewGuid(), - Status = OrganizationUserStatusType.Accepted, - Type = OrganizationUserType.User, - // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync - Email = "test@bitwarden.com", - Name = "TEST", - UserId = Guid.NewGuid(), - }; - - sutProvider.GetDependency() - .GetManyDetailsByOrganizationAsync(policy.OrganizationId) - .Returns(new List - { - orgUserDetail, - }); - - var userService = Substitute.For(); - var organizationService = Substitute.For(); - - userService.TwoFactorIsEnabledAsync(orgUserDetail) - .Returns(false); - - var utcNow = DateTime.UtcNow; - - var savingUserId = Guid.NewGuid(); - - await sutProvider.Sut.SaveAsync(policy, userService, organizationService, savingUserId); - - await sutProvider.GetDependency().Received() - .LogPolicyEventAsync(policy, EventType.Policy_Updated); - - await sutProvider.GetDependency().Received() - .UpsertAsync(policy); - - Assert.True(policy.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(policy.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } - - private static void SetupOrg(SutProvider sutProvider, Guid organizationId, Organization organization) - { - sutProvider.GetDependency() - .GetByIdAsync(organizationId) - .Returns(Task.FromResult(organization)); - } + private static void SetupOrg(SutProvider sutProvider, Guid organizationId, Organization organization) + { + sutProvider.GetDependency() + .GetByIdAsync(organizationId) + .Returns(Task.FromResult(organization)); } } diff --git a/test/Core.Test/Services/RelayPushNotificationServiceTests.cs b/test/Core.Test/Services/RelayPushNotificationServiceTests.cs index 68b8633e2..ccf5e3d4b 100644 --- a/test/Core.Test/Services/RelayPushNotificationServiceTests.cs +++ b/test/Core.Test/Services/RelayPushNotificationServiceTests.cs @@ -6,41 +6,40 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class RelayPushNotificationServiceTests { - public class RelayPushNotificationServiceTests + private readonly RelayPushNotificationService _sut; + + private readonly IHttpClientFactory _httpFactory; + private readonly IDeviceRepository _deviceRepository; + private readonly GlobalSettings _globalSettings; + private readonly IHttpContextAccessor _httpContextAccessor; + private readonly ILogger _logger; + + public RelayPushNotificationServiceTests() { - private readonly RelayPushNotificationService _sut; + _httpFactory = Substitute.For(); + _deviceRepository = Substitute.For(); + _globalSettings = new GlobalSettings(); + _httpContextAccessor = Substitute.For(); + _logger = Substitute.For>(); - private readonly IHttpClientFactory _httpFactory; - private readonly IDeviceRepository _deviceRepository; - private readonly GlobalSettings _globalSettings; - private readonly IHttpContextAccessor _httpContextAccessor; - private readonly ILogger _logger; + _sut = new RelayPushNotificationService( + _httpFactory, + _deviceRepository, + _globalSettings, + _httpContextAccessor, + _logger + ); + } - public RelayPushNotificationServiceTests() - { - _httpFactory = Substitute.For(); - _deviceRepository = Substitute.For(); - _globalSettings = new GlobalSettings(); - _httpContextAccessor = Substitute.For(); - _logger = Substitute.For>(); - - _sut = new RelayPushNotificationService( - _httpFactory, - _deviceRepository, - _globalSettings, - _httpContextAccessor, - _logger - ); - } - - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); - } + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); } } diff --git a/test/Core.Test/Services/RelayPushRegistrationServiceTests.cs b/test/Core.Test/Services/RelayPushRegistrationServiceTests.cs index 371d50168..926a19bc0 100644 --- a/test/Core.Test/Services/RelayPushRegistrationServiceTests.cs +++ b/test/Core.Test/Services/RelayPushRegistrationServiceTests.cs @@ -4,35 +4,34 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class RelayPushRegistrationServiceTests { - public class RelayPushRegistrationServiceTests + private readonly RelayPushRegistrationService _sut; + + private readonly IHttpClientFactory _httpFactory; + private readonly GlobalSettings _globalSettings; + private readonly ILogger _logger; + + public RelayPushRegistrationServiceTests() { - private readonly RelayPushRegistrationService _sut; + _globalSettings = new GlobalSettings(); + _httpFactory = Substitute.For(); + _logger = Substitute.For>(); - private readonly IHttpClientFactory _httpFactory; - private readonly GlobalSettings _globalSettings; - private readonly ILogger _logger; + _sut = new RelayPushRegistrationService( + _httpFactory, + _globalSettings, + _logger + ); + } - public RelayPushRegistrationServiceTests() - { - _globalSettings = new GlobalSettings(); - _httpFactory = Substitute.For(); - _logger = Substitute.For>(); - - _sut = new RelayPushRegistrationService( - _httpFactory, - _globalSettings, - _logger - ); - } - - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact(Skip = "Needs additional work")] - public void ServiceExists() - { - Assert.NotNull(_sut); - } + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact(Skip = "Needs additional work")] + public void ServiceExists() + { + Assert.NotNull(_sut); } } diff --git a/test/Core.Test/Services/RepositoryEventWriteServiceTests.cs b/test/Core.Test/Services/RepositoryEventWriteServiceTests.cs index 4ee3460ab..9cfe2c9e8 100644 --- a/test/Core.Test/Services/RepositoryEventWriteServiceTests.cs +++ b/test/Core.Test/Services/RepositoryEventWriteServiceTests.cs @@ -3,27 +3,26 @@ using Bit.Core.Services; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class RepositoryEventWriteServiceTests { - public class RepositoryEventWriteServiceTests + private readonly RepositoryEventWriteService _sut; + + private readonly IEventRepository _eventRepository; + + public RepositoryEventWriteServiceTests() { - private readonly RepositoryEventWriteService _sut; + _eventRepository = Substitute.For(); - private readonly IEventRepository _eventRepository; + _sut = new RepositoryEventWriteService(_eventRepository); + } - public RepositoryEventWriteServiceTests() - { - _eventRepository = Substitute.For(); - - _sut = new RepositoryEventWriteService(_eventRepository); - } - - // Remove this test when we add actual tests. It only proves that - // we've properly constructed the system under test. - [Fact] - public void ServiceExists() - { - Assert.NotNull(_sut); - } + // Remove this test when we add actual tests. It only proves that + // we've properly constructed the system under test. + [Fact] + public void ServiceExists() + { + Assert.NotNull(_sut); } } diff --git a/test/Core.Test/Services/SendGridMailDeliveryServiceTests.cs b/test/Core.Test/Services/SendGridMailDeliveryServiceTests.cs index 8366cc266..3c64e5c40 100644 --- a/test/Core.Test/Services/SendGridMailDeliveryServiceTests.cs +++ b/test/Core.Test/Services/SendGridMailDeliveryServiceTests.cs @@ -8,78 +8,77 @@ using SendGrid; using SendGrid.Helpers.Mail; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class SendGridMailDeliveryServiceTests : IDisposable { - public class SendGridMailDeliveryServiceTests : IDisposable + private readonly SendGridMailDeliveryService _sut; + + private readonly GlobalSettings _globalSettings; + private readonly IWebHostEnvironment _hostingEnvironment; + private readonly ILogger _logger; + private readonly ISendGridClient _sendGridClient; + + public SendGridMailDeliveryServiceTests() { - private readonly SendGridMailDeliveryService _sut; - - private readonly GlobalSettings _globalSettings; - private readonly IWebHostEnvironment _hostingEnvironment; - private readonly ILogger _logger; - private readonly ISendGridClient _sendGridClient; - - public SendGridMailDeliveryServiceTests() + _globalSettings = new GlobalSettings { - _globalSettings = new GlobalSettings + Mail = { - Mail = - { - SendGridApiKey = "SendGridApiKey" - } - }; + SendGridApiKey = "SendGridApiKey" + } + }; - _hostingEnvironment = Substitute.For(); - _logger = Substitute.For>(); - _sendGridClient = Substitute.For(); + _hostingEnvironment = Substitute.For(); + _logger = Substitute.For>(); + _sendGridClient = Substitute.For(); - _sut = new SendGridMailDeliveryService( - _sendGridClient, - _globalSettings, - _hostingEnvironment, - _logger - ); - } + _sut = new SendGridMailDeliveryService( + _sendGridClient, + _globalSettings, + _hostingEnvironment, + _logger + ); + } - public void Dispose() + public void Dispose() + { + _sut?.Dispose(); + } + + [Fact] + public async Task SendEmailAsync_CallsSendEmailAsync_WhenMessageIsValid() + { + var mailMessage = new MailMessage { - _sut?.Dispose(); - } + ToEmails = new List { "ToEmails" }, + BccEmails = new List { "BccEmails" }, + Subject = "Subject", + HtmlContent = "HtmlContent", + TextContent = "TextContent", + Category = "Category" + }; - [Fact] - public async Task SendEmailAsync_CallsSendEmailAsync_WhenMessageIsValid() - { - var mailMessage = new MailMessage + _sendGridClient.SendEmailAsync(Arg.Any()).Returns( + new Response(System.Net.HttpStatusCode.OK, null, null)); + await _sut.SendEmailAsync(mailMessage); + + await _sendGridClient.Received(1).SendEmailAsync( + Arg.Do(msg => { - ToEmails = new List { "ToEmails" }, - BccEmails = new List { "BccEmails" }, - Subject = "Subject", - HtmlContent = "HtmlContent", - TextContent = "TextContent", - Category = "Category" - }; + msg.Received(1).AddTos(new List { new EmailAddress(mailMessage.ToEmails.First()) }); + msg.Received(1).AddBccs(new List { new EmailAddress(mailMessage.ToEmails.First()) }); - _sendGridClient.SendEmailAsync(Arg.Any()).Returns( - new Response(System.Net.HttpStatusCode.OK, null, null)); - await _sut.SendEmailAsync(mailMessage); + Assert.Equal(mailMessage.Subject, msg.Subject); + Assert.Equal(mailMessage.HtmlContent, msg.HtmlContent); + Assert.Equal(mailMessage.TextContent, msg.PlainTextContent); - await _sendGridClient.Received(1).SendEmailAsync( - Arg.Do(msg => - { - msg.Received(1).AddTos(new List { new EmailAddress(mailMessage.ToEmails.First()) }); - msg.Received(1).AddBccs(new List { new EmailAddress(mailMessage.ToEmails.First()) }); + Assert.Contains("type:Cateogry", msg.Categories); + Assert.Contains(msg.Categories, x => x.StartsWith("env:")); + Assert.Contains(msg.Categories, x => x.StartsWith("sender:")); - Assert.Equal(mailMessage.Subject, msg.Subject); - Assert.Equal(mailMessage.HtmlContent, msg.HtmlContent); - Assert.Equal(mailMessage.TextContent, msg.PlainTextContent); - - Assert.Contains("type:Cateogry", msg.Categories); - Assert.Contains(msg.Categories, x => x.StartsWith("env:")); - Assert.Contains(msg.Categories, x => x.StartsWith("sender:")); - - msg.Received(1).SetClickTracking(false, false); - msg.Received(1).SetOpenTracking(false); - })); - } + msg.Received(1).SetClickTracking(false, false); + msg.Received(1).SetOpenTracking(false); + })); } } diff --git a/test/Core.Test/Services/SendServiceTests.cs b/test/Core.Test/Services/SendServiceTests.cs index aed7d2f04..1468bd0b0 100644 --- a/test/Core.Test/Services/SendServiceTests.cs +++ b/test/Core.Test/Services/SendServiceTests.cs @@ -14,749 +14,748 @@ using Microsoft.AspNetCore.Identity; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class SendServiceTests { - public class SendServiceTests + private void SaveSendAsync_Setup(SendType sendType, bool disableSendPolicyAppliesToUser, + SutProvider sutProvider, Send send) { - private void SaveSendAsync_Setup(SendType sendType, bool disableSendPolicyAppliesToUser, - SutProvider sutProvider, Send send) + send.Id = default; + send.Type = sendType; + + sutProvider.GetDependency().GetCountByTypeApplicableToUserIdAsync( + Arg.Any(), PolicyType.DisableSend).Returns(disableSendPolicyAppliesToUser ? 1 : 0); + } + + // Disable Send policy check + + [Theory] + [InlineUserSendAutoData(SendType.File)] + [InlineUserSendAutoData(SendType.Text)] + public async void SaveSendAsync_DisableSend_Applies_throws(SendType sendType, + SutProvider sutProvider, Send send) + { + SaveSendAsync_Setup(sendType, disableSendPolicyAppliesToUser: true, sutProvider, send); + + await Assert.ThrowsAsync(() => sutProvider.Sut.SaveSendAsync(send)); + } + + [Theory] + [InlineUserSendAutoData(SendType.File)] + [InlineUserSendAutoData(SendType.Text)] + public async void SaveSendAsync_DisableSend_DoesntApply_success(SendType sendType, + SutProvider sutProvider, Send send) + { + SaveSendAsync_Setup(sendType, disableSendPolicyAppliesToUser: false, sutProvider, send); + + await sutProvider.Sut.SaveSendAsync(send); + + await sutProvider.GetDependency().Received(1).CreateAsync(send); + } + + // Send Options Policy - Disable Hide Email check + + private void SaveSendAsync_HideEmail_Setup(bool disableHideEmailAppliesToUser, + SutProvider sutProvider, Send send, Policy policy) + { + send.HideEmail = true; + + var sendOptions = new SendOptionsPolicyData { - send.Id = default; - send.Type = sendType; - - sutProvider.GetDependency().GetCountByTypeApplicableToUserIdAsync( - Arg.Any(), PolicyType.DisableSend).Returns(disableSendPolicyAppliesToUser ? 1 : 0); - } - - // Disable Send policy check - - [Theory] - [InlineUserSendAutoData(SendType.File)] - [InlineUserSendAutoData(SendType.Text)] - public async void SaveSendAsync_DisableSend_Applies_throws(SendType sendType, - SutProvider sutProvider, Send send) + DisableHideEmail = disableHideEmailAppliesToUser + }; + policy.Data = JsonSerializer.Serialize(sendOptions, new JsonSerializerOptions { - SaveSendAsync_Setup(sendType, disableSendPolicyAppliesToUser: true, sutProvider, send); + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }); - await Assert.ThrowsAsync(() => sutProvider.Sut.SaveSendAsync(send)); - } - - [Theory] - [InlineUserSendAutoData(SendType.File)] - [InlineUserSendAutoData(SendType.Text)] - public async void SaveSendAsync_DisableSend_DoesntApply_success(SendType sendType, - SutProvider sutProvider, Send send) - { - SaveSendAsync_Setup(sendType, disableSendPolicyAppliesToUser: false, sutProvider, send); - - await sutProvider.Sut.SaveSendAsync(send); - - await sutProvider.GetDependency().Received(1).CreateAsync(send); - } - - // Send Options Policy - Disable Hide Email check - - private void SaveSendAsync_HideEmail_Setup(bool disableHideEmailAppliesToUser, - SutProvider sutProvider, Send send, Policy policy) - { - send.HideEmail = true; - - var sendOptions = new SendOptionsPolicyData + sutProvider.GetDependency().GetManyByTypeApplicableToUserIdAsync( + Arg.Any(), PolicyType.SendOptions).Returns(new List { - DisableHideEmail = disableHideEmailAppliesToUser - }; - policy.Data = JsonSerializer.Serialize(sendOptions, new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + policy, }); + } - sutProvider.GetDependency().GetManyByTypeApplicableToUserIdAsync( - Arg.Any(), PolicyType.SendOptions).Returns(new List - { - policy, - }); - } + [Theory] + [InlineUserSendAutoData(SendType.File)] + [InlineUserSendAutoData(SendType.Text)] + public async void SaveSendAsync_DisableHideEmail_Applies_throws(SendType sendType, + SutProvider sutProvider, Send send, Policy policy) + { + SaveSendAsync_Setup(sendType, false, sutProvider, send); + SaveSendAsync_HideEmail_Setup(true, sutProvider, send, policy); - [Theory] - [InlineUserSendAutoData(SendType.File)] - [InlineUserSendAutoData(SendType.Text)] - public async void SaveSendAsync_DisableHideEmail_Applies_throws(SendType sendType, - SutProvider sutProvider, Send send, Policy policy) + await Assert.ThrowsAsync(() => sutProvider.Sut.SaveSendAsync(send)); + } + + [Theory] + [InlineUserSendAutoData(SendType.File)] + [InlineUserSendAutoData(SendType.Text)] + public async void SaveSendAsync_DisableHideEmail_DoesntApply_success(SendType sendType, + SutProvider sutProvider, Send send, Policy policy) + { + SaveSendAsync_Setup(sendType, false, sutProvider, send); + SaveSendAsync_HideEmail_Setup(false, sutProvider, send, policy); + + await sutProvider.Sut.SaveSendAsync(send); + + await sutProvider.GetDependency().Received(1).CreateAsync(send); + } + + [Theory] + [InlineUserSendAutoData] + [InlineUserSendAutoData] + public async void SaveSendAsync_ExistingSend_Updates(SutProvider sutProvider, + Send send) + { + send.Id = Guid.NewGuid(); + + var now = DateTime.UtcNow; + await sutProvider.Sut.SaveSendAsync(send); + + Assert.True(send.RevisionDate - now < TimeSpan.FromSeconds(1)); + + await sutProvider.GetDependency() + .Received(1) + .UpsertAsync(send); + + await sutProvider.GetDependency() + .Received(1) + .PushSyncSendUpdateAsync(send); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_TextType_ThrowsBadRequest(SutProvider sutProvider, + Send send) + { + send.Type = SendType.Text; + + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 0) + ); + + Assert.Contains("not of type \"file\"", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_EmptyFile_ThrowsBadRequest(SutProvider sutProvider, + Send send) + { + send.Type = SendType.File; + + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 0) + ); + + Assert.Contains("no file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_UserCannotAccessPremium_ThrowsBadRequest(SutProvider sutProvider, + Send send) + { + var user = new User { - SaveSendAsync_Setup(sendType, false, sutProvider, send); - SaveSendAsync_HideEmail_Setup(true, sutProvider, send, policy); + Id = Guid.NewGuid(), + }; - await Assert.ThrowsAsync(() => sutProvider.Sut.SaveSendAsync(send)); - } + send.UserId = user.Id; + send.Type = SendType.File; - [Theory] - [InlineUserSendAutoData(SendType.File)] - [InlineUserSendAutoData(SendType.Text)] - public async void SaveSendAsync_DisableHideEmail_DoesntApply_success(SendType sendType, - SutProvider sutProvider, Send send, Policy policy) + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(false); + + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 1) + ); + + Assert.Contains("must have premium", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_UserHasUnconfirmedEmail_ThrowsBadRequest(SutProvider sutProvider, + Send send) + { + var user = new User { - SaveSendAsync_Setup(sendType, false, sutProvider, send); - SaveSendAsync_HideEmail_Setup(false, sutProvider, send, policy); + Id = Guid.NewGuid(), + EmailVerified = false, + }; - await sutProvider.Sut.SaveSendAsync(send); + send.UserId = user.Id; + send.Type = SendType.File; - await sutProvider.GetDependency().Received(1).CreateAsync(send); - } + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns(user); - [Theory] - [InlineUserSendAutoData] - [InlineUserSendAutoData] - public async void SaveSendAsync_ExistingSend_Updates(SutProvider sutProvider, - Send send) + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); + + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 1) + ); + + Assert.Contains("must confirm your email", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_UserCanAccessPremium_HasNoStorage_ThrowsBadRequest(SutProvider sutProvider, + Send send) + { + var user = new User { - send.Id = Guid.NewGuid(); + Id = Guid.NewGuid(), + EmailVerified = true, + Premium = true, + MaxStorageGb = null, + Storage = 0, + }; - var now = DateTime.UtcNow; - await sutProvider.Sut.SaveSendAsync(send); + send.UserId = user.Id; + send.Type = SendType.File; - Assert.True(send.RevisionDate - now < TimeSpan.FromSeconds(1)); + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns(user); - await sutProvider.GetDependency() - .Received(1) - .UpsertAsync(send); + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); - await sutProvider.GetDependency() - .Received(1) - .PushSyncSendUpdateAsync(send); - } + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 1) + ); - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_TextType_ThrowsBadRequest(SutProvider sutProvider, - Send send) + Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_UserCanAccessPremium_StorageFull_ThrowsBadRequest(SutProvider sutProvider, + Send send) + { + var user = new User { - send.Type = SendType.Text; + Id = Guid.NewGuid(), + EmailVerified = true, + Premium = true, + MaxStorageGb = 2, + Storage = 2 * UserTests.Multiplier, + }; - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 0) - ); + send.UserId = user.Id; + send.Type = SendType.File; - Assert.Contains("not of type \"file\"", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns(user); - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_EmptyFile_ThrowsBadRequest(SutProvider sutProvider, - Send send) + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); + + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 1) + ); + + Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_UserCanAccessPremium_IsNotPremium_IsSelfHosted_GiantFile_ThrowsBadRequest(SutProvider sutProvider, + Send send) + { + var user = new User { - send.Type = SendType.File; + Id = Guid.NewGuid(), + EmailVerified = true, + Premium = false, + }; - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 0) - ); + send.UserId = user.Id; + send.Type = SendType.File; - Assert.Contains("no file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns(user); - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_UserCannotAccessPremium_ThrowsBadRequest(SutProvider sutProvider, - Send send) + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); + + sutProvider.GetDependency() + .SelfHosted = true; + + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 11000 * UserTests.Multiplier) + ); + + Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_UserCanAccessPremium_IsNotPremium_IsNotSelfHosted_TwoGigabyteFile_ThrowsBadRequest(SutProvider sutProvider, + Send send) + { + var user = new User { - var user = new User - { - Id = Guid.NewGuid(), - }; + Id = Guid.NewGuid(), + EmailVerified = true, + Premium = false, + }; - send.UserId = user.Id; - send.Type = SendType.File; + send.UserId = user.Id; + send.Type = SendType.File; - sutProvider.GetDependency() - .GetByIdAsync(user.Id) - .Returns(user); + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns(user); - sutProvider.GetDependency() - .CanAccessPremium(user) - .Returns(false); + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 1) - ); + sutProvider.GetDependency() + .SelfHosted = false; - Assert.Contains("must have premium", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 2 * UserTests.Multiplier) + ); - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_UserHasUnconfirmedEmail_ThrowsBadRequest(SutProvider sutProvider, - Send send) + Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_ThroughOrg_MaxStorageIsNull_ThrowsBadRequest(SutProvider sutProvider, + Send send) + { + var org = new Organization { - var user = new User - { - Id = Guid.NewGuid(), - EmailVerified = false, - }; + Id = Guid.NewGuid(), + MaxStorageGb = null, + }; - send.UserId = user.Id; - send.Type = SendType.File; + send.UserId = null; + send.OrganizationId = org.Id; + send.Type = SendType.File; - sutProvider.GetDependency() - .GetByIdAsync(user.Id) - .Returns(user); + sutProvider.GetDependency() + .GetByIdAsync(org.Id) + .Returns(org); - sutProvider.GetDependency() - .CanAccessPremium(user) - .Returns(true); + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 1) + ); - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 1) - ); + Assert.Contains("organization cannot use file sends", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } - Assert.Contains("must confirm your email", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_UserCanAccessPremium_HasNoStorage_ThrowsBadRequest(SutProvider sutProvider, - Send send) + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_ThroughOrg_MaxStorageIsNull_TwoGBFile_ThrowsBadRequest(SutProvider sutProvider, + Send send) + { + var org = new Organization { - var user = new User - { - Id = Guid.NewGuid(), - EmailVerified = true, - Premium = true, - MaxStorageGb = null, - Storage = 0, - }; + Id = Guid.NewGuid(), + MaxStorageGb = null, + }; - send.UserId = user.Id; - send.Type = SendType.File; + send.UserId = null; + send.OrganizationId = org.Id; + send.Type = SendType.File; - sutProvider.GetDependency() - .GetByIdAsync(user.Id) - .Returns(user); + sutProvider.GetDependency() + .GetByIdAsync(org.Id) + .Returns(org); - sutProvider.GetDependency() - .CanAccessPremium(user) - .Returns(true); + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 1) + ); - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 1) - ); + Assert.Contains("organization cannot use file sends", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } - Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_UserCanAccessPremium_StorageFull_ThrowsBadRequest(SutProvider sutProvider, - Send send) + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_ThroughOrg_MaxStorageIsOneGB_TwoGBFile_ThrowsBadRequest(SutProvider sutProvider, + Send send) + { + var org = new Organization { - var user = new User - { - Id = Guid.NewGuid(), - EmailVerified = true, - Premium = true, - MaxStorageGb = 2, - Storage = 2 * UserTests.Multiplier, - }; + Id = Guid.NewGuid(), + MaxStorageGb = 1, + }; - send.UserId = user.Id; - send.Type = SendType.File; + send.UserId = null; + send.OrganizationId = org.Id; + send.Type = SendType.File; - sutProvider.GetDependency() - .GetByIdAsync(user.Id) - .Returns(user); + sutProvider.GetDependency() + .GetByIdAsync(org.Id) + .Returns(org); - sutProvider.GetDependency() - .CanAccessPremium(user) - .Returns(true); + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, null, 2 * UserTests.Multiplier) + ); - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 1) - ); + Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } - Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_UserCanAccessPremium_IsNotPremium_IsSelfHosted_GiantFile_ThrowsBadRequest(SutProvider sutProvider, - Send send) + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_HasEnouphStorage_Success(SutProvider sutProvider, + Send send) + { + var user = new User { - var user = new User - { - Id = Guid.NewGuid(), - EmailVerified = true, - Premium = false, - }; + Id = Guid.NewGuid(), + EmailVerified = true, + MaxStorageGb = 10, + }; - send.UserId = user.Id; - send.Type = SendType.File; - - sutProvider.GetDependency() - .GetByIdAsync(user.Id) - .Returns(user); - - sutProvider.GetDependency() - .CanAccessPremium(user) - .Returns(true); - - sutProvider.GetDependency() - .SelfHosted = true; - - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 11000 * UserTests.Multiplier) - ); - - Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_UserCanAccessPremium_IsNotPremium_IsNotSelfHosted_TwoGigabyteFile_ThrowsBadRequest(SutProvider sutProvider, - Send send) - { - var user = new User - { - Id = Guid.NewGuid(), - EmailVerified = true, - Premium = false, - }; - - send.UserId = user.Id; - send.Type = SendType.File; - - sutProvider.GetDependency() - .GetByIdAsync(user.Id) - .Returns(user); - - sutProvider.GetDependency() - .CanAccessPremium(user) - .Returns(true); - - sutProvider.GetDependency() - .SelfHosted = false; - - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 2 * UserTests.Multiplier) - ); - - Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_ThroughOrg_MaxStorageIsNull_ThrowsBadRequest(SutProvider sutProvider, - Send send) - { - var org = new Organization - { - Id = Guid.NewGuid(), - MaxStorageGb = null, - }; - - send.UserId = null; - send.OrganizationId = org.Id; - send.Type = SendType.File; - - sutProvider.GetDependency() - .GetByIdAsync(org.Id) - .Returns(org); - - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 1) - ); - - Assert.Contains("organization cannot use file sends", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_ThroughOrg_MaxStorageIsNull_TwoGBFile_ThrowsBadRequest(SutProvider sutProvider, - Send send) - { - var org = new Organization - { - Id = Guid.NewGuid(), - MaxStorageGb = null, - }; - - send.UserId = null; - send.OrganizationId = org.Id; - send.Type = SendType.File; - - sutProvider.GetDependency() - .GetByIdAsync(org.Id) - .Returns(org); - - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 1) - ); - - Assert.Contains("organization cannot use file sends", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_ThroughOrg_MaxStorageIsOneGB_TwoGBFile_ThrowsBadRequest(SutProvider sutProvider, - Send send) - { - var org = new Organization - { - Id = Guid.NewGuid(), - MaxStorageGb = 1, - }; - - send.UserId = null; - send.OrganizationId = org.Id; - send.Type = SendType.File; - - sutProvider.GetDependency() - .GetByIdAsync(org.Id) - .Returns(org); - - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, null, 2 * UserTests.Multiplier) - ); - - Assert.Contains("not enough storage", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_HasEnouphStorage_Success(SutProvider sutProvider, - Send send) - { - var user = new User - { - Id = Guid.NewGuid(), - EmailVerified = true, - MaxStorageGb = 10, - }; - - var data = new SendFileData - { - - }; - - send.UserId = user.Id; - send.Type = SendType.File; - - var testUrl = "https://test.com/"; - - sutProvider.GetDependency() - .GetByIdAsync(user.Id) - .Returns(user); - - sutProvider.GetDependency() - .CanAccessPremium(user) - .Returns(true); - - sutProvider.GetDependency() - .GetSendFileUploadUrlAsync(send, Arg.Any()) - .Returns(testUrl); - - var utcNow = DateTime.UtcNow; - - var url = await sutProvider.Sut.SaveFileSendAsync(send, data, 1 * UserTests.Multiplier); - - Assert.Equal(testUrl, url); - Assert.True(send.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - - await sutProvider.GetDependency() - .Received(1) - .GetSendFileUploadUrlAsync(send, Arg.Any()); - - await sutProvider.GetDependency() - .Received(1) - .UpsertAsync(send); - - await sutProvider.GetDependency() - .Received(1) - .PushSyncSendUpdateAsync(send); - } - - [Theory] - [InlineUserSendAutoData] - public async void SaveFileSendAsync_HasEnouphStorage_SendFileThrows_CleansUp(SutProvider sutProvider, - Send send) - { - var user = new User - { - Id = Guid.NewGuid(), - EmailVerified = true, - MaxStorageGb = 10, - }; - - var data = new SendFileData - { - - }; - - send.UserId = user.Id; - send.Type = SendType.File; - - sutProvider.GetDependency() - .GetByIdAsync(user.Id) - .Returns(user); - - sutProvider.GetDependency() - .CanAccessPremium(user) - .Returns(true); - - sutProvider.GetDependency() - .GetSendFileUploadUrlAsync(send, Arg.Any()) - .Returns(callInfo => throw new Exception("Problem")); - - var utcNow = DateTime.UtcNow; - - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.SaveFileSendAsync(send, data, 1 * UserTests.Multiplier) - ); - - Assert.True(send.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.Equal("Problem", exception.Message); - - await sutProvider.GetDependency() - .Received(1) - .GetSendFileUploadUrlAsync(send, Arg.Any()); - - await sutProvider.GetDependency() - .Received(1) - .UpsertAsync(send); - - await sutProvider.GetDependency() - .Received(1) - .PushSyncSendUpdateAsync(send); - - await sutProvider.GetDependency() - .Received(1) - .DeleteFileAsync(send, Arg.Any()); - } - - [Theory] - [InlineUserSendAutoData] - public async void UpdateFileToExistingSendAsync_SendNull_ThrowsBadRequest(SutProvider sutProvider) + var data = new SendFileData { - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), null) - ); + }; - Assert.Contains("does not have file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } + send.UserId = user.Id; + send.Type = SendType.File; - [Theory] - [InlineUserSendAutoData] - public async void UpdateFileToExistingSendAsync_SendDataNull_ThrowsBadRequest(SutProvider sutProvider, - Send send) + var testUrl = "https://test.com/"; + + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); + + sutProvider.GetDependency() + .GetSendFileUploadUrlAsync(send, Arg.Any()) + .Returns(testUrl); + + var utcNow = DateTime.UtcNow; + + var url = await sutProvider.Sut.SaveFileSendAsync(send, data, 1 * UserTests.Multiplier); + + Assert.Equal(testUrl, url); + Assert.True(send.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + + await sutProvider.GetDependency() + .Received(1) + .GetSendFileUploadUrlAsync(send, Arg.Any()); + + await sutProvider.GetDependency() + .Received(1) + .UpsertAsync(send); + + await sutProvider.GetDependency() + .Received(1) + .PushSyncSendUpdateAsync(send); + } + + [Theory] + [InlineUserSendAutoData] + public async void SaveFileSendAsync_HasEnouphStorage_SendFileThrows_CleansUp(SutProvider sutProvider, + Send send) + { + var user = new User { - send.Data = null; + Id = Guid.NewGuid(), + EmailVerified = true, + MaxStorageGb = 10, + }; - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), send) - ); - - Assert.Contains("does not have file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Theory] - [InlineUserSendAutoData] - public async void UpdateFileToExistingSendAsync_NotFileType_ThrowsBadRequest(SutProvider sutProvider, - Send send) + var data = new SendFileData { - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), send) - ); - Assert.Contains("not a file type send", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); - } + }; - [Theory] - [InlineUserSendAutoData] - public async void UpdateFileToExistingSendAsync_Success(SutProvider sutProvider, - Send send) + send.UserId = user.Id; + send.Type = SendType.File; + + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); + + sutProvider.GetDependency() + .GetSendFileUploadUrlAsync(send, Arg.Any()) + .Returns(callInfo => throw new Exception("Problem")); + + var utcNow = DateTime.UtcNow; + + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveFileSendAsync(send, data, 1 * UserTests.Multiplier) + ); + + Assert.True(send.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.Equal("Problem", exception.Message); + + await sutProvider.GetDependency() + .Received(1) + .GetSendFileUploadUrlAsync(send, Arg.Any()); + + await sutProvider.GetDependency() + .Received(1) + .UpsertAsync(send); + + await sutProvider.GetDependency() + .Received(1) + .PushSyncSendUpdateAsync(send); + + await sutProvider.GetDependency() + .Received(1) + .DeleteFileAsync(send, Arg.Any()); + } + + [Theory] + [InlineUserSendAutoData] + public async void UpdateFileToExistingSendAsync_SendNull_ThrowsBadRequest(SutProvider sutProvider) + { + + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), null) + ); + + Assert.Contains("does not have file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void UpdateFileToExistingSendAsync_SendDataNull_ThrowsBadRequest(SutProvider sutProvider, + Send send) + { + send.Data = null; + + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), send) + ); + + Assert.Contains("does not have file data", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void UpdateFileToExistingSendAsync_NotFileType_ThrowsBadRequest(SutProvider sutProvider, + Send send) + { + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(), send) + ); + + Assert.Contains("not a file type send", badRequest.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Theory] + [InlineUserSendAutoData] + public async void UpdateFileToExistingSendAsync_Success(SutProvider sutProvider, + Send send) + { + var fileContents = "Test file content"; + + var sendFileData = new SendFileData { - var fileContents = "Test file content"; + Id = "TEST", + Size = fileContents.Length, + Validated = false, + }; - var sendFileData = new SendFileData - { - Id = "TEST", - Size = fileContents.Length, - Validated = false, - }; + send.Type = SendType.File; + send.Data = JsonSerializer.Serialize(sendFileData); - send.Type = SendType.File; - send.Data = JsonSerializer.Serialize(sendFileData); + sutProvider.GetDependency() + .ValidateFileAsync(send, sendFileData.Id, sendFileData.Size, Arg.Any()) + .Returns((true, sendFileData.Size)); - sutProvider.GetDependency() - .ValidateFileAsync(send, sendFileData.Id, sendFileData.Size, Arg.Any()) - .Returns((true, sendFileData.Size)); + await sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(Encoding.UTF8.GetBytes(fileContents)), send); + } - await sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(Encoding.UTF8.GetBytes(fileContents)), send); - } + [Theory] + [InlineUserSendAutoData] + public async void UpdateFileToExistingSendAsync_InvalidSize(SutProvider sutProvider, + Send send) + { + var fileContents = "Test file content"; - [Theory] - [InlineUserSendAutoData] - public async void UpdateFileToExistingSendAsync_InvalidSize(SutProvider sutProvider, - Send send) + var sendFileData = new SendFileData { - var fileContents = "Test file content"; + Id = "TEST", + Size = fileContents.Length, + }; - var sendFileData = new SendFileData - { - Id = "TEST", - Size = fileContents.Length, - }; + send.Type = SendType.File; + send.Data = JsonSerializer.Serialize(sendFileData); - send.Type = SendType.File; - send.Data = JsonSerializer.Serialize(sendFileData); + sutProvider.GetDependency() + .ValidateFileAsync(send, sendFileData.Id, sendFileData.Size, Arg.Any()) + .Returns((false, sendFileData.Size)); - sutProvider.GetDependency() - .ValidateFileAsync(send, sendFileData.Id, sendFileData.Size, Arg.Any()) - .Returns((false, sendFileData.Size)); + var badRequest = await Assert.ThrowsAsync(() => + sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(Encoding.UTF8.GetBytes(fileContents)), send) + ); + } - var badRequest = await Assert.ThrowsAsync(() => - sutProvider.Sut.UploadFileToExistingSendAsync(new MemoryStream(Encoding.UTF8.GetBytes(fileContents)), send) - ); - } + [Theory] + [InlineUserSendAutoData] + public void SendCanBeAccessed_Success(SutProvider sutProvider, Send send) + { + var now = DateTime.UtcNow; + send.MaxAccessCount = 10; + send.AccessCount = 5; + send.ExpirationDate = now.AddYears(1); + send.DeletionDate = now.AddYears(1); + send.Disabled = false; - [Theory] - [InlineUserSendAutoData] - public void SendCanBeAccessed_Success(SutProvider sutProvider, Send send) - { - var now = DateTime.UtcNow; - send.MaxAccessCount = 10; - send.AccessCount = 5; - send.ExpirationDate = now.AddYears(1); - send.DeletionDate = now.AddYears(1); - send.Disabled = false; + sutProvider.GetDependency>() + .VerifyHashedPassword(Arg.Any(), send.Password, "TEST") + .Returns(PasswordVerificationResult.Success); - sutProvider.GetDependency>() - .VerifyHashedPassword(Arg.Any(), send.Password, "TEST") - .Returns(PasswordVerificationResult.Success); + var (grant, passwordRequiredError, passwordInvalidError) + = sutProvider.Sut.SendCanBeAccessed(send, "TEST"); - var (grant, passwordRequiredError, passwordInvalidError) - = sutProvider.Sut.SendCanBeAccessed(send, "TEST"); + Assert.True(grant); + Assert.False(passwordRequiredError); + Assert.False(passwordInvalidError); + } - Assert.True(grant); - Assert.False(passwordRequiredError); - Assert.False(passwordInvalidError); - } + [Theory] + [InlineUserSendAutoData] + public void SendCanBeAccessed_NullMaxAccess_Success(SutProvider sutProvider, + Send send) + { + var now = DateTime.UtcNow; + send.MaxAccessCount = null; + send.AccessCount = 5; + send.ExpirationDate = now.AddYears(1); + send.DeletionDate = now.AddYears(1); + send.Disabled = false; - [Theory] - [InlineUserSendAutoData] - public void SendCanBeAccessed_NullMaxAccess_Success(SutProvider sutProvider, - Send send) - { - var now = DateTime.UtcNow; - send.MaxAccessCount = null; - send.AccessCount = 5; - send.ExpirationDate = now.AddYears(1); - send.DeletionDate = now.AddYears(1); - send.Disabled = false; + sutProvider.GetDependency>() + .VerifyHashedPassword(Arg.Any(), send.Password, "TEST") + .Returns(PasswordVerificationResult.Success); - sutProvider.GetDependency>() - .VerifyHashedPassword(Arg.Any(), send.Password, "TEST") - .Returns(PasswordVerificationResult.Success); + var (grant, passwordRequiredError, passwordInvalidError) + = sutProvider.Sut.SendCanBeAccessed(send, "TEST"); - var (grant, passwordRequiredError, passwordInvalidError) - = sutProvider.Sut.SendCanBeAccessed(send, "TEST"); + Assert.True(grant); + Assert.False(passwordRequiredError); + Assert.False(passwordInvalidError); + } - Assert.True(grant); - Assert.False(passwordRequiredError); - Assert.False(passwordInvalidError); - } + [Theory] + [InlineUserSendAutoData] + public void SendCanBeAccessed_NullSend_DoesNotGrantAccess(SutProvider sutProvider) + { + sutProvider.GetDependency>() + .VerifyHashedPassword(Arg.Any(), "TEST", "TEST") + .Returns(PasswordVerificationResult.Success); - [Theory] - [InlineUserSendAutoData] - public void SendCanBeAccessed_NullSend_DoesNotGrantAccess(SutProvider sutProvider) - { - sutProvider.GetDependency>() - .VerifyHashedPassword(Arg.Any(), "TEST", "TEST") - .Returns(PasswordVerificationResult.Success); + var (grant, passwordRequiredError, passwordInvalidError) + = sutProvider.Sut.SendCanBeAccessed(null, "TEST"); - var (grant, passwordRequiredError, passwordInvalidError) - = sutProvider.Sut.SendCanBeAccessed(null, "TEST"); + Assert.False(grant); + Assert.False(passwordRequiredError); + Assert.False(passwordInvalidError); + } - Assert.False(grant); - Assert.False(passwordRequiredError); - Assert.False(passwordInvalidError); - } + [Theory] + [InlineUserSendAutoData] + public void SendCanBeAccessed_NullPassword_PasswordRequiredErrorReturnsTrue(SutProvider sutProvider, + Send send) + { + var now = DateTime.UtcNow; + send.MaxAccessCount = null; + send.AccessCount = 5; + send.ExpirationDate = now.AddYears(1); + send.DeletionDate = now.AddYears(1); + send.Disabled = false; + send.Password = "HASH"; - [Theory] - [InlineUserSendAutoData] - public void SendCanBeAccessed_NullPassword_PasswordRequiredErrorReturnsTrue(SutProvider sutProvider, - Send send) - { - var now = DateTime.UtcNow; - send.MaxAccessCount = null; - send.AccessCount = 5; - send.ExpirationDate = now.AddYears(1); - send.DeletionDate = now.AddYears(1); - send.Disabled = false; - send.Password = "HASH"; + sutProvider.GetDependency>() + .VerifyHashedPassword(Arg.Any(), "TEST", "TEST") + .Returns(PasswordVerificationResult.Success); - sutProvider.GetDependency>() - .VerifyHashedPassword(Arg.Any(), "TEST", "TEST") - .Returns(PasswordVerificationResult.Success); + var (grant, passwordRequiredError, passwordInvalidError) + = sutProvider.Sut.SendCanBeAccessed(send, null); - var (grant, passwordRequiredError, passwordInvalidError) - = sutProvider.Sut.SendCanBeAccessed(send, null); + Assert.False(grant); + Assert.True(passwordRequiredError); + Assert.False(passwordInvalidError); + } - Assert.False(grant); - Assert.True(passwordRequiredError); - Assert.False(passwordInvalidError); - } + [Theory] + [InlineUserSendAutoData] + public void SendCanBeAccessed_RehashNeeded_RehashesPassword(SutProvider sutProvider, + Send send) + { + var now = DateTime.UtcNow; + send.MaxAccessCount = null; + send.AccessCount = 5; + send.ExpirationDate = now.AddYears(1); + send.DeletionDate = now.AddYears(1); + send.Disabled = false; + send.Password = "TEST"; - [Theory] - [InlineUserSendAutoData] - public void SendCanBeAccessed_RehashNeeded_RehashesPassword(SutProvider sutProvider, - Send send) - { - var now = DateTime.UtcNow; - send.MaxAccessCount = null; - send.AccessCount = 5; - send.ExpirationDate = now.AddYears(1); - send.DeletionDate = now.AddYears(1); - send.Disabled = false; - send.Password = "TEST"; + sutProvider.GetDependency>() + .VerifyHashedPassword(Arg.Any(), "TEST", "TEST") + .Returns(PasswordVerificationResult.SuccessRehashNeeded); - sutProvider.GetDependency>() - .VerifyHashedPassword(Arg.Any(), "TEST", "TEST") - .Returns(PasswordVerificationResult.SuccessRehashNeeded); + var (grant, passwordRequiredError, passwordInvalidError) + = sutProvider.Sut.SendCanBeAccessed(send, "TEST"); - var (grant, passwordRequiredError, passwordInvalidError) - = sutProvider.Sut.SendCanBeAccessed(send, "TEST"); + sutProvider.GetDependency>() + .Received(1) + .HashPassword(Arg.Any(), "TEST"); - sutProvider.GetDependency>() - .Received(1) - .HashPassword(Arg.Any(), "TEST"); + Assert.True(grant); + Assert.False(passwordRequiredError); + Assert.False(passwordInvalidError); + } - Assert.True(grant); - Assert.False(passwordRequiredError); - Assert.False(passwordInvalidError); - } + [Theory] + [InlineUserSendAutoData] + public void SendCanBeAccessed_VerifyFailed_PasswordInvalidReturnsTrue(SutProvider sutProvider, + Send send) + { + var now = DateTime.UtcNow; + send.MaxAccessCount = null; + send.AccessCount = 5; + send.ExpirationDate = now.AddYears(1); + send.DeletionDate = now.AddYears(1); + send.Disabled = false; + send.Password = "TEST"; - [Theory] - [InlineUserSendAutoData] - public void SendCanBeAccessed_VerifyFailed_PasswordInvalidReturnsTrue(SutProvider sutProvider, - Send send) - { - var now = DateTime.UtcNow; - send.MaxAccessCount = null; - send.AccessCount = 5; - send.ExpirationDate = now.AddYears(1); - send.DeletionDate = now.AddYears(1); - send.Disabled = false; - send.Password = "TEST"; + sutProvider.GetDependency>() + .VerifyHashedPassword(Arg.Any(), "TEST", "TEST") + .Returns(PasswordVerificationResult.Failed); - sutProvider.GetDependency>() - .VerifyHashedPassword(Arg.Any(), "TEST", "TEST") - .Returns(PasswordVerificationResult.Failed); + var (grant, passwordRequiredError, passwordInvalidError) + = sutProvider.Sut.SendCanBeAccessed(send, "TEST"); - var (grant, passwordRequiredError, passwordInvalidError) - = sutProvider.Sut.SendCanBeAccessed(send, "TEST"); - - Assert.False(grant); - Assert.False(passwordRequiredError); - Assert.True(passwordInvalidError); - } + Assert.False(grant); + Assert.False(passwordRequiredError); + Assert.True(passwordInvalidError); } } diff --git a/test/Core.Test/Services/SsoConfigServiceTests.cs b/test/Core.Test/Services/SsoConfigServiceTests.cs index 475a2a1c5..fa5cb904a 100644 --- a/test/Core.Test/Services/SsoConfigServiceTests.cs +++ b/test/Core.Test/Services/SsoConfigServiceTests.cs @@ -9,309 +9,308 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class SsoConfigServiceTests { - public class SsoConfigServiceTests + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_ExistingItem_UpdatesRevisionDateOnly(SutProvider sutProvider, + Organization organization) { - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_ExistingItem_UpdatesRevisionDateOnly(SutProvider sutProvider, - Organization organization) + var utcNow = DateTime.UtcNow; + + var ssoConfig = new SsoConfig { - var utcNow = DateTime.UtcNow; + Id = 1, + Data = "{}", + Enabled = true, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow.AddDays(-10), + }; - var ssoConfig = new SsoConfig - { - Id = 1, - Data = "{}", - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow.AddDays(-10), - }; + sutProvider.GetDependency() + .UpsertAsync(ssoConfig).Returns(Task.CompletedTask); - sutProvider.GetDependency() - .UpsertAsync(ssoConfig).Returns(Task.CompletedTask); + await sutProvider.Sut.SaveAsync(ssoConfig, organization); - await sutProvider.Sut.SaveAsync(ssoConfig, organization); + await sutProvider.GetDependency().Received() + .UpsertAsync(ssoConfig); - await sutProvider.GetDependency().Received() - .UpsertAsync(ssoConfig); + Assert.Equal(utcNow.AddDays(-10), ssoConfig.CreationDate); + Assert.True(ssoConfig.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } - Assert.Equal(utcNow.AddDays(-10), ssoConfig.CreationDate); - Assert.True(ssoConfig.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_NewItem_UpdatesCreationAndRevisionDate(SutProvider sutProvider, + Organization organization) + { + var utcNow = DateTime.UtcNow; - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_NewItem_UpdatesCreationAndRevisionDate(SutProvider sutProvider, - Organization organization) + var ssoConfig = new SsoConfig { - var utcNow = DateTime.UtcNow; + Id = default, + Data = "{}", + Enabled = true, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow.AddDays(-10), + }; - var ssoConfig = new SsoConfig - { - Id = default, - Data = "{}", - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow.AddDays(-10), - }; + sutProvider.GetDependency() + .UpsertAsync(ssoConfig).Returns(Task.CompletedTask); - sutProvider.GetDependency() - .UpsertAsync(ssoConfig).Returns(Task.CompletedTask); + await sutProvider.Sut.SaveAsync(ssoConfig, organization); - await sutProvider.Sut.SaveAsync(ssoConfig, organization); + await sutProvider.GetDependency().Received() + .UpsertAsync(ssoConfig); - await sutProvider.GetDependency().Received() - .UpsertAsync(ssoConfig); + Assert.True(ssoConfig.CreationDate - utcNow < TimeSpan.FromSeconds(1)); + Assert.True(ssoConfig.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); + } - Assert.True(ssoConfig.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(ssoConfig.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_PreventDisablingKeyConnector(SutProvider sutProvider, + Organization organization) + { + var utcNow = DateTime.UtcNow; - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_PreventDisablingKeyConnector(SutProvider sutProvider, - Organization organization) + var oldSsoConfig = new SsoConfig { - var utcNow = DateTime.UtcNow; - - var oldSsoConfig = new SsoConfig + Id = 1, + Data = new SsoConfigurationData { - Id = 1, - Data = new SsoConfigurationData - { - KeyConnectorEnabled = true, - }.Serialize(), - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow.AddDays(-10), - }; + KeyConnectorEnabled = true, + }.Serialize(), + Enabled = true, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow.AddDays(-10), + }; - var newSsoConfig = new SsoConfig - { - Id = 1, - Data = "{}", - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow, - }; - - var ssoConfigRepository = sutProvider.GetDependency(); - ssoConfigRepository.GetByOrganizationIdAsync(organization.Id).Returns(oldSsoConfig); - ssoConfigRepository.UpsertAsync(newSsoConfig).Returns(Task.CompletedTask); - sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(organization.Id) - .Returns(new[] { new OrganizationUserUserDetails { UsesKeyConnector = true } }); - - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(newSsoConfig, organization)); - - Assert.Contains("Key Connector cannot be disabled at this moment.", exception.Message); - - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_AllowDisablingKeyConnectorWhenNoUserIsUsingIt( - SutProvider sutProvider, Organization organization) + var newSsoConfig = new SsoConfig { - var utcNow = DateTime.UtcNow; + Id = 1, + Data = "{}", + Enabled = true, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow, + }; - var oldSsoConfig = new SsoConfig - { - Id = 1, - Data = new SsoConfigurationData - { - KeyConnectorEnabled = true, - }.Serialize(), - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow.AddDays(-10), - }; + var ssoConfigRepository = sutProvider.GetDependency(); + ssoConfigRepository.GetByOrganizationIdAsync(organization.Id).Returns(oldSsoConfig); + ssoConfigRepository.UpsertAsync(newSsoConfig).Returns(Task.CompletedTask); + sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(organization.Id) + .Returns(new[] { new OrganizationUserUserDetails { UsesKeyConnector = true } }); - var newSsoConfig = new SsoConfig - { - Id = 1, - Data = "{}", - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow, - }; + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(newSsoConfig, organization)); - var ssoConfigRepository = sutProvider.GetDependency(); - ssoConfigRepository.GetByOrganizationIdAsync(organization.Id).Returns(oldSsoConfig); - ssoConfigRepository.UpsertAsync(newSsoConfig).Returns(Task.CompletedTask); - sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(organization.Id) - .Returns(new[] { new OrganizationUserUserDetails { UsesKeyConnector = false } }); + Assert.Contains("Key Connector cannot be disabled at this moment.", exception.Message); - await sutProvider.Sut.SaveAsync(newSsoConfig, organization); - } + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_KeyConnector_SingleOrgNotEnabled_Throws(SutProvider sutProvider, - Organization organization) + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_AllowDisablingKeyConnectorWhenNoUserIsUsingIt( + SutProvider sutProvider, Organization organization) + { + var utcNow = DateTime.UtcNow; + + var oldSsoConfig = new SsoConfig { - var utcNow = DateTime.UtcNow; - - var ssoConfig = new SsoConfig + Id = 1, + Data = new SsoConfigurationData { - Id = default, - Data = new SsoConfigurationData - { - KeyConnectorEnabled = true, - }.Serialize(), - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow.AddDays(-10), - }; + KeyConnectorEnabled = true, + }.Serialize(), + Enabled = true, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow.AddDays(-10), + }; - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); - - Assert.Contains("Key Connector requires the Single Organization policy to be enabled.", exception.Message); - - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_KeyConnector_SsoPolicyNotEnabled_Throws(SutProvider sutProvider, - Organization organization) + var newSsoConfig = new SsoConfig { - var utcNow = DateTime.UtcNow; + Id = 1, + Data = "{}", + Enabled = true, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow, + }; - var ssoConfig = new SsoConfig - { - Id = default, - Data = new SsoConfigurationData - { - KeyConnectorEnabled = true, - }.Serialize(), - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow.AddDays(-10), - }; + var ssoConfigRepository = sutProvider.GetDependency(); + ssoConfigRepository.GetByOrganizationIdAsync(organization.Id).Returns(oldSsoConfig); + ssoConfigRepository.UpsertAsync(newSsoConfig).Returns(Task.CompletedTask); + sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(organization.Id) + .Returns(new[] { new OrganizationUserUserDetails { UsesKeyConnector = false } }); - sutProvider.GetDependency().GetByOrganizationIdTypeAsync( - Arg.Any(), Enums.PolicyType.SingleOrg).Returns(new Policy - { - Enabled = true - }); + await sutProvider.Sut.SaveAsync(newSsoConfig, organization); + } - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_KeyConnector_SingleOrgNotEnabled_Throws(SutProvider sutProvider, + Organization organization) + { + var utcNow = DateTime.UtcNow; - Assert.Contains("Key Connector requires the Single Sign-On Authentication policy to be enabled.", exception.Message); - - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_KeyConnector_SsoConfigNotEnabled_Throws(SutProvider sutProvider, - Organization organization) + var ssoConfig = new SsoConfig { - var utcNow = DateTime.UtcNow; - - var ssoConfig = new SsoConfig + Id = default, + Data = new SsoConfigurationData { - Id = default, - Data = new SsoConfigurationData - { - KeyConnectorEnabled = true, - }.Serialize(), - Enabled = false, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow.AddDays(-10), - }; + KeyConnectorEnabled = true, + }.Serialize(), + Enabled = true, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow.AddDays(-10), + }; - sutProvider.GetDependency().GetByOrganizationIdTypeAsync( - Arg.Any(), Arg.Any()).Returns(new Policy - { - Enabled = true - }); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); + Assert.Contains("Key Connector requires the Single Organization policy to be enabled.", exception.Message); - Assert.Contains("You must enable SSO to use Key Connector.", exception.Message); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_KeyConnector_SsoPolicyNotEnabled_Throws(SutProvider sutProvider, + Organization organization) + { + var utcNow = DateTime.UtcNow; - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_KeyConnector_KeyConnectorAbilityNotEnabled_Throws(SutProvider sutProvider, - Organization organization) + var ssoConfig = new SsoConfig { - var utcNow = DateTime.UtcNow; - - organization.UseKeyConnector = false; - var ssoConfig = new SsoConfig + Id = default, + Data = new SsoConfigurationData { - Id = default, - Data = new SsoConfigurationData - { - KeyConnectorEnabled = true, - }.Serialize(), - Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow.AddDays(-10), - }; + KeyConnectorEnabled = true, + }.Serialize(), + Enabled = true, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow.AddDays(-10), + }; - sutProvider.GetDependency().GetByOrganizationIdTypeAsync( - Arg.Any(), Arg.Any()).Returns(new Policy - { - Enabled = true, - }); + sutProvider.GetDependency().GetByOrganizationIdTypeAsync( + Arg.Any(), Enums.PolicyType.SingleOrg).Returns(new Policy + { + Enabled = true + }); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); - Assert.Contains("Organization cannot use Key Connector.", exception.Message); + Assert.Contains("Key Connector requires the Single Sign-On Authentication policy to be enabled.", exception.Message); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SaveAsync_KeyConnector_Success(SutProvider sutProvider, - Organization organization) + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_KeyConnector_SsoConfigNotEnabled_Throws(SutProvider sutProvider, + Organization organization) + { + var utcNow = DateTime.UtcNow; + + var ssoConfig = new SsoConfig { - var utcNow = DateTime.UtcNow; - - organization.UseKeyConnector = true; - var ssoConfig = new SsoConfig + Id = default, + Data = new SsoConfigurationData + { + KeyConnectorEnabled = true, + }.Serialize(), + Enabled = false, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow.AddDays(-10), + }; + + sutProvider.GetDependency().GetByOrganizationIdTypeAsync( + Arg.Any(), Arg.Any()).Returns(new Policy + { + Enabled = true + }); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); + + Assert.Contains("You must enable SSO to use Key Connector.", exception.Message); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_KeyConnector_KeyConnectorAbilityNotEnabled_Throws(SutProvider sutProvider, + Organization organization) + { + var utcNow = DateTime.UtcNow; + + organization.UseKeyConnector = false; + var ssoConfig = new SsoConfig + { + Id = default, + Data = new SsoConfigurationData + { + KeyConnectorEnabled = true, + }.Serialize(), + Enabled = true, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow.AddDays(-10), + }; + + sutProvider.GetDependency().GetByOrganizationIdTypeAsync( + Arg.Any(), Arg.Any()).Returns(new Policy { - Id = default, - Data = new SsoConfigurationData - { - KeyConnectorEnabled = true, - }.Serialize(), Enabled = true, - OrganizationId = organization.Id, - CreationDate = utcNow.AddDays(-10), - RevisionDate = utcNow.AddDays(-10), - }; + }); - sutProvider.GetDependency().GetByOrganizationIdTypeAsync( - Arg.Any(), Arg.Any()).Returns(new Policy - { - Enabled = true, - }); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); - await sutProvider.Sut.SaveAsync(ssoConfig, organization); + Assert.Contains("Organization cannot use Key Connector.", exception.Message); - await sutProvider.GetDependency().ReceivedWithAnyArgs() - .UpsertAsync(default); - } + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .UpsertAsync(default); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SaveAsync_KeyConnector_Success(SutProvider sutProvider, + Organization organization) + { + var utcNow = DateTime.UtcNow; + + organization.UseKeyConnector = true; + var ssoConfig = new SsoConfig + { + Id = default, + Data = new SsoConfigurationData + { + KeyConnectorEnabled = true, + }.Serialize(), + Enabled = true, + OrganizationId = organization.Id, + CreationDate = utcNow.AddDays(-10), + RevisionDate = utcNow.AddDays(-10), + }; + + sutProvider.GetDependency().GetByOrganizationIdTypeAsync( + Arg.Any(), Arg.Any()).Returns(new Policy + { + Enabled = true, + }); + + await sutProvider.Sut.SaveAsync(ssoConfig, organization); + + await sutProvider.GetDependency().ReceivedWithAnyArgs() + .UpsertAsync(default); } } diff --git a/test/Core.Test/Services/StripePaymentServiceTests.cs b/test/Core.Test/Services/StripePaymentServiceTests.cs index 0c4ea5c03..a14f183d4 100644 --- a/test/Core.Test/Services/StripePaymentServiceTests.cs +++ b/test/Core.Test/Services/StripePaymentServiceTests.cs @@ -12,363 +12,362 @@ using NSubstitute; using Xunit; using PaymentMethodType = Bit.Core.Enums.PaymentMethodType; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class StripePaymentServiceTests { - public class StripePaymentServiceTests + [Theory] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.BitPay)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.BitPay)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.Credit)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.WireTransfer)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.AppleInApp)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.GoogleInApp)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.Check)] + public async void PurchaseOrganizationAsync_Invalid(PaymentMethodType paymentMethodType, SutProvider sutProvider) { - [Theory] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.BitPay)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.BitPay)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.Credit)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.WireTransfer)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.AppleInApp)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.GoogleInApp)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, PaymentMethodType.Check)] - public async void PurchaseOrganizationAsync_Invalid(PaymentMethodType paymentMethodType, SutProvider sutProvider) + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.PurchaseOrganizationAsync(null, paymentMethodType, null, null, 0, 0, false, null)); + + Assert.Equal("Payment method is not supported at this time.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void PurchaseOrganizationAsync_Stripe(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + { + var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer { - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.PurchaseOrganizationAsync(null, paymentMethodType, null, null, 0, 0, false, null)); - - Assert.Equal("Payment method is not supported at this time.", exception.Message); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void PurchaseOrganizationAsync_Stripe(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + Id = "C-1", + }); + stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription { - var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + Id = "S-1", + CurrentPeriodEnd = DateTime.Today.AddDays(10), + }); - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer - { - Id = "C-1", - }); - stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription - { - Id = "S-1", - CurrentPeriodEnd = DateTime.Today.AddDays(10), - }); + var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); - var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); + Assert.Null(result); + Assert.Equal(GatewayType.Stripe, organization.Gateway); + Assert.Equal("C-1", organization.GatewayCustomerId); + Assert.Equal("S-1", organization.GatewaySubscriptionId); + Assert.True(organization.Enabled); + Assert.Equal(DateTime.Today.AddDays(10), organization.ExpirationDate); - Assert.Null(result); - Assert.Equal(GatewayType.Stripe, organization.Gateway); - Assert.Equal("C-1", organization.GatewayCustomerId); - Assert.Equal("S-1", organization.GatewaySubscriptionId); - Assert.True(organization.Enabled); - Assert.Equal(DateTime.Today.AddDays(10), organization.ExpirationDate); + await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(c => + c.Description == organization.BusinessName && + c.Email == organization.BillingEmail && + c.Source == paymentToken && + c.PaymentMethod == null && + !c.Metadata.Any() && + c.InvoiceSettings.DefaultPaymentMethod == null && + c.Address.Country == taxInfo.BillingAddressCountry && + c.Address.PostalCode == taxInfo.BillingAddressPostalCode && + c.Address.Line1 == taxInfo.BillingAddressLine1 && + c.Address.Line2 == taxInfo.BillingAddressLine2 && + c.Address.City == taxInfo.BillingAddressCity && + c.Address.State == taxInfo.BillingAddressState && + c.TaxIdData == null + )); - await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(c => - c.Description == organization.BusinessName && - c.Email == organization.BillingEmail && - c.Source == paymentToken && - c.PaymentMethod == null && - !c.Metadata.Any() && - c.InvoiceSettings.DefaultPaymentMethod == null && - c.Address.Country == taxInfo.BillingAddressCountry && - c.Address.PostalCode == taxInfo.BillingAddressPostalCode && - c.Address.Line1 == taxInfo.BillingAddressLine1 && - c.Address.Line2 == taxInfo.BillingAddressLine2 && - c.Address.City == taxInfo.BillingAddressCity && - c.Address.State == taxInfo.BillingAddressState && - c.TaxIdData == null - )); + await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => + s.Customer == "C-1" && + s.Expand[0] == "latest_invoice.payment_intent" && + s.Metadata[organization.GatewayIdField()] == organization.Id.ToString() && + s.Items.Count == 0 + )); + } - await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => - s.Customer == "C-1" && - s.Expand[0] == "latest_invoice.payment_intent" && - s.Metadata[organization.GatewayIdField()] == organization.Id.ToString() && - s.Items.Count == 0 - )); - } + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void PurchaseOrganizationAsync_Stripe_PM(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + { + var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + paymentToken = "pm_" + paymentToken; - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void PurchaseOrganizationAsync_Stripe_PM(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer { - var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); - paymentToken = "pm_" + paymentToken; - - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer - { - Id = "C-1", - }); - stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription - { - Id = "S-1", - CurrentPeriodEnd = DateTime.Today.AddDays(10), - }); - - var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); - - Assert.Null(result); - Assert.Equal(GatewayType.Stripe, organization.Gateway); - Assert.Equal("C-1", organization.GatewayCustomerId); - Assert.Equal("S-1", organization.GatewaySubscriptionId); - Assert.True(organization.Enabled); - Assert.Equal(DateTime.Today.AddDays(10), organization.ExpirationDate); - - await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(c => - c.Description == organization.BusinessName && - c.Email == organization.BillingEmail && - c.Source == null && - c.PaymentMethod == paymentToken && - !c.Metadata.Any() && - c.InvoiceSettings.DefaultPaymentMethod == paymentToken && - c.Address.Country == taxInfo.BillingAddressCountry && - c.Address.PostalCode == taxInfo.BillingAddressPostalCode && - c.Address.Line1 == taxInfo.BillingAddressLine1 && - c.Address.Line2 == taxInfo.BillingAddressLine2 && - c.Address.City == taxInfo.BillingAddressCity && - c.Address.State == taxInfo.BillingAddressState && - c.TaxIdData == null - )); - - await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => - s.Customer == "C-1" && - s.Expand[0] == "latest_invoice.payment_intent" && - s.Metadata[organization.GatewayIdField()] == organization.Id.ToString() && - s.Items.Count == 0 - )); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void PurchaseOrganizationAsync_Stripe_TaxRate(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + Id = "C-1", + }); + stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription { - var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + Id = "S-1", + CurrentPeriodEnd = DateTime.Today.AddDays(10), + }); - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer - { - Id = "C-1", - }); - stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription - { - Id = "S-1", - CurrentPeriodEnd = DateTime.Today.AddDays(10), - }); - sutProvider.GetDependency().GetByLocationAsync(Arg.Is(t => - t.Country == taxInfo.BillingAddressCountry && t.PostalCode == taxInfo.BillingAddressPostalCode)) - .Returns(new List { new() { Id = "T-1" } }); + var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); - var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); + Assert.Null(result); + Assert.Equal(GatewayType.Stripe, organization.Gateway); + Assert.Equal("C-1", organization.GatewayCustomerId); + Assert.Equal("S-1", organization.GatewaySubscriptionId); + Assert.True(organization.Enabled); + Assert.Equal(DateTime.Today.AddDays(10), organization.ExpirationDate); - Assert.Null(result); + await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(c => + c.Description == organization.BusinessName && + c.Email == organization.BillingEmail && + c.Source == null && + c.PaymentMethod == paymentToken && + !c.Metadata.Any() && + c.InvoiceSettings.DefaultPaymentMethod == paymentToken && + c.Address.Country == taxInfo.BillingAddressCountry && + c.Address.PostalCode == taxInfo.BillingAddressPostalCode && + c.Address.Line1 == taxInfo.BillingAddressLine1 && + c.Address.Line2 == taxInfo.BillingAddressLine2 && + c.Address.City == taxInfo.BillingAddressCity && + c.Address.State == taxInfo.BillingAddressState && + c.TaxIdData == null + )); - await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => - s.DefaultTaxRates.Count == 1 && - s.DefaultTaxRates[0] == "T-1" - )); - } + await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => + s.Customer == "C-1" && + s.Expand[0] == "latest_invoice.payment_intent" && + s.Metadata[organization.GatewayIdField()] == organization.Id.ToString() && + s.Items.Count == 0 + )); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void PurchaseOrganizationAsync_Stripe_Declined(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void PurchaseOrganizationAsync_Stripe_TaxRate(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + { + var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer { - var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); - paymentToken = "pm_" + paymentToken; + Id = "C-1", + }); + stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription + { + Id = "S-1", + CurrentPeriodEnd = DateTime.Today.AddDays(10), + }); + sutProvider.GetDependency().GetByLocationAsync(Arg.Is(t => + t.Country == taxInfo.BillingAddressCountry && t.PostalCode == taxInfo.BillingAddressPostalCode)) + .Returns(new List { new() { Id = "T-1" } }); - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer + var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); + + Assert.Null(result); + + await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => + s.DefaultTaxRates.Count == 1 && + s.DefaultTaxRates[0] == "T-1" + )); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void PurchaseOrganizationAsync_Stripe_Declined(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + { + var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + paymentToken = "pm_" + paymentToken; + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer + { + Id = "C-1", + }); + stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription + { + Id = "S-1", + CurrentPeriodEnd = DateTime.Today.AddDays(10), + Status = "incomplete", + LatestInvoice = new Stripe.Invoice { - Id = "C-1", - }); - stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription - { - Id = "S-1", - CurrentPeriodEnd = DateTime.Today.AddDays(10), - Status = "incomplete", - LatestInvoice = new Stripe.Invoice + PaymentIntent = new Stripe.PaymentIntent { - PaymentIntent = new Stripe.PaymentIntent - { - Status = "requires_payment_method", - }, + Status = "requires_payment_method", }, - }); + }, + }); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo)); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo)); - Assert.Equal("Payment method was declined.", exception.Message); + Assert.Equal("Payment method was declined.", exception.Message); - await stripeAdapter.Received(1).CustomerDeleteAsync("C-1"); - } + await stripeAdapter.Received(1).CustomerDeleteAsync("C-1"); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void PurchaseOrganizationAsync_Stripe_RequiresAction(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void PurchaseOrganizationAsync_Stripe_RequiresAction(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + { + var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer { - var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); - - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer + Id = "C-1", + }); + stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription + { + Id = "S-1", + CurrentPeriodEnd = DateTime.Today.AddDays(10), + Status = "incomplete", + LatestInvoice = new Stripe.Invoice { - Id = "C-1", - }); - stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription - { - Id = "S-1", - CurrentPeriodEnd = DateTime.Today.AddDays(10), - Status = "incomplete", - LatestInvoice = new Stripe.Invoice + PaymentIntent = new Stripe.PaymentIntent { - PaymentIntent = new Stripe.PaymentIntent - { - Status = "requires_action", - ClientSecret = "clientSecret", - }, + Status = "requires_action", + ClientSecret = "clientSecret", }, - }); + }, + }); - var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); + var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo); - Assert.Equal("clientSecret", result); - Assert.False(organization.Enabled); - } + Assert.Equal("clientSecret", result); + Assert.False(organization.Enabled); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void PurchaseOrganizationAsync_Paypal(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void PurchaseOrganizationAsync_Paypal(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + { + var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer { - var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); - - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer - { - Id = "C-1", - }); - stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription - { - Id = "S-1", - CurrentPeriodEnd = DateTime.Today.AddDays(10), - }); - - var customer = Substitute.For(); - customer.Id.ReturnsForAnyArgs("Braintree-Id"); - customer.PaymentMethods.ReturnsForAnyArgs(new[] { Substitute.For() }); - var customerResult = Substitute.For>(); - customerResult.IsSuccess().Returns(true); - customerResult.Target.ReturnsForAnyArgs(customer); - - var braintreeGateway = sutProvider.GetDependency(); - braintreeGateway.Customer.CreateAsync(default).ReturnsForAnyArgs(customerResult); - - var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.PayPal, paymentToken, plan, 0, 0, false, taxInfo); - - Assert.Null(result); - Assert.Equal(GatewayType.Stripe, organization.Gateway); - Assert.Equal("C-1", organization.GatewayCustomerId); - Assert.Equal("S-1", organization.GatewaySubscriptionId); - Assert.True(organization.Enabled); - Assert.Equal(DateTime.Today.AddDays(10), organization.ExpirationDate); - - await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(c => - c.Description == organization.BusinessName && - c.Email == organization.BillingEmail && - c.PaymentMethod == null && - c.Metadata.Count == 1 && - c.Metadata["btCustomerId"] == "Braintree-Id" && - c.InvoiceSettings.DefaultPaymentMethod == null && - c.Address.Country == taxInfo.BillingAddressCountry && - c.Address.PostalCode == taxInfo.BillingAddressPostalCode && - c.Address.Line1 == taxInfo.BillingAddressLine1 && - c.Address.Line2 == taxInfo.BillingAddressLine2 && - c.Address.City == taxInfo.BillingAddressCity && - c.Address.State == taxInfo.BillingAddressState && - c.TaxIdData == null - )); - - await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => - s.Customer == "C-1" && - s.Expand[0] == "latest_invoice.payment_intent" && - s.Metadata[organization.GatewayIdField()] == organization.Id.ToString() && - s.Items.Count == 0 - )); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void PurchaseOrganizationAsync_Paypal_FailedCreate(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + Id = "C-1", + }); + stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription { - var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + Id = "S-1", + CurrentPeriodEnd = DateTime.Today.AddDays(10), + }); - var customerResult = Substitute.For>(); - customerResult.IsSuccess().Returns(false); + var customer = Substitute.For(); + customer.Id.ReturnsForAnyArgs("Braintree-Id"); + customer.PaymentMethods.ReturnsForAnyArgs(new[] { Substitute.For() }); + var customerResult = Substitute.For>(); + customerResult.IsSuccess().Returns(true); + customerResult.Target.ReturnsForAnyArgs(customer); - var braintreeGateway = sutProvider.GetDependency(); - braintreeGateway.Customer.CreateAsync(default).ReturnsForAnyArgs(customerResult); + var braintreeGateway = sutProvider.GetDependency(); + braintreeGateway.Customer.CreateAsync(default).ReturnsForAnyArgs(customerResult); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.PayPal, paymentToken, plan, 0, 0, false, taxInfo)); + var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.PayPal, paymentToken, plan, 0, 0, false, taxInfo); - Assert.Equal("Failed to create PayPal customer record.", exception.Message); - } + Assert.Null(result); + Assert.Equal(GatewayType.Stripe, organization.Gateway); + Assert.Equal("C-1", organization.GatewayCustomerId); + Assert.Equal("S-1", organization.GatewaySubscriptionId); + Assert.True(organization.Enabled); + Assert.Equal(DateTime.Today.AddDays(10), organization.ExpirationDate); - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void PurchaseOrganizationAsync_PayPal_Declined(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(c => + c.Description == organization.BusinessName && + c.Email == organization.BillingEmail && + c.PaymentMethod == null && + c.Metadata.Count == 1 && + c.Metadata["btCustomerId"] == "Braintree-Id" && + c.InvoiceSettings.DefaultPaymentMethod == null && + c.Address.Country == taxInfo.BillingAddressCountry && + c.Address.PostalCode == taxInfo.BillingAddressPostalCode && + c.Address.Line1 == taxInfo.BillingAddressLine1 && + c.Address.Line2 == taxInfo.BillingAddressLine2 && + c.Address.City == taxInfo.BillingAddressCity && + c.Address.State == taxInfo.BillingAddressState && + c.TaxIdData == null + )); + + await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => + s.Customer == "C-1" && + s.Expand[0] == "latest_invoice.payment_intent" && + s.Metadata[organization.GatewayIdField()] == organization.Id.ToString() && + s.Items.Count == 0 + )); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void PurchaseOrganizationAsync_Paypal_FailedCreate(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + { + var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + + var customerResult = Substitute.For>(); + customerResult.IsSuccess().Returns(false); + + var braintreeGateway = sutProvider.GetDependency(); + braintreeGateway.Customer.CreateAsync(default).ReturnsForAnyArgs(customerResult); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.PayPal, paymentToken, plan, 0, 0, false, taxInfo)); + + Assert.Equal("Failed to create PayPal customer record.", exception.Message); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void PurchaseOrganizationAsync_PayPal_Declined(SutProvider sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo) + { + var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + paymentToken = "pm_" + paymentToken; + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer { - var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); - paymentToken = "pm_" + paymentToken; - - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer + Id = "C-1", + }); + stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription + { + Id = "S-1", + CurrentPeriodEnd = DateTime.Today.AddDays(10), + Status = "incomplete", + LatestInvoice = new Stripe.Invoice { - Id = "C-1", - }); - stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription - { - Id = "S-1", - CurrentPeriodEnd = DateTime.Today.AddDays(10), - Status = "incomplete", - LatestInvoice = new Stripe.Invoice + PaymentIntent = new Stripe.PaymentIntent { - PaymentIntent = new Stripe.PaymentIntent - { - Status = "requires_payment_method", - }, + Status = "requires_payment_method", }, - }); + }, + }); - var customer = Substitute.For(); - customer.Id.ReturnsForAnyArgs("Braintree-Id"); - customer.PaymentMethods.ReturnsForAnyArgs(new[] { Substitute.For() }); - var customerResult = Substitute.For>(); - customerResult.IsSuccess().Returns(true); - customerResult.Target.ReturnsForAnyArgs(customer); + var customer = Substitute.For(); + customer.Id.ReturnsForAnyArgs("Braintree-Id"); + customer.PaymentMethods.ReturnsForAnyArgs(new[] { Substitute.For() }); + var customerResult = Substitute.For>(); + customerResult.IsSuccess().Returns(true); + customerResult.Target.ReturnsForAnyArgs(customer); - var braintreeGateway = sutProvider.GetDependency(); - braintreeGateway.Customer.CreateAsync(default).ReturnsForAnyArgs(customerResult); + var braintreeGateway = sutProvider.GetDependency(); + braintreeGateway.Customer.CreateAsync(default).ReturnsForAnyArgs(customerResult); - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.PayPal, paymentToken, plan, 0, 0, false, taxInfo)); + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.PayPal, paymentToken, plan, 0, 0, false, taxInfo)); - Assert.Equal("Payment method was declined.", exception.Message); + Assert.Equal("Payment method was declined.", exception.Message); - await stripeAdapter.Received(1).CustomerDeleteAsync("C-1"); - await braintreeGateway.Customer.Received(1).DeleteAsync("Braintree-Id"); - } + await stripeAdapter.Received(1).CustomerDeleteAsync("C-1"); + await braintreeGateway.Customer.Received(1).DeleteAsync("Braintree-Id"); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void UpgradeFreeOrganizationAsync_Success(SutProvider sutProvider, - Organization organization, TaxInfo taxInfo) + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void UpgradeFreeOrganizationAsync_Success(SutProvider sutProvider, + Organization organization, TaxInfo taxInfo) + { + organization.GatewaySubscriptionId = null; + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter.CustomerGetAsync(default).ReturnsForAnyArgs(new Stripe.Customer { - organization.GatewaySubscriptionId = null; - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerGetAsync(default).ReturnsForAnyArgs(new Stripe.Customer + Id = "C-1", + Metadata = new Dictionary { - Id = "C-1", - Metadata = new Dictionary - { - { "btCustomerId", "B-123" }, - } - }); - stripeAdapter.InvoiceUpcomingAsync(default).ReturnsForAnyArgs(new Stripe.Invoice - { - PaymentIntent = new Stripe.PaymentIntent { Status = "requires_payment_method", }, - AmountDue = 0 - }); - stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription { }); + { "btCustomerId", "B-123" }, + } + }); + stripeAdapter.InvoiceUpcomingAsync(default).ReturnsForAnyArgs(new Stripe.Invoice + { + PaymentIntent = new Stripe.PaymentIntent { Status = "requires_payment_method", }, + AmountDue = 0 + }); + stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription { }); - var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); - var result = await sutProvider.Sut.UpgradeFreeOrganizationAsync(organization, plan, 0, 0, false, taxInfo); + var plan = StaticStore.Plans.First(p => p.Type == PlanType.EnterpriseAnnually); + var result = await sutProvider.Sut.UpgradeFreeOrganizationAsync(organization, plan, 0, 0, false, taxInfo); - Assert.Null(result); - } + Assert.Null(result); } } diff --git a/test/Core.Test/Services/UserServiceTests.cs b/test/Core.Test/Services/UserServiceTests.cs index 10a4beac5..5e82e4c40 100644 --- a/test/Core.Test/Services/UserServiceTests.cs +++ b/test/Core.Test/Services/UserServiceTests.cs @@ -14,376 +14,375 @@ using NSubstitute; using NSubstitute.ReceivedExtensions; using Xunit; -namespace Bit.Core.Test.Services +namespace Bit.Core.Test.Services; + +public class UserServiceTests { - public class UserServiceTests + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task UpdateLicenseAsync_Success(SutProvider sutProvider, + User user, UserLicense userLicense) { - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task UpdateLicenseAsync_Success(SutProvider sutProvider, - User user, UserLicense userLicense) + using var tempDir = new TempDirectory(); + + var now = DateTime.UtcNow; + userLicense.Issued = now.AddDays(-10); + userLicense.Expires = now.AddDays(10); + userLicense.Version = 1; + userLicense.Premium = true; + + user.EmailVerified = true; + user.Email = userLicense.Email; + + sutProvider.GetDependency().SelfHosted = true; + sutProvider.GetDependency().LicenseDirectory = tempDir.Directory; + sutProvider.GetDependency() + .VerifyLicense(userLicense) + .Returns(true); + + await sutProvider.Sut.UpdateLicenseAsync(user, userLicense); + + var filePath = Path.Combine(tempDir.Directory, "user", $"{user.Id}.json"); + Assert.True(File.Exists(filePath)); + var document = JsonDocument.Parse(File.OpenRead(filePath)); + var root = document.RootElement; + Assert.Equal(JsonValueKind.Object, root.ValueKind); + // Sort of a lazy way to test that it is indented but not sure of a better way + Assert.Contains('\n', root.GetRawText()); + AssertHelper.AssertJsonProperty(root, "LicenseKey", JsonValueKind.String); + AssertHelper.AssertJsonProperty(root, "Id", JsonValueKind.String); + AssertHelper.AssertJsonProperty(root, "Premium", JsonValueKind.True); + var versionProp = AssertHelper.AssertJsonProperty(root, "Version", JsonValueKind.Number); + Assert.Equal(1, versionProp.GetInt32()); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SendTwoFactorEmailAsync_Success(SutProvider sutProvider, User user) + { + var email = user.Email.ToLowerInvariant(); + var token = "thisisatokentocompare"; + + var userTwoFactorTokenProvider = Substitute.For>(); + userTwoFactorTokenProvider + .CanGenerateTwoFactorTokenAsync(Arg.Any>(), user) + .Returns(Task.FromResult(true)); + userTwoFactorTokenProvider + .GenerateAsync("2faEmail:" + email, Arg.Any>(), user) + .Returns(Task.FromResult(token)); + + sutProvider.Sut.RegisterTokenProvider("Email", userTwoFactorTokenProvider); + + user.SetTwoFactorProviders(new Dictionary { - using var tempDir = new TempDirectory(); - - var now = DateTime.UtcNow; - userLicense.Issued = now.AddDays(-10); - userLicense.Expires = now.AddDays(10); - userLicense.Version = 1; - userLicense.Premium = true; - - user.EmailVerified = true; - user.Email = userLicense.Email; - - sutProvider.GetDependency().SelfHosted = true; - sutProvider.GetDependency().LicenseDirectory = tempDir.Directory; - sutProvider.GetDependency() - .VerifyLicense(userLicense) - .Returns(true); - - await sutProvider.Sut.UpdateLicenseAsync(user, userLicense); - - var filePath = Path.Combine(tempDir.Directory, "user", $"{user.Id}.json"); - Assert.True(File.Exists(filePath)); - var document = JsonDocument.Parse(File.OpenRead(filePath)); - var root = document.RootElement; - Assert.Equal(JsonValueKind.Object, root.ValueKind); - // Sort of a lazy way to test that it is indented but not sure of a better way - Assert.Contains('\n', root.GetRawText()); - AssertHelper.AssertJsonProperty(root, "LicenseKey", JsonValueKind.String); - AssertHelper.AssertJsonProperty(root, "Id", JsonValueKind.String); - AssertHelper.AssertJsonProperty(root, "Premium", JsonValueKind.True); - var versionProp = AssertHelper.AssertJsonProperty(root, "Version", JsonValueKind.Number); - Assert.Equal(1, versionProp.GetInt32()); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SendTwoFactorEmailAsync_Success(SutProvider sutProvider, User user) - { - var email = user.Email.ToLowerInvariant(); - var token = "thisisatokentocompare"; - - var userTwoFactorTokenProvider = Substitute.For>(); - userTwoFactorTokenProvider - .CanGenerateTwoFactorTokenAsync(Arg.Any>(), user) - .Returns(Task.FromResult(true)); - userTwoFactorTokenProvider - .GenerateAsync("2faEmail:" + email, Arg.Any>(), user) - .Returns(Task.FromResult(token)); - - sutProvider.Sut.RegisterTokenProvider("Email", userTwoFactorTokenProvider); - - user.SetTwoFactorProviders(new Dictionary + [TwoFactorProviderType.Email] = new TwoFactorProvider { - [TwoFactorProviderType.Email] = new TwoFactorProvider - { - MetaData = new Dictionary { ["Email"] = email }, - Enabled = true - } - }); - await sutProvider.Sut.SendTwoFactorEmailAsync(user); + MetaData = new Dictionary { ["Email"] = email }, + Enabled = true + } + }); + await sutProvider.Sut.SendTwoFactorEmailAsync(user); - await sutProvider.GetDependency() - .Received(1) - .SendTwoFactorEmailAsync(email, token); - } + await sutProvider.GetDependency() + .Received(1) + .SendTwoFactorEmailAsync(email, token); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SendTwoFactorEmailBecauseNewDeviceLoginAsync_Success(SutProvider sutProvider, User user) + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SendTwoFactorEmailBecauseNewDeviceLoginAsync_Success(SutProvider sutProvider, User user) + { + var email = user.Email.ToLowerInvariant(); + var token = "thisisatokentocompare"; + + var userTwoFactorTokenProvider = Substitute.For>(); + userTwoFactorTokenProvider + .CanGenerateTwoFactorTokenAsync(Arg.Any>(), user) + .Returns(Task.FromResult(true)); + userTwoFactorTokenProvider + .GenerateAsync("2faEmail:" + email, Arg.Any>(), user) + .Returns(Task.FromResult(token)); + + sutProvider.Sut.RegisterTokenProvider("Email", userTwoFactorTokenProvider); + + user.SetTwoFactorProviders(new Dictionary { - var email = user.Email.ToLowerInvariant(); - var token = "thisisatokentocompare"; - - var userTwoFactorTokenProvider = Substitute.For>(); - userTwoFactorTokenProvider - .CanGenerateTwoFactorTokenAsync(Arg.Any>(), user) - .Returns(Task.FromResult(true)); - userTwoFactorTokenProvider - .GenerateAsync("2faEmail:" + email, Arg.Any>(), user) - .Returns(Task.FromResult(token)); - - sutProvider.Sut.RegisterTokenProvider("Email", userTwoFactorTokenProvider); - - user.SetTwoFactorProviders(new Dictionary + [TwoFactorProviderType.Email] = new TwoFactorProvider { - [TwoFactorProviderType.Email] = new TwoFactorProvider - { - MetaData = new Dictionary { ["Email"] = email }, - Enabled = true - } - }); - await sutProvider.Sut.SendTwoFactorEmailAsync(user, true); + MetaData = new Dictionary { ["Email"] = email }, + Enabled = true + } + }); + await sutProvider.Sut.SendTwoFactorEmailAsync(user, true); - await sutProvider.GetDependency() - .Received(1) - .SendNewDeviceLoginTwoFactorEmailAsync(email, token); - } + await sutProvider.GetDependency() + .Received(1) + .SendNewDeviceLoginTwoFactorEmailAsync(email, token); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SendTwoFactorEmailAsync_ExceptionBecauseNoProviderOnUser(SutProvider sutProvider, User user) + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SendTwoFactorEmailAsync_ExceptionBecauseNoProviderOnUser(SutProvider sutProvider, User user) + { + user.TwoFactorProviders = null; + + await Assert.ThrowsAsync("No email.", () => sutProvider.Sut.SendTwoFactorEmailAsync(user)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SendTwoFactorEmailAsync_ExceptionBecauseNoProviderMetadataOnUser(SutProvider sutProvider, User user) + { + user.SetTwoFactorProviders(new Dictionary { - user.TwoFactorProviders = null; - - await Assert.ThrowsAsync("No email.", () => sutProvider.Sut.SendTwoFactorEmailAsync(user)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SendTwoFactorEmailAsync_ExceptionBecauseNoProviderMetadataOnUser(SutProvider sutProvider, User user) - { - user.SetTwoFactorProviders(new Dictionary + [TwoFactorProviderType.Email] = new TwoFactorProvider { - [TwoFactorProviderType.Email] = new TwoFactorProvider - { - MetaData = null, - Enabled = true - } - }); + MetaData = null, + Enabled = true + } + }); - await Assert.ThrowsAsync("No email.", () => sutProvider.Sut.SendTwoFactorEmailAsync(user)); - } + await Assert.ThrowsAsync("No email.", () => sutProvider.Sut.SendTwoFactorEmailAsync(user)); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task SendTwoFactorEmailAsync_ExceptionBecauseNoProviderEmailMetadataOnUser(SutProvider sutProvider, User user) + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task SendTwoFactorEmailAsync_ExceptionBecauseNoProviderEmailMetadataOnUser(SutProvider sutProvider, User user) + { + user.SetTwoFactorProviders(new Dictionary { - user.SetTwoFactorProviders(new Dictionary + [TwoFactorProviderType.Email] = new TwoFactorProvider { - [TwoFactorProviderType.Email] = new TwoFactorProvider - { - MetaData = new Dictionary { ["qweqwe"] = user.Email.ToLowerInvariant() }, - Enabled = true - } - }); + MetaData = new Dictionary { ["qweqwe"] = user.Email.ToLowerInvariant() }, + Enabled = true + } + }); - await Assert.ThrowsAsync("No email.", () => sutProvider.Sut.SendTwoFactorEmailAsync(user)); - } + await Assert.ThrowsAsync("No email.", () => sutProvider.Sut.SendTwoFactorEmailAsync(user)); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task Needs2FABecauseNewDeviceAsync_ReturnsTrue(SutProvider sutProvider, User user) + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task Needs2FABecauseNewDeviceAsync_ReturnsTrue(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.TwoFactorProviders = null; + user.UnknownDeviceVerificationEnabled = true; + const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; + const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; + + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(Task.FromResult>(new List + { + new Device { Identifier = deviceIdInRepo } + })); + + + sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); + + Assert.True(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_GranType_Is_AuthorizationCode(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.TwoFactorProviders = null; + const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; + const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; + + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(Task.FromResult>(new List + { + new Device { Identifier = deviceIdInRepo } + })); + + Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "authorization_code")); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_Email_Is_Not_Verified(SutProvider sutProvider, User user) + { + user.EmailVerified = false; + user.TwoFactorProviders = null; + const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; + const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; + + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(Task.FromResult>(new List + { + new Device { Identifier = deviceIdInRepo } + })); + + Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_Is_The_First_Device(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.TwoFactorProviders = null; + const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; + + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(Task.FromResult>(new List())); + + Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_DeviceId_Is_Already_In_Repo(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.TwoFactorProviders = null; + const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; + + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(Task.FromResult>(new List + { + new Device { Identifier = deviceIdToCheck } + })); + + Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_GlobalSettings_2FA_EmailOnNewDeviceLogin_Is_Disabled(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.TwoFactorProviders = null; + const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; + const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; + + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(Task.FromResult>(new List + { + new Device { Identifier = deviceIdInRepo } + })); + + sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(false); + + Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_UnknownDeviceVerification_Is_Disabled(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.TwoFactorProviders = null; + user.UnknownDeviceVerificationEnabled = false; + const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; + const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; + + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(Task.FromResult>(new List + { + new Device { Identifier = deviceIdInRepo } + })); + + sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); + + Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public void CanEditDeviceVerificationSettings_ReturnsTrue(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.TwoFactorProviders = null; + + sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); + + Assert.True(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public void CanEditDeviceVerificationSettings_ReturnsFalse_When_GlobalSettings_2FA_EmailOnNewDeviceLogin_Is_Disabled(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.TwoFactorProviders = null; + + sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(false); + + Assert.False(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public void CanEditDeviceVerificationSettings_ReturnsFalse_When_Email_Is_Not_Verified(SutProvider sutProvider, User user) + { + user.EmailVerified = false; + user.TwoFactorProviders = null; + + sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); + + Assert.False(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public void CanEditDeviceVerificationSettings_ReturnsFalse_When_User_Uses_Key_Connector(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.TwoFactorProviders = null; + user.UsesKeyConnector = true; + + sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); + + Assert.False(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); + } + + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public void CanEditDeviceVerificationSettings_ReturnsFalse_When_User_Has_A_2FA_Already_Set_Up(SutProvider sutProvider, User user) + { + user.EmailVerified = true; + user.SetTwoFactorProviders(new Dictionary { - user.EmailVerified = true; - user.TwoFactorProviders = null; - user.UnknownDeviceVerificationEnabled = true; - const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; - const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; - - sutProvider.GetDependency() - .GetManyByUserIdAsync(user.Id) - .Returns(Task.FromResult>(new List - { - new Device { Identifier = deviceIdInRepo } - })); - - - sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); - - Assert.True(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_GranType_Is_AuthorizationCode(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.TwoFactorProviders = null; - const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; - const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; - - sutProvider.GetDependency() - .GetManyByUserIdAsync(user.Id) - .Returns(Task.FromResult>(new List - { - new Device { Identifier = deviceIdInRepo } - })); - - Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "authorization_code")); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_Email_Is_Not_Verified(SutProvider sutProvider, User user) - { - user.EmailVerified = false; - user.TwoFactorProviders = null; - const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; - const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; - - sutProvider.GetDependency() - .GetManyByUserIdAsync(user.Id) - .Returns(Task.FromResult>(new List - { - new Device { Identifier = deviceIdInRepo } - })); - - Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_Is_The_First_Device(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.TwoFactorProviders = null; - const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; - - sutProvider.GetDependency() - .GetManyByUserIdAsync(user.Id) - .Returns(Task.FromResult>(new List())); - - Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_DeviceId_Is_Already_In_Repo(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.TwoFactorProviders = null; - const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; - - sutProvider.GetDependency() - .GetManyByUserIdAsync(user.Id) - .Returns(Task.FromResult>(new List - { - new Device { Identifier = deviceIdToCheck } - })); - - Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_GlobalSettings_2FA_EmailOnNewDeviceLogin_Is_Disabled(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.TwoFactorProviders = null; - const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; - const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; - - sutProvider.GetDependency() - .GetManyByUserIdAsync(user.Id) - .Returns(Task.FromResult>(new List - { - new Device { Identifier = deviceIdInRepo } - })); - - sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(false); - - Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async Task Needs2FABecauseNewDeviceAsync_ReturnsFalse_When_UnknownDeviceVerification_Is_Disabled(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.TwoFactorProviders = null; - user.UnknownDeviceVerificationEnabled = false; - const string deviceIdToCheck = "7b01b586-b210-499f-8d52-0c3fdaa646fc"; - const string deviceIdInRepo = "ea29126c-91b7-4cc4-8ce6-00105b37f64a"; - - sutProvider.GetDependency() - .GetManyByUserIdAsync(user.Id) - .Returns(Task.FromResult>(new List - { - new Device { Identifier = deviceIdInRepo } - })); - - sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); - - Assert.False(await sutProvider.Sut.Needs2FABecauseNewDeviceAsync(user, deviceIdToCheck, "password")); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public void CanEditDeviceVerificationSettings_ReturnsTrue(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.TwoFactorProviders = null; - - sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); - - Assert.True(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public void CanEditDeviceVerificationSettings_ReturnsFalse_When_GlobalSettings_2FA_EmailOnNewDeviceLogin_Is_Disabled(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.TwoFactorProviders = null; - - sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(false); - - Assert.False(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public void CanEditDeviceVerificationSettings_ReturnsFalse_When_Email_Is_Not_Verified(SutProvider sutProvider, User user) - { - user.EmailVerified = false; - user.TwoFactorProviders = null; - - sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); - - Assert.False(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public void CanEditDeviceVerificationSettings_ReturnsFalse_When_User_Uses_Key_Connector(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.TwoFactorProviders = null; - user.UsesKeyConnector = true; - - sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); - - Assert.False(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); - } - - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public void CanEditDeviceVerificationSettings_ReturnsFalse_When_User_Has_A_2FA_Already_Set_Up(SutProvider sutProvider, User user) - { - user.EmailVerified = true; - user.SetTwoFactorProviders(new Dictionary + [TwoFactorProviderType.Email] = new TwoFactorProvider { - [TwoFactorProviderType.Email] = new TwoFactorProvider - { - MetaData = new Dictionary { ["Email"] = "asdfasf" }, - Enabled = true - } - }); + MetaData = new Dictionary { ["Email"] = "asdfasf" }, + Enabled = true + } + }); - sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); + sutProvider.GetDependency().TwoFactorAuth.EmailOnNewDeviceLogin.Returns(true); - Assert.False(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); - } + Assert.False(sutProvider.Sut.CanEditDeviceVerificationSettings(user)); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void HasPremiumFromOrganization_Returns_False_If_No_Orgs(SutProvider sutProvider, User user) - { - sutProvider.GetDependency().GetManyByUserAsync(user.Id).Returns(new List()); - Assert.False(await sutProvider.Sut.HasPremiumFromOrganization(user)); + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void HasPremiumFromOrganization_Returns_False_If_No_Orgs(SutProvider sutProvider, User user) + { + sutProvider.GetDependency().GetManyByUserAsync(user.Id).Returns(new List()); + Assert.False(await sutProvider.Sut.HasPremiumFromOrganization(user)); - } + } - [Theory] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, false, true)] - [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, true, false)] - public async void HasPremiumFromOrganization_Returns_False_If_Org_Not_Eligible(bool orgEnabled, bool orgUsersGetPremium, SutProvider sutProvider, User user, OrganizationUser orgUser, Organization organization) - { - orgUser.OrganizationId = organization.Id; - organization.Enabled = orgEnabled; - organization.UsersGetPremium = orgUsersGetPremium; - var orgAbilities = new Dictionary() { { organization.Id, new OrganizationAbility(organization) } }; + [Theory] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, false, true)] + [InlineCustomAutoData(new[] { typeof(SutProviderCustomization) }, true, false)] + public async void HasPremiumFromOrganization_Returns_False_If_Org_Not_Eligible(bool orgEnabled, bool orgUsersGetPremium, SutProvider sutProvider, User user, OrganizationUser orgUser, Organization organization) + { + orgUser.OrganizationId = organization.Id; + organization.Enabled = orgEnabled; + organization.UsersGetPremium = orgUsersGetPremium; + var orgAbilities = new Dictionary() { { organization.Id, new OrganizationAbility(organization) } }; - sutProvider.GetDependency().GetManyByUserAsync(user.Id).Returns(new List() { orgUser }); - sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(orgAbilities); + sutProvider.GetDependency().GetManyByUserAsync(user.Id).Returns(new List() { orgUser }); + sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(orgAbilities); - Assert.False(await sutProvider.Sut.HasPremiumFromOrganization(user)); - } + Assert.False(await sutProvider.Sut.HasPremiumFromOrganization(user)); + } - [Theory, CustomAutoData(typeof(SutProviderCustomization))] - public async void HasPremiumFromOrganization_Returns_True_If_Org_Eligible(SutProvider sutProvider, User user, OrganizationUser orgUser, Organization organization) - { - orgUser.OrganizationId = organization.Id; - organization.Enabled = true; - organization.UsersGetPremium = true; - var orgAbilities = new Dictionary() { { organization.Id, new OrganizationAbility(organization) } }; + [Theory, CustomAutoData(typeof(SutProviderCustomization))] + public async void HasPremiumFromOrganization_Returns_True_If_Org_Eligible(SutProvider sutProvider, User user, OrganizationUser orgUser, Organization organization) + { + orgUser.OrganizationId = organization.Id; + organization.Enabled = true; + organization.UsersGetPremium = true; + var orgAbilities = new Dictionary() { { organization.Id, new OrganizationAbility(organization) } }; - sutProvider.GetDependency().GetManyByUserAsync(user.Id).Returns(new List() { orgUser }); - sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(orgAbilities); + sutProvider.GetDependency().GetManyByUserAsync(user.Id).Returns(new List() { orgUser }); + sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(orgAbilities); - Assert.True(await sutProvider.Sut.HasPremiumFromOrganization(user)); - } + Assert.True(await sutProvider.Sut.HasPremiumFromOrganization(user)); } } diff --git a/test/Core.Test/TempDirectory.cs b/test/Core.Test/TempDirectory.cs index 9a1cd86af..832d8c79c 100644 --- a/test/Core.Test/TempDirectory.cs +++ b/test/Core.Test/TempDirectory.cs @@ -1,39 +1,38 @@ -namespace Bit.Core.Test +namespace Bit.Core.Test; + +public class TempDirectory : IDisposable { - public class TempDirectory : IDisposable + public string Directory { get; private set; } + + public TempDirectory() { - public string Directory { get; private set; } - - public TempDirectory() - { - Directory = Path.Combine(Path.GetTempPath(), $"bitwarden_{Guid.NewGuid().ToString().Replace("-", "")}"); - } - - public override string ToString() => Directory; - - #region IDisposable implementation - ~TempDirectory() - { - Dispose(false); - } - - public void Dispose() - { - Dispose(true); - GC.SuppressFinalize(this); - } - - public void Dispose(bool disposing) - { - if (disposing) - { - try - { - System.IO.Directory.Delete(Directory, true); - } - catch { } - } - } - # endregion + Directory = Path.Combine(Path.GetTempPath(), $"bitwarden_{Guid.NewGuid().ToString().Replace("-", "")}"); } + + public override string ToString() => Directory; + + #region IDisposable implementation + ~TempDirectory() + { + Dispose(false); + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + public void Dispose(bool disposing) + { + if (disposing) + { + try + { + System.IO.Directory.Delete(Directory, true); + } + catch { } + } + } + # endregion } diff --git a/test/Core.Test/Tokens/DataProtectorTokenFactoryTests.cs b/test/Core.Test/Tokens/DataProtectorTokenFactoryTests.cs index 8a75a0790..3837ae026 100644 --- a/test/Core.Test/Tokens/DataProtectorTokenFactoryTests.cs +++ b/test/Core.Test/Tokens/DataProtectorTokenFactoryTests.cs @@ -7,122 +7,121 @@ using Bit.Test.Common.Helpers; using Microsoft.AspNetCore.DataProtection; using Xunit; -namespace Bit.Core.Test.Tokens +namespace Bit.Core.Test.Tokens; + +[SutProviderCustomize] +public class DataProtectorTokenFactoryTests { - [SutProviderCustomize] - public class DataProtectorTokenFactoryTests + public static SutProvider> GetSutProvider() { - public static SutProvider> GetSutProvider() - { - var fixture = new Fixture(); - return new SutProvider>(fixture) - .SetDependency(fixture.Create()) - .Create(); - } - - [Theory, BitAutoData] - public void CanRoundTripTokenables(TestTokenable tokenable) - { - var sutProvider = GetSutProvider(); - - var token = sutProvider.Sut.Protect(tokenable); - var recoveredTokenable = sutProvider.Sut.Unprotect(token); - - AssertHelper.AssertPropertyEqual(tokenable, recoveredTokenable); - } - - [Theory, BitAutoData] - public void PrependsClearText(TestTokenable tokenable) - { - var sutProvider = GetSutProvider(); - - var token = sutProvider.Sut.Protect(tokenable); - - Assert.StartsWith(sutProvider.GetDependency("clearTextPrefix"), token); - } - - [Theory, BitAutoData] - public void EncryptsToken(TestTokenable tokenable) - { - var sutProvider = GetSutProvider(); - var prefix = sutProvider.GetDependency("clearTextPrefix"); - - var token = sutProvider.Sut.Protect(tokenable); - - Assert.NotEqual(new Token(token).RemovePrefix(prefix), tokenable.ToToken()); - } - - [Theory, BitAutoData] - public void ThrowsIfUnprotectFails(TestTokenable tokenable) - { - var sutProvider = GetSutProvider(); - - var token = sutProvider.Sut.Protect(tokenable); - token += "stuff to make sure decryption fails"; - - Assert.Throws(() => sutProvider.Sut.Unprotect(token)); - } - - [Theory, BitAutoData] - public void TryUnprotect_FalseIfUnprotectFails(TestTokenable tokenable) - { - var sutProvider = GetSutProvider(); - var token = sutProvider.Sut.Protect(tokenable) + "fail decryption"; - - var result = sutProvider.Sut.TryUnprotect(token, out var data); - - Assert.False(result); - Assert.Null(data); - } - - [Theory, BitAutoData] - public void TokenValid_FalseIfUnprotectFails(TestTokenable tokenable) - { - var sutProvider = GetSutProvider(); - var token = sutProvider.Sut.Protect(tokenable) + "fail decryption"; - - var result = sutProvider.Sut.TokenValid(token); - - Assert.False(result); - } - - - [Theory, BitAutoData] - public void TokenValid_FalseIfTokenInvalid(TestTokenable tokenable) - { - var sutProvider = GetSutProvider(); - - tokenable.ForceInvalid = true; - var token = sutProvider.Sut.Protect(tokenable); - - var result = sutProvider.Sut.TokenValid(token); - - Assert.False(result); - } - - [Theory, BitAutoData] - public void TryUnprotect_TrueIfSuccess(TestTokenable tokenable) - { - var sutProvider = GetSutProvider(); - var token = sutProvider.Sut.Protect(tokenable); - - var result = sutProvider.Sut.TryUnprotect(token, out var data); - - Assert.True(result); - AssertHelper.AssertPropertyEqual(tokenable, data); - } - - [Theory, BitAutoData] - public void TokenValid_TrueIfSuccess(TestTokenable tokenable) - { - tokenable.ForceInvalid = false; - var sutProvider = GetSutProvider(); - var token = sutProvider.Sut.Protect(tokenable); - - var result = sutProvider.Sut.TokenValid(token); - - Assert.True(result); - } - + var fixture = new Fixture(); + return new SutProvider>(fixture) + .SetDependency(fixture.Create()) + .Create(); } + + [Theory, BitAutoData] + public void CanRoundTripTokenables(TestTokenable tokenable) + { + var sutProvider = GetSutProvider(); + + var token = sutProvider.Sut.Protect(tokenable); + var recoveredTokenable = sutProvider.Sut.Unprotect(token); + + AssertHelper.AssertPropertyEqual(tokenable, recoveredTokenable); + } + + [Theory, BitAutoData] + public void PrependsClearText(TestTokenable tokenable) + { + var sutProvider = GetSutProvider(); + + var token = sutProvider.Sut.Protect(tokenable); + + Assert.StartsWith(sutProvider.GetDependency("clearTextPrefix"), token); + } + + [Theory, BitAutoData] + public void EncryptsToken(TestTokenable tokenable) + { + var sutProvider = GetSutProvider(); + var prefix = sutProvider.GetDependency("clearTextPrefix"); + + var token = sutProvider.Sut.Protect(tokenable); + + Assert.NotEqual(new Token(token).RemovePrefix(prefix), tokenable.ToToken()); + } + + [Theory, BitAutoData] + public void ThrowsIfUnprotectFails(TestTokenable tokenable) + { + var sutProvider = GetSutProvider(); + + var token = sutProvider.Sut.Protect(tokenable); + token += "stuff to make sure decryption fails"; + + Assert.Throws(() => sutProvider.Sut.Unprotect(token)); + } + + [Theory, BitAutoData] + public void TryUnprotect_FalseIfUnprotectFails(TestTokenable tokenable) + { + var sutProvider = GetSutProvider(); + var token = sutProvider.Sut.Protect(tokenable) + "fail decryption"; + + var result = sutProvider.Sut.TryUnprotect(token, out var data); + + Assert.False(result); + Assert.Null(data); + } + + [Theory, BitAutoData] + public void TokenValid_FalseIfUnprotectFails(TestTokenable tokenable) + { + var sutProvider = GetSutProvider(); + var token = sutProvider.Sut.Protect(tokenable) + "fail decryption"; + + var result = sutProvider.Sut.TokenValid(token); + + Assert.False(result); + } + + + [Theory, BitAutoData] + public void TokenValid_FalseIfTokenInvalid(TestTokenable tokenable) + { + var sutProvider = GetSutProvider(); + + tokenable.ForceInvalid = true; + var token = sutProvider.Sut.Protect(tokenable); + + var result = sutProvider.Sut.TokenValid(token); + + Assert.False(result); + } + + [Theory, BitAutoData] + public void TryUnprotect_TrueIfSuccess(TestTokenable tokenable) + { + var sutProvider = GetSutProvider(); + var token = sutProvider.Sut.Protect(tokenable); + + var result = sutProvider.Sut.TryUnprotect(token, out var data); + + Assert.True(result); + AssertHelper.AssertPropertyEqual(tokenable, data); + } + + [Theory, BitAutoData] + public void TokenValid_TrueIfSuccess(TestTokenable tokenable) + { + tokenable.ForceInvalid = false; + var sutProvider = GetSutProvider(); + var token = sutProvider.Sut.Protect(tokenable); + + var result = sutProvider.Sut.TokenValid(token); + + Assert.True(result); + } + } diff --git a/test/Core.Test/Tokens/ExpiringTokenTests.cs b/test/Core.Test/Tokens/ExpiringTokenTests.cs index 33ce91178..9154b65b6 100644 --- a/test/Core.Test/Tokens/ExpiringTokenTests.cs +++ b/test/Core.Test/Tokens/ExpiringTokenTests.cs @@ -3,69 +3,68 @@ using AutoFixture.Xunit2; using Bit.Core.Utilities; using Xunit; -namespace Bit.Core.Test.Tokens +namespace Bit.Core.Test.Tokens; + +public class ExpiringTokenTests { - public class ExpiringTokenTests + [Theory, AutoData] + public void ExpirationSerializesToEpochMilliseconds(DateTime expirationDate) { - [Theory, AutoData] - public void ExpirationSerializesToEpochMilliseconds(DateTime expirationDate) + var sut = new TestExpiringTokenable { - var sut = new TestExpiringTokenable - { - ExpirationDate = expirationDate - }; + ExpirationDate = expirationDate + }; - var result = JsonSerializer.Serialize(sut); - var expectedDate = CoreHelpers.ToEpocMilliseconds(expirationDate); + var result = JsonSerializer.Serialize(sut); + var expectedDate = CoreHelpers.ToEpocMilliseconds(expirationDate); - Assert.Contains($"\"ExpirationDate\":{expectedDate}", result); - } + Assert.Contains($"\"ExpirationDate\":{expectedDate}", result); + } - [Theory, AutoData] - public void ExpirationSerializationRoundTrip(DateTime expirationDate) + [Theory, AutoData] + public void ExpirationSerializationRoundTrip(DateTime expirationDate) + { + var sut = new TestExpiringTokenable { - var sut = new TestExpiringTokenable - { - ExpirationDate = expirationDate - }; + ExpirationDate = expirationDate + }; - var intermediate = JsonSerializer.Serialize(sut); - var result = JsonSerializer.Deserialize(intermediate); + var intermediate = JsonSerializer.Serialize(sut); + var result = JsonSerializer.Deserialize(intermediate); - Assert.Equal(sut.ExpirationDate, result.ExpirationDate, TimeSpan.FromMilliseconds(100)); - } + Assert.Equal(sut.ExpirationDate, result.ExpirationDate, TimeSpan.FromMilliseconds(100)); + } - [Fact] - public void InvalidIfPastExpiryDate() + [Fact] + public void InvalidIfPastExpiryDate() + { + var sut = new TestExpiringTokenable { - var sut = new TestExpiringTokenable - { - ExpirationDate = DateTime.UtcNow.AddHours(-1) - }; + ExpirationDate = DateTime.UtcNow.AddHours(-1) + }; - Assert.False(sut.Valid); - } + Assert.False(sut.Valid); + } - [Fact] - public void ValidIfWithinExpirationAndTokenReportsValid() + [Fact] + public void ValidIfWithinExpirationAndTokenReportsValid() + { + var sut = new TestExpiringTokenable { - var sut = new TestExpiringTokenable - { - ExpirationDate = DateTime.UtcNow.AddHours(1) - }; + ExpirationDate = DateTime.UtcNow.AddHours(1) + }; - Assert.True(sut.Valid); - } + Assert.True(sut.Valid); + } - [Fact] - public void HonorsTokenIsValidAbstractMember() + [Fact] + public void HonorsTokenIsValidAbstractMember() + { + var sut = new TestExpiringTokenable(forceInvalid: true) { - var sut = new TestExpiringTokenable(forceInvalid: true) - { - ExpirationDate = DateTime.UtcNow.AddHours(1) - }; + ExpirationDate = DateTime.UtcNow.AddHours(1) + }; - Assert.False(sut.Valid); - } + Assert.False(sut.Valid); } } diff --git a/test/Core.Test/Tokens/TestTokenable.cs b/test/Core.Test/Tokens/TestTokenable.cs index 7e73cd5e9..c8dee643b 100644 --- a/test/Core.Test/Tokens/TestTokenable.cs +++ b/test/Core.Test/Tokens/TestTokenable.cs @@ -1,26 +1,25 @@ using System.Text.Json.Serialization; using Bit.Core.Tokens; -namespace Bit.Core.Test.Tokens +namespace Bit.Core.Test.Tokens; + +public class TestTokenable : Tokenable { - public class TestTokenable : Tokenable - { - public bool ForceInvalid { get; set; } = false; + public bool ForceInvalid { get; set; } = false; - [JsonIgnore] - public override bool Valid => !ForceInvalid; - } - - public class TestExpiringTokenable : ExpiringTokenable - { - private bool _forceInvalid; - - public TestExpiringTokenable() : this(false) { } - - public TestExpiringTokenable(bool forceInvalid) - { - _forceInvalid = forceInvalid; - } - protected override bool TokenIsValid() => !_forceInvalid; - } + [JsonIgnore] + public override bool Valid => !ForceInvalid; +} + +public class TestExpiringTokenable : ExpiringTokenable +{ + private bool _forceInvalid; + + public TestExpiringTokenable() : this(false) { } + + public TestExpiringTokenable(bool forceInvalid) + { + _forceInvalid = forceInvalid; + } + protected override bool TokenIsValid() => !_forceInvalid; } diff --git a/test/Core.Test/Tokens/TokenTests.cs b/test/Core.Test/Tokens/TokenTests.cs index bc1ad8568..1afad2412 100644 --- a/test/Core.Test/Tokens/TokenTests.cs +++ b/test/Core.Test/Tokens/TokenTests.cs @@ -2,38 +2,37 @@ using Bit.Core.Tokens; using Xunit; -namespace Bit.Core.Test.Tokens +namespace Bit.Core.Test.Tokens; + +public class TokenTests { - public class TokenTests + [Theory, AutoData] + public void InitializeWithString_ReturnsString(string initString) { - [Theory, AutoData] - public void InitializeWithString_ReturnsString(string initString) - { - var token = new Token(initString); + var token = new Token(initString); - Assert.Equal(initString, token.ToString()); - } + Assert.Equal(initString, token.ToString()); + } - [Theory, AutoData] - public void AddsPrefix(Token token, string prefix) - { - Assert.Equal($"{prefix}{token.ToString()}", token.WithPrefix(prefix).ToString()); - } + [Theory, AutoData] + public void AddsPrefix(Token token, string prefix) + { + Assert.Equal($"{prefix}{token.ToString()}", token.WithPrefix(prefix).ToString()); + } - [Theory, AutoData] - public void RemovePrefix_WithPrefix_RemovesPrefix(string initString, string prefix) - { - var token = new Token(initString).WithPrefix(prefix); + [Theory, AutoData] + public void RemovePrefix_WithPrefix_RemovesPrefix(string initString, string prefix) + { + var token = new Token(initString).WithPrefix(prefix); - Assert.Equal(initString, token.RemovePrefix(prefix).ToString()); - } + Assert.Equal(initString, token.RemovePrefix(prefix).ToString()); + } - [Theory, AutoData] - public void RemovePrefix_WithoutPrefix_Throws(Token token, string prefix) - { - var exception = Assert.Throws(() => token.RemovePrefix(prefix)); + [Theory, AutoData] + public void RemovePrefix_WithoutPrefix_Throws(Token token, string prefix) + { + var exception = Assert.Throws(() => token.RemovePrefix(prefix)); - Assert.Equal($"Expected prefix, {prefix}, was not present.", exception.Message); - } + Assert.Equal($"Expected prefix, {prefix}, was not present.", exception.Message); } } diff --git a/test/Core.Test/Utilities/ClaimsExtensionsTests.cs b/test/Core.Test/Utilities/ClaimsExtensionsTests.cs index d6b5c90db..665c64779 100644 --- a/test/Core.Test/Utilities/ClaimsExtensionsTests.cs +++ b/test/Core.Test/Utilities/ClaimsExtensionsTests.cs @@ -2,36 +2,35 @@ using Bit.Core.Utilities; using Xunit; -namespace Bit.Core.Test.Utilities +namespace Bit.Core.Test.Utilities; + +public class ClaimsExtensionsTests { - public class ClaimsExtensionsTests + [Fact] + public void HasSSOIdP_Returns_True_When_The_Claims_Has_One_Of_Type_IdP_And_Value_Sso() { - [Fact] - public void HasSSOIdP_Returns_True_When_The_Claims_Has_One_Of_Type_IdP_And_Value_Sso() - { - var claims = new List { new Claim("idp", "sso") }; - Assert.True(claims.HasSsoIdP()); - } + var claims = new List { new Claim("idp", "sso") }; + Assert.True(claims.HasSsoIdP()); + } - [Fact] - public void HasSSOIdP_Returns_False_When_The_Claims_Has_One_Of_Type_IdP_And_Value_Is_Not_Sso() - { - var claims = new List { new Claim("idp", "asdfasfd") }; - Assert.False(claims.HasSsoIdP()); - } + [Fact] + public void HasSSOIdP_Returns_False_When_The_Claims_Has_One_Of_Type_IdP_And_Value_Is_Not_Sso() + { + var claims = new List { new Claim("idp", "asdfasfd") }; + Assert.False(claims.HasSsoIdP()); + } - [Fact] - public void HasSSOIdP_Returns_False_When_The_Claims_Has_No_One_Of_Type_IdP() - { - var claims = new List { new Claim("qweqweq", "sso") }; - Assert.False(claims.HasSsoIdP()); - } + [Fact] + public void HasSSOIdP_Returns_False_When_The_Claims_Has_No_One_Of_Type_IdP() + { + var claims = new List { new Claim("qweqweq", "sso") }; + Assert.False(claims.HasSsoIdP()); + } - [Fact] - public void HasSSOIdP_Returns_False_When_The_Claims_Are_Empty() - { - var claims = new List(); - Assert.False(claims.HasSsoIdP()); - } + [Fact] + public void HasSSOIdP_Returns_False_When_The_Claims_Are_Empty() + { + var claims = new List(); + Assert.False(claims.HasSsoIdP()); } } diff --git a/test/Core.Test/Utilities/CoreHelpersTests.cs b/test/Core.Test/Utilities/CoreHelpersTests.cs index 37b9c22df..76db48fe3 100644 --- a/test/Core.Test/Utilities/CoreHelpersTests.cs +++ b/test/Core.Test/Utilities/CoreHelpersTests.cs @@ -12,434 +12,433 @@ using IdentityModel; using Microsoft.AspNetCore.DataProtection; using Xunit; -namespace Bit.Core.Test.Utilities +namespace Bit.Core.Test.Utilities; + +public class CoreHelpersTests { - public class CoreHelpersTests + public static IEnumerable _epochTestCases = new[] { - public static IEnumerable _epochTestCases = new[] - { - new object[] {new DateTime(2020, 12, 30, 11, 49, 12, DateTimeKind.Utc), 1609328952000L}, - }; + new object[] {new DateTime(2020, 12, 30, 11, 49, 12, DateTimeKind.Utc), 1609328952000L}, + }; - [Fact] - public void GenerateComb_Success() - { - // Arrange & Act - var comb = CoreHelpers.GenerateComb(); + [Fact] + public void GenerateComb_Success() + { + // Arrange & Act + var comb = CoreHelpers.GenerateComb(); - // Assert - Assert.NotEqual(Guid.Empty, comb); - // TODO: Add more asserts to make sure important aspects of - // the comb are working properly + // Assert + Assert.NotEqual(Guid.Empty, comb); + // TODO: Add more asserts to make sure important aspects of + // the comb are working properly + } + + public static IEnumerable GenerateCombCases = new[] + { + new object[] + { + Guid.Parse("a58db474-43d8-42f1-b4ee-0c17647cd0c0"), // Input Guid + new DateTime(2022, 3, 12, 12, 12, 0, DateTimeKind.Utc), // Input Time + Guid.Parse("a58db474-43d8-42f1-b4ee-ae5600c90cc1"), // Expected Comb + }, + new object[] + { + Guid.Parse("f776e6ee-511f-4352-bb28-88513002bdeb"), + new DateTime(2021, 5, 10, 10, 52, 0, DateTimeKind.Utc), + Guid.Parse("f776e6ee-511f-4352-bb28-ad2400b313c1"), + }, + new object[] + { + Guid.Parse("51a25fc7-3cad-497d-8e2f-8d77011648a1"), + new DateTime(1999, 2, 26, 16, 53, 13, DateTimeKind.Utc), + Guid.Parse("51a25fc7-3cad-497d-8e2f-8d77011649cd"), + }, + new object[] + { + Guid.Parse("bfb8f353-3b32-4a9e-bef6-24fe0b54bfb0"), + new DateTime(2024, 10, 20, 1, 32, 16, DateTimeKind.Utc), + Guid.Parse("bfb8f353-3b32-4a9e-bef6-b20f00195780"), + } + }; + + [Theory] + [MemberData(nameof(GenerateCombCases))] + public void GenerateComb_WithInputs_Success(Guid inputGuid, DateTime inputTime, Guid expectedComb) + { + var comb = CoreHelpers.GenerateComb(inputGuid, inputTime); + + Assert.Equal(expectedComb, comb); + } + + [Theory] + [InlineData(2, 5, new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 0 })] + [InlineData(2, 3, new[] { 1, 2, 3, 4, 5 })] + [InlineData(2, 1, new[] { 1, 2 })] + [InlineData(1, 1, new[] { 1 })] + [InlineData(2, 2, new[] { 1, 2, 3 })] + public void Batch_Success(int batchSize, int totalBatches, int[] collection) + { + // Arrange + var remainder = collection.Length % batchSize; + + // Act + var batches = collection.Batch(batchSize); + + // Assert + Assert.Equal(totalBatches, batches.Count()); + + foreach (var batch in batches.Take(totalBatches - 1)) + { + Assert.Equal(batchSize, batch.Count()); } - public static IEnumerable GenerateCombCases = new[] - { - new object[] + Assert.Equal(batches.Last().Count(), remainder == 0 ? batchSize : remainder); + } + + /* + [Fact] + public void ToGuidIdArrayTVP_Success() + { + // Arrange + var item0 = Guid.NewGuid(); + var item1 = Guid.NewGuid(); + + var ids = new[] { item0, item1 }; + + // Act + var dt = ids.ToGuidIdArrayTVP(); + + // Assert + Assert.Single(dt.Columns); + Assert.Equal("GuidId", dt.Columns[0].ColumnName); + Assert.Equal(2, dt.Rows.Count); + Assert.Equal(item0, dt.Rows[0][0]); + Assert.Equal(item1, dt.Rows[1][0]); + } + */ + + // TODO: Test the other ToArrayTVP Methods + + [Theory] + [InlineData("12345&6789", "123456789")] + [InlineData("abcdef", "ABCDEF")] + [InlineData("1!@#$%&*()_+", "1")] + [InlineData("\u00C6123abc\u00C7", "123ABC")] + [InlineData("123\u00C6ABC", "123ABC")] + [InlineData("\r\nHello", "E")] + [InlineData("\tdef", "DEF")] + [InlineData("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV1234567890", "ABCDEFABCDEF1234567890")] + public void CleanCertificateThumbprint_Success(string input, string output) + { + // Arrange & Act + var sanitizedInput = CoreHelpers.CleanCertificateThumbprint(input); + + // Assert + Assert.Equal(output, sanitizedInput); + } + + // TODO: Add more tests + [Theory] + [MemberData(nameof(_epochTestCases))] + public void ToEpocMilliseconds_Success(DateTime date, long milliseconds) + { + // Act & Assert + Assert.Equal(milliseconds, CoreHelpers.ToEpocMilliseconds(date)); + } + + [Theory] + [MemberData(nameof(_epochTestCases))] + public void FromEpocMilliseconds(DateTime date, long milliseconds) + { + // Act & Assert + Assert.Equal(date, CoreHelpers.FromEpocMilliseconds(milliseconds)); + } + + [Fact] + public void SecureRandomString_Success() + { + // Arrange & Act + var @string = CoreHelpers.SecureRandomString(8); + + // Assert + // TODO: Should probably add more Asserts down the line + Assert.Equal(8, @string.Length); + } + + [Theory] + [InlineData(1, "1 Bytes")] + [InlineData(-5L, "-5 Bytes")] + [InlineData(1023L, "1023 Bytes")] + [InlineData(1024L, "1 KB")] + [InlineData(1025L, "1 KB")] + [InlineData(-1023L, "-1023 Bytes")] + [InlineData(-1024L, "-1 KB")] + [InlineData(-1025L, "-1 KB")] + [InlineData(1048575L, "1024 KB")] + [InlineData(1048576L, "1 MB")] + [InlineData(1048577L, "1 MB")] + [InlineData(-1048575L, "-1024 KB")] + [InlineData(-1048576L, "-1 MB")] + [InlineData(-1048577L, "-1 MB")] + [InlineData(1073741823L, "1024 MB")] + [InlineData(1073741824L, "1 GB")] + [InlineData(1073741825L, "1 GB")] + [InlineData(-1073741823L, "-1024 MB")] + [InlineData(-1073741824L, "-1 GB")] + [InlineData(-1073741825L, "-1 GB")] + [InlineData(long.MaxValue, "8589934592 GB")] + public void ReadableBytesSize_Success(long size, string readable) + { + // Act & Assert + Assert.Equal(readable, CoreHelpers.ReadableBytesSize(size)); + } + + [Fact] + public void CloneObject_Success() + { + var original = new { Message = "Message" }; + + var copy = CoreHelpers.CloneObject(original); + + Assert.Equal(original.Message, copy.Message); + } + + [Fact] + public void ExtendQuery_AddNewParameter_Success() + { + // Arrange + var uri = new Uri("https://bitwarden.com/?param1=value1"); + + // Act + var newUri = CoreHelpers.ExtendQuery(uri, + new Dictionary { { "param2", "value2" } }); + + // Assert + Assert.Equal("https://bitwarden.com/?param1=value1¶m2=value2", newUri.ToString()); + } + + [Fact] + public void ExtendQuery_AddTwoNewParameters_Success() + { + // Arrange + var uri = new Uri("https://bitwarden.com/?param1=value1"); + + // Act + var newUri = CoreHelpers.ExtendQuery(uri, + new Dictionary { - Guid.Parse("a58db474-43d8-42f1-b4ee-0c17647cd0c0"), // Input Guid - new DateTime(2022, 3, 12, 12, 12, 0, DateTimeKind.Utc), // Input Time - Guid.Parse("a58db474-43d8-42f1-b4ee-ae5600c90cc1"), // Expected Comb - }, - new object[] - { - Guid.Parse("f776e6ee-511f-4352-bb28-88513002bdeb"), - new DateTime(2021, 5, 10, 10, 52, 0, DateTimeKind.Utc), - Guid.Parse("f776e6ee-511f-4352-bb28-ad2400b313c1"), - }, - new object[] - { - Guid.Parse("51a25fc7-3cad-497d-8e2f-8d77011648a1"), - new DateTime(1999, 2, 26, 16, 53, 13, DateTimeKind.Utc), - Guid.Parse("51a25fc7-3cad-497d-8e2f-8d77011649cd"), - }, - new object[] - { - Guid.Parse("bfb8f353-3b32-4a9e-bef6-24fe0b54bfb0"), - new DateTime(2024, 10, 20, 1, 32, 16, DateTimeKind.Utc), - Guid.Parse("bfb8f353-3b32-4a9e-bef6-b20f00195780"), - } - }; + { "param2", "value2" }, + { "param3", "value3" } + }); - [Theory] - [MemberData(nameof(GenerateCombCases))] - public void GenerateComb_WithInputs_Success(Guid inputGuid, DateTime inputTime, Guid expectedComb) + // Assert + Assert.Equal("https://bitwarden.com/?param1=value1¶m2=value2¶m3=value3", newUri.ToString()); + } + + [Fact] + public void ExtendQuery_AddExistingParameter_Success() + { + // Arrange + var uri = new Uri("https://bitwarden.com/?param1=value1¶m2=value2"); + + // Act + var newUri = CoreHelpers.ExtendQuery(uri, + new Dictionary { { "param1", "test_value" } }); + + // Assert + Assert.Equal("https://bitwarden.com/?param1=test_value¶m2=value2", newUri.ToString()); + } + + [Fact] + public void ExtendQuery_AddNoParameters_Success() + { + // Arrange + const string startingUri = "https://bitwarden.com/?param1=value1"; + + var uri = new Uri(startingUri); + + // Act + var newUri = CoreHelpers.ExtendQuery(uri, new Dictionary()); + + // Assert + Assert.Equal(startingUri, newUri.ToString()); + } + + [Theory] + [InlineData("bücher.com", "xn--bcher-kva.com")] + [InlineData("bücher.cömé", "xn--bcher-kva.xn--cm-cja4c")] + [InlineData("hello@bücher.com", "hello@xn--bcher-kva.com")] + [InlineData("hello@world.cömé", "hello@world.xn--cm-cja4c")] + [InlineData("hello@bücher.cömé", "hello@xn--bcher-kva.xn--cm-cja4c")] + [InlineData("ascii.com", "ascii.com")] + [InlineData("", "")] + [InlineData(null, null)] + public void PunyEncode_Success(string text, string expected) + { + var actual = CoreHelpers.PunyEncode(text); + Assert.Equal(expected, actual); + } + + [Fact] + public void GetEmbeddedResourceContentsAsync_Success() + { + var fileContents = CoreHelpers.GetEmbeddedResourceContentsAsync("data.embeddedResource.txt"); + Assert.Equal("Contents of embeddedResource.txt\n", fileContents.Replace("\r\n", "\n")); + } + + [Theory, CustomAutoData(typeof(UserFixture))] + public void BuildIdentityClaims_BaseClaims_Success(User user, bool isPremium) + { + var expected = new Dictionary { - var comb = CoreHelpers.GenerateComb(inputGuid, inputTime); + { "premium", isPremium ? "true" : "false" }, + { JwtClaimTypes.Email, user.Email }, + { JwtClaimTypes.EmailVerified, user.EmailVerified ? "true" : "false" }, + { JwtClaimTypes.Name, user.Name }, + { "sstamp", user.SecurityStamp }, + }.ToList(); - Assert.Equal(expectedComb, comb); + var actual = CoreHelpers.BuildIdentityClaims(user, Array.Empty(), + Array.Empty(), isPremium); + + foreach (var claim in expected) + { + Assert.Contains(claim, actual); } + Assert.Equal(expected.Count, actual.Count); + } - [Theory] - [InlineData(2, 5, new[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 0 })] - [InlineData(2, 3, new[] { 1, 2, 3, 4, 5 })] - [InlineData(2, 1, new[] { 1, 2 })] - [InlineData(1, 1, new[] { 1 })] - [InlineData(2, 2, new[] { 1, 2, 3 })] - public void Batch_Success(int batchSize, int totalBatches, int[] collection) + [Theory, CustomAutoData(typeof(UserFixture))] + public void BuildIdentityClaims_NonCustomOrganizationUserType_Success(User user) + { + var fixture = new Fixture().WithAutoNSubstitutions(); + foreach (var organizationUserType in Enum.GetValues().Except(new[] { OrganizationUserType.Custom })) { - // Arrange - var remainder = collection.Length % batchSize; - - // Act - var batches = collection.Batch(batchSize); - - // Assert - Assert.Equal(totalBatches, batches.Count()); - - foreach (var batch in batches.Take(totalBatches - 1)) - { - Assert.Equal(batchSize, batch.Count()); - } - - Assert.Equal(batches.Last().Count(), remainder == 0 ? batchSize : remainder); - } - - /* - [Fact] - public void ToGuidIdArrayTVP_Success() - { - // Arrange - var item0 = Guid.NewGuid(); - var item1 = Guid.NewGuid(); - - var ids = new[] { item0, item1 }; - - // Act - var dt = ids.ToGuidIdArrayTVP(); - - // Assert - Assert.Single(dt.Columns); - Assert.Equal("GuidId", dt.Columns[0].ColumnName); - Assert.Equal(2, dt.Rows.Count); - Assert.Equal(item0, dt.Rows[0][0]); - Assert.Equal(item1, dt.Rows[1][0]); - } - */ - - // TODO: Test the other ToArrayTVP Methods - - [Theory] - [InlineData("12345&6789", "123456789")] - [InlineData("abcdef", "ABCDEF")] - [InlineData("1!@#$%&*()_+", "1")] - [InlineData("\u00C6123abc\u00C7", "123ABC")] - [InlineData("123\u00C6ABC", "123ABC")] - [InlineData("\r\nHello", "E")] - [InlineData("\tdef", "DEF")] - [InlineData("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV1234567890", "ABCDEFABCDEF1234567890")] - public void CleanCertificateThumbprint_Success(string input, string output) - { - // Arrange & Act - var sanitizedInput = CoreHelpers.CleanCertificateThumbprint(input); - - // Assert - Assert.Equal(output, sanitizedInput); - } - - // TODO: Add more tests - [Theory] - [MemberData(nameof(_epochTestCases))] - public void ToEpocMilliseconds_Success(DateTime date, long milliseconds) - { - // Act & Assert - Assert.Equal(milliseconds, CoreHelpers.ToEpocMilliseconds(date)); - } - - [Theory] - [MemberData(nameof(_epochTestCases))] - public void FromEpocMilliseconds(DateTime date, long milliseconds) - { - // Act & Assert - Assert.Equal(date, CoreHelpers.FromEpocMilliseconds(milliseconds)); - } - - [Fact] - public void SecureRandomString_Success() - { - // Arrange & Act - var @string = CoreHelpers.SecureRandomString(8); - - // Assert - // TODO: Should probably add more Asserts down the line - Assert.Equal(8, @string.Length); - } - - [Theory] - [InlineData(1, "1 Bytes")] - [InlineData(-5L, "-5 Bytes")] - [InlineData(1023L, "1023 Bytes")] - [InlineData(1024L, "1 KB")] - [InlineData(1025L, "1 KB")] - [InlineData(-1023L, "-1023 Bytes")] - [InlineData(-1024L, "-1 KB")] - [InlineData(-1025L, "-1 KB")] - [InlineData(1048575L, "1024 KB")] - [InlineData(1048576L, "1 MB")] - [InlineData(1048577L, "1 MB")] - [InlineData(-1048575L, "-1024 KB")] - [InlineData(-1048576L, "-1 MB")] - [InlineData(-1048577L, "-1 MB")] - [InlineData(1073741823L, "1024 MB")] - [InlineData(1073741824L, "1 GB")] - [InlineData(1073741825L, "1 GB")] - [InlineData(-1073741823L, "-1024 MB")] - [InlineData(-1073741824L, "-1 GB")] - [InlineData(-1073741825L, "-1 GB")] - [InlineData(long.MaxValue, "8589934592 GB")] - public void ReadableBytesSize_Success(long size, string readable) - { - // Act & Assert - Assert.Equal(readable, CoreHelpers.ReadableBytesSize(size)); - } - - [Fact] - public void CloneObject_Success() - { - var original = new { Message = "Message" }; - - var copy = CoreHelpers.CloneObject(original); - - Assert.Equal(original.Message, copy.Message); - } - - [Fact] - public void ExtendQuery_AddNewParameter_Success() - { - // Arrange - var uri = new Uri("https://bitwarden.com/?param1=value1"); - - // Act - var newUri = CoreHelpers.ExtendQuery(uri, - new Dictionary { { "param2", "value2" } }); - - // Assert - Assert.Equal("https://bitwarden.com/?param1=value1¶m2=value2", newUri.ToString()); - } - - [Fact] - public void ExtendQuery_AddTwoNewParameters_Success() - { - // Arrange - var uri = new Uri("https://bitwarden.com/?param1=value1"); - - // Act - var newUri = CoreHelpers.ExtendQuery(uri, - new Dictionary - { - { "param2", "value2" }, - { "param3", "value3" } - }); - - // Assert - Assert.Equal("https://bitwarden.com/?param1=value1¶m2=value2¶m3=value3", newUri.ToString()); - } - - [Fact] - public void ExtendQuery_AddExistingParameter_Success() - { - // Arrange - var uri = new Uri("https://bitwarden.com/?param1=value1¶m2=value2"); - - // Act - var newUri = CoreHelpers.ExtendQuery(uri, - new Dictionary { { "param1", "test_value" } }); - - // Assert - Assert.Equal("https://bitwarden.com/?param1=test_value¶m2=value2", newUri.ToString()); - } - - [Fact] - public void ExtendQuery_AddNoParameters_Success() - { - // Arrange - const string startingUri = "https://bitwarden.com/?param1=value1"; - - var uri = new Uri(startingUri); - - // Act - var newUri = CoreHelpers.ExtendQuery(uri, new Dictionary()); - - // Assert - Assert.Equal(startingUri, newUri.ToString()); - } - - [Theory] - [InlineData("bücher.com", "xn--bcher-kva.com")] - [InlineData("bücher.cömé", "xn--bcher-kva.xn--cm-cja4c")] - [InlineData("hello@bücher.com", "hello@xn--bcher-kva.com")] - [InlineData("hello@world.cömé", "hello@world.xn--cm-cja4c")] - [InlineData("hello@bücher.cömé", "hello@xn--bcher-kva.xn--cm-cja4c")] - [InlineData("ascii.com", "ascii.com")] - [InlineData("", "")] - [InlineData(null, null)] - public void PunyEncode_Success(string text, string expected) - { - var actual = CoreHelpers.PunyEncode(text); - Assert.Equal(expected, actual); - } - - [Fact] - public void GetEmbeddedResourceContentsAsync_Success() - { - var fileContents = CoreHelpers.GetEmbeddedResourceContentsAsync("data.embeddedResource.txt"); - Assert.Equal("Contents of embeddedResource.txt\n", fileContents.Replace("\r\n", "\n")); - } - - [Theory, CustomAutoData(typeof(UserFixture))] - public void BuildIdentityClaims_BaseClaims_Success(User user, bool isPremium) - { - var expected = new Dictionary - { - { "premium", isPremium ? "true" : "false" }, - { JwtClaimTypes.Email, user.Email }, - { JwtClaimTypes.EmailVerified, user.EmailVerified ? "true" : "false" }, - { JwtClaimTypes.Name, user.Name }, - { "sstamp", user.SecurityStamp }, - }.ToList(); - - var actual = CoreHelpers.BuildIdentityClaims(user, Array.Empty(), - Array.Empty(), isPremium); - - foreach (var claim in expected) - { - Assert.Contains(claim, actual); - } - Assert.Equal(expected.Count, actual.Count); - } - - [Theory, CustomAutoData(typeof(UserFixture))] - public void BuildIdentityClaims_NonCustomOrganizationUserType_Success(User user) - { - var fixture = new Fixture().WithAutoNSubstitutions(); - foreach (var organizationUserType in Enum.GetValues().Except(new[] { OrganizationUserType.Custom })) - { - var org = fixture.Create(); - org.Type = organizationUserType; - - var expected = new KeyValuePair($"org{organizationUserType.ToString().ToLower()}", org.Id.ToString()); - var actual = CoreHelpers.BuildIdentityClaims(user, new[] { org }, Array.Empty(), false); - - Assert.Contains(expected, actual); - } - } - - [Theory, CustomAutoData(typeof(UserFixture))] - public void BuildIdentityClaims_CustomOrganizationUserClaims_Success(User user, CurrentContentOrganization org) - { - var fixture = new Fixture().WithAutoNSubstitutions(); - org.Type = OrganizationUserType.Custom; + var org = fixture.Create(); + org.Type = organizationUserType; + var expected = new KeyValuePair($"org{organizationUserType.ToString().ToLower()}", org.Id.ToString()); var actual = CoreHelpers.BuildIdentityClaims(user, new[] { org }, Array.Empty(), false); - foreach (var (permitted, claimName) in org.Permissions.ClaimsMap) - { - var claim = new KeyValuePair(claimName, org.Id.ToString()); - if (permitted) - { - Assert.Contains(claim, actual); - } - else - { - Assert.DoesNotContain(claim, actual); - } - } - } - - [Theory, CustomAutoData(typeof(UserFixture))] - public void BuildIdentityClaims_ProviderClaims_Success(User user) - { - var fixture = new Fixture().WithAutoNSubstitutions(); - var providers = new List(); - foreach (var providerUserType in Enum.GetValues()) - { - var provider = fixture.Create(); - provider.Type = providerUserType; - providers.Add(provider); - } - - var claims = new List>(); - - if (providers.Any()) - { - foreach (var group in providers.GroupBy(o => o.Type)) - { - switch (group.Key) - { - case ProviderUserType.ProviderAdmin: - foreach (var provider in group) - { - claims.Add(new KeyValuePair("providerprovideradmin", provider.Id.ToString())); - } - break; - case ProviderUserType.ServiceUser: - foreach (var provider in group) - { - claims.Add(new KeyValuePair("providerserviceuser", provider.Id.ToString())); - } - break; - } - } - } - - var actual = CoreHelpers.BuildIdentityClaims(user, Array.Empty(), providers, false); - foreach (var claim in claims) - { - Assert.Contains(claim, actual); - } - } - - public static IEnumerable TokenIsValidData() - { - return new[] - { - new object[] - { - "first_part 476669d4-9642-4af8-9b29-9366efad4ed3 test@email.com {0}", // unprotectedTokenTemplate - "first_part", // firstPart - "test@email.com", // email - Guid.Parse("476669d4-9642-4af8-9b29-9366efad4ed3"), // id - DateTime.UtcNow.AddHours(-1), // creationTime - 12, // expirationInHours - true, // isValid - } - }; - } - - [Theory] - [MemberData(nameof(TokenIsValidData))] - public void TokenIsValid_Success(string unprotectedTokenTemplate, string firstPart, string userEmail, Guid id, DateTime creationTime, double expirationInHours, bool isValid) - { - var protector = new TestDataProtector(string.Format(unprotectedTokenTemplate, CoreHelpers.ToEpocMilliseconds(creationTime))); - - Assert.Equal(isValid, CoreHelpers.TokenIsValid(firstPart, protector, "protected_token", userEmail, id, expirationInHours)); - } - - private class TestDataProtector : IDataProtector - { - private readonly string _token; - public TestDataProtector(string token) - { - _token = token; - } - public IDataProtector CreateProtector(string purpose) => throw new NotImplementedException(); - public byte[] Protect(byte[] plaintext) => throw new NotImplementedException(); - public byte[] Unprotect(byte[] protectedData) - { - return Encoding.UTF8.GetBytes(_token); - } - } - - [Theory] - [InlineData("hi@email.com", "hi@email.com")] // Short email with no room to obfuscate - [InlineData("name@email.com", "na**@email.com")] // Can obfuscate - [InlineData("reallylongnamethatnooneshouldhave@email", "re*******************************@email")] // Really long email and no .com, .net, etc - [InlineData("name@", "name@")] // @ symbol but no domain - [InlineData("", "")] // Empty string - [InlineData(null, null)] // null - public void ObfuscateEmail_Success(string input, string expected) - { - Assert.Equal(expected, CoreHelpers.ObfuscateEmail(input)); + Assert.Contains(expected, actual); } } + + [Theory, CustomAutoData(typeof(UserFixture))] + public void BuildIdentityClaims_CustomOrganizationUserClaims_Success(User user, CurrentContentOrganization org) + { + var fixture = new Fixture().WithAutoNSubstitutions(); + org.Type = OrganizationUserType.Custom; + + var actual = CoreHelpers.BuildIdentityClaims(user, new[] { org }, Array.Empty(), false); + foreach (var (permitted, claimName) in org.Permissions.ClaimsMap) + { + var claim = new KeyValuePair(claimName, org.Id.ToString()); + if (permitted) + { + + Assert.Contains(claim, actual); + } + else + { + Assert.DoesNotContain(claim, actual); + } + } + } + + [Theory, CustomAutoData(typeof(UserFixture))] + public void BuildIdentityClaims_ProviderClaims_Success(User user) + { + var fixture = new Fixture().WithAutoNSubstitutions(); + var providers = new List(); + foreach (var providerUserType in Enum.GetValues()) + { + var provider = fixture.Create(); + provider.Type = providerUserType; + providers.Add(provider); + } + + var claims = new List>(); + + if (providers.Any()) + { + foreach (var group in providers.GroupBy(o => o.Type)) + { + switch (group.Key) + { + case ProviderUserType.ProviderAdmin: + foreach (var provider in group) + { + claims.Add(new KeyValuePair("providerprovideradmin", provider.Id.ToString())); + } + break; + case ProviderUserType.ServiceUser: + foreach (var provider in group) + { + claims.Add(new KeyValuePair("providerserviceuser", provider.Id.ToString())); + } + break; + } + } + } + + var actual = CoreHelpers.BuildIdentityClaims(user, Array.Empty(), providers, false); + foreach (var claim in claims) + { + Assert.Contains(claim, actual); + } + } + + public static IEnumerable TokenIsValidData() + { + return new[] + { + new object[] + { + "first_part 476669d4-9642-4af8-9b29-9366efad4ed3 test@email.com {0}", // unprotectedTokenTemplate + "first_part", // firstPart + "test@email.com", // email + Guid.Parse("476669d4-9642-4af8-9b29-9366efad4ed3"), // id + DateTime.UtcNow.AddHours(-1), // creationTime + 12, // expirationInHours + true, // isValid + } + }; + } + + [Theory] + [MemberData(nameof(TokenIsValidData))] + public void TokenIsValid_Success(string unprotectedTokenTemplate, string firstPart, string userEmail, Guid id, DateTime creationTime, double expirationInHours, bool isValid) + { + var protector = new TestDataProtector(string.Format(unprotectedTokenTemplate, CoreHelpers.ToEpocMilliseconds(creationTime))); + + Assert.Equal(isValid, CoreHelpers.TokenIsValid(firstPart, protector, "protected_token", userEmail, id, expirationInHours)); + } + + private class TestDataProtector : IDataProtector + { + private readonly string _token; + public TestDataProtector(string token) + { + _token = token; + } + public IDataProtector CreateProtector(string purpose) => throw new NotImplementedException(); + public byte[] Protect(byte[] plaintext) => throw new NotImplementedException(); + public byte[] Unprotect(byte[] protectedData) + { + return Encoding.UTF8.GetBytes(_token); + } + } + + [Theory] + [InlineData("hi@email.com", "hi@email.com")] // Short email with no room to obfuscate + [InlineData("name@email.com", "na**@email.com")] // Can obfuscate + [InlineData("reallylongnamethatnooneshouldhave@email", "re*******************************@email")] // Really long email and no .com, .net, etc + [InlineData("name@", "name@")] // @ symbol but no domain + [InlineData("", "")] // Empty string + [InlineData(null, null)] // null + public void ObfuscateEmail_Success(string input, string expected) + { + Assert.Equal(expected, CoreHelpers.ObfuscateEmail(input)); + } } diff --git a/test/Core.Test/Utilities/EncryptedStringAttributeTests.cs b/test/Core.Test/Utilities/EncryptedStringAttributeTests.cs index 09ee18847..c16a983cf 100644 --- a/test/Core.Test/Utilities/EncryptedStringAttributeTests.cs +++ b/test/Core.Test/Utilities/EncryptedStringAttributeTests.cs @@ -1,43 +1,42 @@ using Bit.Core.Utilities; using Xunit; -namespace Bit.Core.Test.Utilities +namespace Bit.Core.Test.Utilities; + +public class EncryptedStringAttributeTests { - public class EncryptedStringAttributeTests + [Theory] + [InlineData(null)] + [InlineData("aXY=|Y3Q=")] // Valid AesCbc256_B64 + [InlineData("aXY=|Y3Q=|cnNhQ3Q=")] // Valid AesCbc128_HmacSha256_B64 + [InlineData("Rsa2048_OaepSha256_B64.cnNhQ3Q=")] + public void IsValid_ReturnsTrue_WhenValid(string input) { - [Theory] - [InlineData(null)] - [InlineData("aXY=|Y3Q=")] // Valid AesCbc256_B64 - [InlineData("aXY=|Y3Q=|cnNhQ3Q=")] // Valid AesCbc128_HmacSha256_B64 - [InlineData("Rsa2048_OaepSha256_B64.cnNhQ3Q=")] - public void IsValid_ReturnsTrue_WhenValid(string input) - { - var sut = new EncryptedStringAttribute(); + var sut = new EncryptedStringAttribute(); - var actual = sut.IsValid(input); + var actual = sut.IsValid(input); - Assert.True(actual); - } + Assert.True(actual); + } - [Theory] - [InlineData("")] - [InlineData(".")] - [InlineData("|")] - [InlineData("!|!")] // Invalid base 64 - [InlineData("Rsa2048_OaepSha1_HmacSha256_B64.1")] // Invalid length - [InlineData("Rsa2048_OaepSha1_HmacSha256_B64.|")] // Empty iv & ct - [InlineData("AesCbc128_HmacSha256_B64.1")] // Invalid length - [InlineData("AesCbc128_HmacSha256_B64.aXY=|Y3Q=|")] // Empty mac - [InlineData("Rsa2048_OaepSha1_HmacSha256_B64.aXY=|Y3Q=|")] // Empty mac - [InlineData("Rsa2048_OaepSha256_B64.1|2")] // Invalid length - [InlineData("Rsa2048_OaepSha1_HmacSha256_B64.aXY=|")] // Empty mac - public void IsValid_ReturnsFalse_WhenInvalid(string input) - { - var sut = new EncryptedStringAttribute(); + [Theory] + [InlineData("")] + [InlineData(".")] + [InlineData("|")] + [InlineData("!|!")] // Invalid base 64 + [InlineData("Rsa2048_OaepSha1_HmacSha256_B64.1")] // Invalid length + [InlineData("Rsa2048_OaepSha1_HmacSha256_B64.|")] // Empty iv & ct + [InlineData("AesCbc128_HmacSha256_B64.1")] // Invalid length + [InlineData("AesCbc128_HmacSha256_B64.aXY=|Y3Q=|")] // Empty mac + [InlineData("Rsa2048_OaepSha1_HmacSha256_B64.aXY=|Y3Q=|")] // Empty mac + [InlineData("Rsa2048_OaepSha256_B64.1|2")] // Invalid length + [InlineData("Rsa2048_OaepSha1_HmacSha256_B64.aXY=|")] // Empty mac + public void IsValid_ReturnsFalse_WhenInvalid(string input) + { + var sut = new EncryptedStringAttribute(); - var actual = sut.IsValid(input); + var actual = sut.IsValid(input); - Assert.False(actual); - } + Assert.False(actual); } } diff --git a/test/Core.Test/Utilities/JsonHelpersTests.cs b/test/Core.Test/Utilities/JsonHelpersTests.cs index 8a9a26614..8c12cf22e 100644 --- a/test/Core.Test/Utilities/JsonHelpersTests.cs +++ b/test/Core.Test/Utilities/JsonHelpersTests.cs @@ -2,65 +2,64 @@ using Bit.Core.Utilities; using Xunit; -namespace Bit.Core.Test.Helpers +namespace Bit.Core.Test.Helpers; + +public class JsonHelpersTests { - public class JsonHelpersTests + private static void CompareJson(T value, JsonSerializerOptions options, Newtonsoft.Json.JsonSerializerSettings settings) { - private static void CompareJson(T value, JsonSerializerOptions options, Newtonsoft.Json.JsonSerializerSettings settings) - { - var stgJson = JsonSerializer.Serialize(value, options); - var nsJson = Newtonsoft.Json.JsonConvert.SerializeObject(value, settings); + var stgJson = JsonSerializer.Serialize(value, options); + var nsJson = Newtonsoft.Json.JsonConvert.SerializeObject(value, settings); - Assert.Equal(stgJson, nsJson); - } - - - [Fact] - public void DefaultJsonOptions() - { - var testObject = new SimpleTestObject - { - Id = 0, - Name = "Test", - }; - - CompareJson(testObject, JsonHelpers.Default, new Newtonsoft.Json.JsonSerializerSettings()); - } - - [Fact] - public void IndentedJsonOptions() - { - var testObject = new SimpleTestObject - { - Id = 10, - Name = "Test Name" - }; - - CompareJson(testObject, JsonHelpers.Indented, new Newtonsoft.Json.JsonSerializerSettings - { - Formatting = Newtonsoft.Json.Formatting.Indented, - }); - } - - [Fact] - public void NullValueHandlingJsonOptions() - { - var testObject = new SimpleTestObject - { - Id = 14, - Name = null, - }; - - CompareJson(testObject, JsonHelpers.IgnoreWritingNull, new Newtonsoft.Json.JsonSerializerSettings - { - NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore, - }); - } + Assert.Equal(stgJson, nsJson); } - public class SimpleTestObject + + [Fact] + public void DefaultJsonOptions() { - public int Id { get; set; } - public string Name { get; set; } + var testObject = new SimpleTestObject + { + Id = 0, + Name = "Test", + }; + + CompareJson(testObject, JsonHelpers.Default, new Newtonsoft.Json.JsonSerializerSettings()); + } + + [Fact] + public void IndentedJsonOptions() + { + var testObject = new SimpleTestObject + { + Id = 10, + Name = "Test Name" + }; + + CompareJson(testObject, JsonHelpers.Indented, new Newtonsoft.Json.JsonSerializerSettings + { + Formatting = Newtonsoft.Json.Formatting.Indented, + }); + } + + [Fact] + public void NullValueHandlingJsonOptions() + { + var testObject = new SimpleTestObject + { + Id = 14, + Name = null, + }; + + CompareJson(testObject, JsonHelpers.IgnoreWritingNull, new Newtonsoft.Json.JsonSerializerSettings + { + NullValueHandling = Newtonsoft.Json.NullValueHandling.Ignore, + }); } } + +public class SimpleTestObject +{ + public int Id { get; set; } + public string Name { get; set; } +} diff --git a/test/Core.Test/Utilities/PermissiveStringConverterTests.cs b/test/Core.Test/Utilities/PermissiveStringConverterTests.cs index 396d277e4..dc23b1acb 100644 --- a/test/Core.Test/Utilities/PermissiveStringConverterTests.cs +++ b/test/Core.Test/Utilities/PermissiveStringConverterTests.cs @@ -5,165 +5,164 @@ using Bit.Core.Utilities; using Bit.Test.Common.Helpers; using Xunit; -namespace Bit.Core.Test.Utilities +namespace Bit.Core.Test.Utilities; + +public class PermissiveStringConverterTests { - public class PermissiveStringConverterTests + private const string numberJson = "{ \"StringProp\": 1, \"EnumerableStringProp\": [ 2, 3 ]}"; + private const string stringJson = "{ \"StringProp\": \"1\", \"EnumerableStringProp\": [ \"2\", \"3\" ]}"; + private const string nullAndEmptyJson = "{ \"StringProp\": null, \"EnumerableStringProp\": [] }"; + private const string singleValueJson = "{ \"StringProp\": 1, \"EnumerableStringProp\": \"Hello!\" }"; + private const string nullJson = "{ \"StringProp\": null, \"EnumerableStringProp\": null }"; + private const string boolJson = "{ \"StringProp\": true, \"EnumerableStringProp\": [ false, 1.2]}"; + private const string objectJsonOne = "{ \"StringProp\": { \"Message\": \"Hi\"}, \"EnumerableStringProp\": []}"; + private const string objectJsonTwo = "{ \"StringProp\": \"Hi\", \"EnumerableStringProp\": {}}"; + private readonly string bigNumbersJson = + "{ \"StringProp\":" + decimal.MinValue + ", \"EnumerableStringProp\": [" + ulong.MaxValue + ", " + long.MinValue + "]}"; + + [Theory] + [InlineData(numberJson)] + [InlineData(stringJson)] + public void Read_Success(string json) { - private const string numberJson = "{ \"StringProp\": 1, \"EnumerableStringProp\": [ 2, 3 ]}"; - private const string stringJson = "{ \"StringProp\": \"1\", \"EnumerableStringProp\": [ \"2\", \"3\" ]}"; - private const string nullAndEmptyJson = "{ \"StringProp\": null, \"EnumerableStringProp\": [] }"; - private const string singleValueJson = "{ \"StringProp\": 1, \"EnumerableStringProp\": \"Hello!\" }"; - private const string nullJson = "{ \"StringProp\": null, \"EnumerableStringProp\": null }"; - private const string boolJson = "{ \"StringProp\": true, \"EnumerableStringProp\": [ false, 1.2]}"; - private const string objectJsonOne = "{ \"StringProp\": { \"Message\": \"Hi\"}, \"EnumerableStringProp\": []}"; - private const string objectJsonTwo = "{ \"StringProp\": \"Hi\", \"EnumerableStringProp\": {}}"; - private readonly string bigNumbersJson = - "{ \"StringProp\":" + decimal.MinValue + ", \"EnumerableStringProp\": [" + ulong.MaxValue + ", " + long.MinValue + "]}"; - - [Theory] - [InlineData(numberJson)] - [InlineData(stringJson)] - public void Read_Success(string json) - { - var obj = JsonSerializer.Deserialize(json); - Assert.Equal("1", obj.StringProp); - Assert.Equal(2, obj.EnumerableStringProp.Count()); - Assert.Equal("2", obj.EnumerableStringProp.ElementAt(0)); - Assert.Equal("3", obj.EnumerableStringProp.ElementAt(1)); - } - - [Fact] - public void Read_Boolean_Success() - { - var obj = JsonSerializer.Deserialize(boolJson); - Assert.Equal("True", obj.StringProp); - Assert.Equal(2, obj.EnumerableStringProp.Count()); - Assert.Equal("False", obj.EnumerableStringProp.ElementAt(0)); - Assert.Equal("1.2", obj.EnumerableStringProp.ElementAt(1)); - } - - [Fact] - public void Read_Float_Success_Culture() - { - var ci = new CultureInfo("sv-SE"); - Thread.CurrentThread.CurrentCulture = ci; - Thread.CurrentThread.CurrentUICulture = ci; - - var obj = JsonSerializer.Deserialize(boolJson); - Assert.Equal("1.2", obj.EnumerableStringProp.ElementAt(1)); - } - - [Fact] - public void Read_BigNumbers_Success() - { - var obj = JsonSerializer.Deserialize(bigNumbersJson); - Assert.Equal(decimal.MinValue.ToString(), obj.StringProp); - Assert.Equal(2, obj.EnumerableStringProp.Count()); - Assert.Equal(ulong.MaxValue.ToString(), obj.EnumerableStringProp.ElementAt(0)); - Assert.Equal(long.MinValue.ToString(), obj.EnumerableStringProp.ElementAt(1)); - } - - [Fact] - public void Read_SingleValue_Success() - { - var obj = JsonSerializer.Deserialize(singleValueJson); - Assert.Equal("1", obj.StringProp); - Assert.Single(obj.EnumerableStringProp); - Assert.Equal("Hello!", obj.EnumerableStringProp.ElementAt(0)); - } - - [Fact] - public void Read_NullAndEmptyJson_Success() - { - var obj = JsonSerializer.Deserialize(nullAndEmptyJson); - Assert.Null(obj.StringProp); - Assert.Empty(obj.EnumerableStringProp); - } - - [Fact] - public void Read_Null_Success() - { - var obj = JsonSerializer.Deserialize(nullJson); - Assert.Null(obj.StringProp); - Assert.Null(obj.EnumerableStringProp); - } - - [Theory] - [InlineData(objectJsonOne)] - [InlineData(objectJsonTwo)] - public void Read_Object_Throws(string json) - { - var exception = Assert.Throws(() => JsonSerializer.Deserialize(json)); - } - - [Fact] - public void Write_Success() - { - var json = JsonSerializer.Serialize(new TestObject - { - StringProp = "1", - EnumerableStringProp = new List - { - "2", - "3", - }, - }); - - var jsonElement = JsonDocument.Parse(json).RootElement; - - var stringProp = AssertHelper.AssertJsonProperty(jsonElement, "StringProp", JsonValueKind.String); - Assert.Equal("1", stringProp.GetString()); - var list = AssertHelper.AssertJsonProperty(jsonElement, "EnumerableStringProp", JsonValueKind.Array); - Assert.Equal(2, list.GetArrayLength()); - var firstElement = list[0]; - Assert.Equal(JsonValueKind.String, firstElement.ValueKind); - Assert.Equal("2", firstElement.GetString()); - var secondElement = list[1]; - Assert.Equal(JsonValueKind.String, secondElement.ValueKind); - Assert.Equal("3", secondElement.GetString()); - } - - [Fact] - public void Write_Null() - { - // When the values are null the converters aren't actually ran and it automatically serializes null - var json = JsonSerializer.Serialize(new TestObject - { - StringProp = null, - EnumerableStringProp = null, - }); - - var jsonElement = JsonDocument.Parse(json).RootElement; - - AssertHelper.AssertJsonProperty(jsonElement, "StringProp", JsonValueKind.Null); - AssertHelper.AssertJsonProperty(jsonElement, "EnumerableStringProp", JsonValueKind.Null); - } - - [Fact] - public void Write_Empty() - { - // When the values are null the converters aren't actually ran and it automatically serializes null - var json = JsonSerializer.Serialize(new TestObject - { - StringProp = "", - EnumerableStringProp = Enumerable.Empty(), - }); - - var jsonElement = JsonDocument.Parse(json).RootElement; - - var stringVal = AssertHelper.AssertJsonProperty(jsonElement, "StringProp", JsonValueKind.String).GetString(); - Assert.Equal("", stringVal); - var array = AssertHelper.AssertJsonProperty(jsonElement, "EnumerableStringProp", JsonValueKind.Array); - Assert.Equal(0, array.GetArrayLength()); - } + var obj = JsonSerializer.Deserialize(json); + Assert.Equal("1", obj.StringProp); + Assert.Equal(2, obj.EnumerableStringProp.Count()); + Assert.Equal("2", obj.EnumerableStringProp.ElementAt(0)); + Assert.Equal("3", obj.EnumerableStringProp.ElementAt(1)); } - public class TestObject + [Fact] + public void Read_Boolean_Success() { - [JsonConverter(typeof(PermissiveStringConverter))] - public string StringProp { get; set; } + var obj = JsonSerializer.Deserialize(boolJson); + Assert.Equal("True", obj.StringProp); + Assert.Equal(2, obj.EnumerableStringProp.Count()); + Assert.Equal("False", obj.EnumerableStringProp.ElementAt(0)); + Assert.Equal("1.2", obj.EnumerableStringProp.ElementAt(1)); + } - [JsonConverter(typeof(PermissiveStringEnumerableConverter))] - public IEnumerable EnumerableStringProp { get; set; } + [Fact] + public void Read_Float_Success_Culture() + { + var ci = new CultureInfo("sv-SE"); + Thread.CurrentThread.CurrentCulture = ci; + Thread.CurrentThread.CurrentUICulture = ci; + + var obj = JsonSerializer.Deserialize(boolJson); + Assert.Equal("1.2", obj.EnumerableStringProp.ElementAt(1)); + } + + [Fact] + public void Read_BigNumbers_Success() + { + var obj = JsonSerializer.Deserialize(bigNumbersJson); + Assert.Equal(decimal.MinValue.ToString(), obj.StringProp); + Assert.Equal(2, obj.EnumerableStringProp.Count()); + Assert.Equal(ulong.MaxValue.ToString(), obj.EnumerableStringProp.ElementAt(0)); + Assert.Equal(long.MinValue.ToString(), obj.EnumerableStringProp.ElementAt(1)); + } + + [Fact] + public void Read_SingleValue_Success() + { + var obj = JsonSerializer.Deserialize(singleValueJson); + Assert.Equal("1", obj.StringProp); + Assert.Single(obj.EnumerableStringProp); + Assert.Equal("Hello!", obj.EnumerableStringProp.ElementAt(0)); + } + + [Fact] + public void Read_NullAndEmptyJson_Success() + { + var obj = JsonSerializer.Deserialize(nullAndEmptyJson); + Assert.Null(obj.StringProp); + Assert.Empty(obj.EnumerableStringProp); + } + + [Fact] + public void Read_Null_Success() + { + var obj = JsonSerializer.Deserialize(nullJson); + Assert.Null(obj.StringProp); + Assert.Null(obj.EnumerableStringProp); + } + + [Theory] + [InlineData(objectJsonOne)] + [InlineData(objectJsonTwo)] + public void Read_Object_Throws(string json) + { + var exception = Assert.Throws(() => JsonSerializer.Deserialize(json)); + } + + [Fact] + public void Write_Success() + { + var json = JsonSerializer.Serialize(new TestObject + { + StringProp = "1", + EnumerableStringProp = new List + { + "2", + "3", + }, + }); + + var jsonElement = JsonDocument.Parse(json).RootElement; + + var stringProp = AssertHelper.AssertJsonProperty(jsonElement, "StringProp", JsonValueKind.String); + Assert.Equal("1", stringProp.GetString()); + var list = AssertHelper.AssertJsonProperty(jsonElement, "EnumerableStringProp", JsonValueKind.Array); + Assert.Equal(2, list.GetArrayLength()); + var firstElement = list[0]; + Assert.Equal(JsonValueKind.String, firstElement.ValueKind); + Assert.Equal("2", firstElement.GetString()); + var secondElement = list[1]; + Assert.Equal(JsonValueKind.String, secondElement.ValueKind); + Assert.Equal("3", secondElement.GetString()); + } + + [Fact] + public void Write_Null() + { + // When the values are null the converters aren't actually ran and it automatically serializes null + var json = JsonSerializer.Serialize(new TestObject + { + StringProp = null, + EnumerableStringProp = null, + }); + + var jsonElement = JsonDocument.Parse(json).RootElement; + + AssertHelper.AssertJsonProperty(jsonElement, "StringProp", JsonValueKind.Null); + AssertHelper.AssertJsonProperty(jsonElement, "EnumerableStringProp", JsonValueKind.Null); + } + + [Fact] + public void Write_Empty() + { + // When the values are null the converters aren't actually ran and it automatically serializes null + var json = JsonSerializer.Serialize(new TestObject + { + StringProp = "", + EnumerableStringProp = Enumerable.Empty(), + }); + + var jsonElement = JsonDocument.Parse(json).RootElement; + + var stringVal = AssertHelper.AssertJsonProperty(jsonElement, "StringProp", JsonValueKind.String).GetString(); + Assert.Equal("", stringVal); + var array = AssertHelper.AssertJsonProperty(jsonElement, "EnumerableStringProp", JsonValueKind.Array); + Assert.Equal(0, array.GetArrayLength()); } } + +public class TestObject +{ + [JsonConverter(typeof(PermissiveStringConverter))] + public string StringProp { get; set; } + + [JsonConverter(typeof(PermissiveStringEnumerableConverter))] + public IEnumerable EnumerableStringProp { get; set; } +} diff --git a/test/Core.Test/Utilities/SelfHostedAttributeTests.cs b/test/Core.Test/Utilities/SelfHostedAttributeTests.cs index 4261cf7f5..564c32839 100644 --- a/test/Core.Test/Utilities/SelfHostedAttributeTests.cs +++ b/test/Core.Test/Utilities/SelfHostedAttributeTests.cs @@ -10,83 +10,82 @@ using Microsoft.Extensions.DependencyInjection; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Utilities +namespace Bit.Core.Test.Utilities; + +public class SelfHostedAttributeTests { - public class SelfHostedAttributeTests + [Fact] + public void NotSelfHosted_Throws_When_SelfHosted() { - [Fact] - public void NotSelfHosted_Throws_When_SelfHosted() + // Arrange + var sha = new SelfHostedAttribute { NotSelfHostedOnly = true }; + + // Act & Assert + Assert.Throws(() => sha.OnActionExecuting(GetContext(selfHosted: true))); + } + + [Fact] + public void NotSelfHosted_Success_When_NotSelfHosted() + { + // Arrange + var sha = new SelfHostedAttribute { NotSelfHostedOnly = true }; + + // Act + sha.OnActionExecuting(GetContext(selfHosted: false)); + + // Assert + // The Assert here is just NOT throwing an exception + } + + + [Fact] + public void SelfHosted_Success_When_SelfHosted() + { + // Arrange + var sha = new SelfHostedAttribute { SelfHostedOnly = true }; + + // Act + sha.OnActionExecuting(GetContext(selfHosted: true)); + + // Assert + // The Assert here is just NOT throwing an exception + } + + [Fact] + public void SelfHosted_Throws_When_NotSelfHosted() + { + // Arrange + var sha = new SelfHostedAttribute { SelfHostedOnly = true }; + + // Act & Assert + Assert.Throws(() => sha.OnActionExecuting(GetContext(selfHosted: false))); + } + + + // This generates a ActionExecutingContext with the needed injected + // service with the given value. + private ActionExecutingContext GetContext(bool selfHosted) + { + IServiceCollection services = new ServiceCollection(); + + var globalSettings = new GlobalSettings { - // Arrange - var sha = new SelfHostedAttribute { NotSelfHostedOnly = true }; + SelfHosted = selfHosted + }; - // Act & Assert - Assert.Throws(() => sha.OnActionExecuting(GetContext(selfHosted: true))); - } + services.AddSingleton(globalSettings); - [Fact] - public void NotSelfHosted_Success_When_NotSelfHosted() - { - // Arrange - var sha = new SelfHostedAttribute { NotSelfHostedOnly = true }; + var httpContext = new DefaultHttpContext(); + httpContext.RequestServices = services.BuildServiceProvider(); - // Act - sha.OnActionExecuting(GetContext(selfHosted: false)); + var context = Substitute.For( + Substitute.For(httpContext, + new RouteData(), + Substitute.For()), + new List(), + new Dictionary(), + Substitute.For()); - // Assert - // The Assert here is just NOT throwing an exception - } - - - [Fact] - public void SelfHosted_Success_When_SelfHosted() - { - // Arrange - var sha = new SelfHostedAttribute { SelfHostedOnly = true }; - - // Act - sha.OnActionExecuting(GetContext(selfHosted: true)); - - // Assert - // The Assert here is just NOT throwing an exception - } - - [Fact] - public void SelfHosted_Throws_When_NotSelfHosted() - { - // Arrange - var sha = new SelfHostedAttribute { SelfHostedOnly = true }; - - // Act & Assert - Assert.Throws(() => sha.OnActionExecuting(GetContext(selfHosted: false))); - } - - - // This generates a ActionExecutingContext with the needed injected - // service with the given value. - private ActionExecutingContext GetContext(bool selfHosted) - { - IServiceCollection services = new ServiceCollection(); - - var globalSettings = new GlobalSettings - { - SelfHosted = selfHosted - }; - - services.AddSingleton(globalSettings); - - var httpContext = new DefaultHttpContext(); - httpContext.RequestServices = services.BuildServiceProvider(); - - var context = Substitute.For( - Substitute.For(httpContext, - new RouteData(), - Substitute.For()), - new List(), - new Dictionary(), - Substitute.For()); - - return context; - } + return context; } } diff --git a/test/Core.Test/Utilities/StrictEmailAddressAttributeTests.cs b/test/Core.Test/Utilities/StrictEmailAddressAttributeTests.cs index 6fac59562..bcd3efcc1 100644 --- a/test/Core.Test/Utilities/StrictEmailAddressAttributeTests.cs +++ b/test/Core.Test/Utilities/StrictEmailAddressAttributeTests.cs @@ -1,59 +1,58 @@ using Bit.Core.Utilities; using Xunit; -namespace Bit.Core.Test.Utilities +namespace Bit.Core.Test.Utilities; + +public class StrictEmailAttributeTests { - public class StrictEmailAttributeTests + [Theory] + [InlineData("hello@world.com")] // regular email address + [InlineData("hello@world.planet.com")] // subdomain + [InlineData("hello+1@world.com")] // alias + [InlineData("hello.there@world.com")] // period in local-part + [InlineData("hello@wörldé.com")] // unicode domain + [InlineData("hello@world.cömé")] // unicode top-level domain + public void IsValid_ReturnsTrueWhenValid(string email) { - [Theory] - [InlineData("hello@world.com")] // regular email address - [InlineData("hello@world.planet.com")] // subdomain - [InlineData("hello+1@world.com")] // alias - [InlineData("hello.there@world.com")] // period in local-part - [InlineData("hello@wörldé.com")] // unicode domain - [InlineData("hello@world.cömé")] // unicode top-level domain - public void IsValid_ReturnsTrueWhenValid(string email) - { - var sut = new StrictEmailAddressAttribute(); + var sut = new StrictEmailAddressAttribute(); - var actual = sut.IsValid(email); + var actual = sut.IsValid(email); - Assert.True(actual); - } + Assert.True(actual); + } - [Theory] - [InlineData(null)] // null - [InlineData("hello@world.com\t")] // trailing tab char - [InlineData("\thello@world.com")] // leading tab char - [InlineData("hel\tlo@world.com")] // local-part tab char - [InlineData("hello@world.com\b")] // trailing backspace char - [InlineData("\" \"hello@world.com")] // leading spaces in quotes - [InlineData("hello@world.com\" \"")] // trailing spaces in quotes - [InlineData("hel\" \"lo@world.com")] // local-part spaces in quotes - [InlineData("hello there@world.com")] // unescaped unquoted spaces - [InlineData("Hello ")] // friendly from - [InlineData("")] // wrapped angle brackets - [InlineData("hello(com)there@world.com")] // comment - [InlineData("hello@world.com.")] // trailing period - [InlineData(".hello@world.com")] // leading period - [InlineData("hello@world.com;")] // trailing semicolon - [InlineData(";hello@world.com")] // leading semicolon - [InlineData("hello@world.com; hello@world.com")] // semicolon separated list - [InlineData("hello@world.com, hello@world.com")] // comma separated list - [InlineData("hellothere@worldcom")] // dotless domain - [InlineData("hello.there@worldcom")] // dotless domain - [InlineData("hellothere@.worldcom")] // domain beginning with dot - [InlineData("hellothere@worldcom.")] // domain ending in dot - [InlineData("hellothere@world.com-")] // domain ending in hyphen - [InlineData("hellö@world.com")] // unicode at end of local-part - [InlineData("héllo@world.com")] // unicode in middle of local-part - public void IsValid_ReturnsFalseWhenInvalid(string email) - { - var sut = new StrictEmailAddressAttribute(); + [Theory] + [InlineData(null)] // null + [InlineData("hello@world.com\t")] // trailing tab char + [InlineData("\thello@world.com")] // leading tab char + [InlineData("hel\tlo@world.com")] // local-part tab char + [InlineData("hello@world.com\b")] // trailing backspace char + [InlineData("\" \"hello@world.com")] // leading spaces in quotes + [InlineData("hello@world.com\" \"")] // trailing spaces in quotes + [InlineData("hel\" \"lo@world.com")] // local-part spaces in quotes + [InlineData("hello there@world.com")] // unescaped unquoted spaces + [InlineData("Hello ")] // friendly from + [InlineData("")] // wrapped angle brackets + [InlineData("hello(com)there@world.com")] // comment + [InlineData("hello@world.com.")] // trailing period + [InlineData(".hello@world.com")] // leading period + [InlineData("hello@world.com;")] // trailing semicolon + [InlineData(";hello@world.com")] // leading semicolon + [InlineData("hello@world.com; hello@world.com")] // semicolon separated list + [InlineData("hello@world.com, hello@world.com")] // comma separated list + [InlineData("hellothere@worldcom")] // dotless domain + [InlineData("hello.there@worldcom")] // dotless domain + [InlineData("hellothere@.worldcom")] // domain beginning with dot + [InlineData("hellothere@worldcom.")] // domain ending in dot + [InlineData("hellothere@world.com-")] // domain ending in hyphen + [InlineData("hellö@world.com")] // unicode at end of local-part + [InlineData("héllo@world.com")] // unicode in middle of local-part + public void IsValid_ReturnsFalseWhenInvalid(string email) + { + var sut = new StrictEmailAddressAttribute(); - var actual = sut.IsValid(email); + var actual = sut.IsValid(email); - Assert.False(actual); - } + Assert.False(actual); } } diff --git a/test/Core.Test/Utilities/StrictEmailAddressListAttributeTests.cs b/test/Core.Test/Utilities/StrictEmailAddressListAttributeTests.cs index 2f31a75dc..2ec5a4568 100644 --- a/test/Core.Test/Utilities/StrictEmailAddressListAttributeTests.cs +++ b/test/Core.Test/Utilities/StrictEmailAddressListAttributeTests.cs @@ -1,54 +1,53 @@ using Bit.Core.Utilities; using Xunit; -namespace Bit.Core.Test.Utilities +namespace Bit.Core.Test.Utilities; + +public class StrictEmailAddressListAttributeTests { - public class StrictEmailAddressListAttributeTests + public static List EmailList => new() { - public static List EmailList => new() - { - new object[] { new List { "test@domain.com", "test@sub.domain.com", "hello@world.planet.com" }, true }, - new object[] { new List { "/hello@world.com", "hello@##world.pla net.com", "''thello@world.com" }, false }, - new object[] { new List { "/hello.com", "test@domain.com", "''thello@world.com" }, false }, - new object[] { new List { "héllö@world.com", "hello@world.planet.com", "hello@world.planet.com" }, false }, - new object[] { new List { }, false }, - new object[] { new List - { - "test1@domain.com", "test2@domain.com", "test3@domain.com", "test4@domain.com", "test5@domain.com", - "test6@domain.com", "test7@domain.com", "test8@domain.com", "test9@domain.com", "test10@domain.com", - "test11@domain.com", "test12@domain.com", "test13@domain.com", "test14@domain.com", "test15@domain.com", - "test16@domain.com", "test17@domain.com", "test18@domain.com", "test19@domain.com", "test20@domain.com", - "test21@domain.com", "test22@domain.com", "test23@domain.com", "test24@domain.com", "test25@domain.com", - }, false }, - new object[] { new List - { - "test1domaincomtest2domaincomtest3domaincomtest4domaincomtest5domaincomtest6domaincomtest7domaincomtest8domaincomtest9domaincomtest10domaincomtest1domaincomtest2domaincomtest3domaincomtest4domaincomtest5domaincomtest6domaincomtest7domaincomtest8domaincomtest9domaincomtest10domaincom@test.com", - "test@domain.com" - }, false } // > 256 character email + new object[] { new List { "test@domain.com", "test@sub.domain.com", "hello@world.planet.com" }, true }, + new object[] { new List { "/hello@world.com", "hello@##world.pla net.com", "''thello@world.com" }, false }, + new object[] { new List { "/hello.com", "test@domain.com", "''thello@world.com" }, false }, + new object[] { new List { "héllö@world.com", "hello@world.planet.com", "hello@world.planet.com" }, false }, + new object[] { new List { }, false }, + new object[] { new List + { + "test1@domain.com", "test2@domain.com", "test3@domain.com", "test4@domain.com", "test5@domain.com", + "test6@domain.com", "test7@domain.com", "test8@domain.com", "test9@domain.com", "test10@domain.com", + "test11@domain.com", "test12@domain.com", "test13@domain.com", "test14@domain.com", "test15@domain.com", + "test16@domain.com", "test17@domain.com", "test18@domain.com", "test19@domain.com", "test20@domain.com", + "test21@domain.com", "test22@domain.com", "test23@domain.com", "test24@domain.com", "test25@domain.com", + }, false }, + new object[] { new List + { + "test1domaincomtest2domaincomtest3domaincomtest4domaincomtest5domaincomtest6domaincomtest7domaincomtest8domaincomtest9domaincomtest10domaincomtest1domaincomtest2domaincomtest3domaincomtest4domaincomtest5domaincomtest6domaincomtest7domaincomtest8domaincomtest9domaincomtest10domaincom@test.com", + "test@domain.com" + }, false } // > 256 character email - }; + }; - [Theory] - [MemberData(nameof(EmailList))] - public void IsListValid_ReturnsTrue_WhenValid(List emailList, bool valid) - { - var sut = new StrictEmailAddressListAttribute(); + [Theory] + [MemberData(nameof(EmailList))] + public void IsListValid_ReturnsTrue_WhenValid(List emailList, bool valid) + { + var sut = new StrictEmailAddressListAttribute(); - var actual = sut.IsValid(emailList); + var actual = sut.IsValid(emailList); - Assert.Equal(actual, valid); - } + Assert.Equal(actual, valid); + } - [Theory] - [InlineData("single@email.com", false)] - [InlineData(null, false)] - public void IsValid_ReturnsTrue_WhenValid(string email, bool valid) - { - var sut = new StrictEmailAddressListAttribute(); + [Theory] + [InlineData("single@email.com", false)] + [InlineData(null, false)] + public void IsValid_ReturnsTrue_WhenValid(string email, bool valid) + { + var sut = new StrictEmailAddressListAttribute(); - var actual = sut.IsValid(email); + var actual = sut.IsValid(email); - Assert.Equal(actual, valid); - } + Assert.Equal(actual, valid); } } diff --git a/test/Icons.Test/Resources/VerifyResources.cs b/test/Icons.Test/Resources/VerifyResources.cs index ad5d8d681..208bd5077 100644 --- a/test/Icons.Test/Resources/VerifyResources.cs +++ b/test/Icons.Test/Resources/VerifyResources.cs @@ -1,20 +1,19 @@ using Xunit; -namespace Bit.Icons.Test.Resources -{ - public class VerifyResources - { - [Theory] - [InlineData("Bit.Icons.Resources.public_suffix_list.dat")] - public void Resources_FoundAndReadable(string resourceName) - { - var assembly = typeof(Program).Assembly; +namespace Bit.Icons.Test.Resources; - using (var resource = assembly.GetManifestResourceStream(resourceName)) - { - Assert.NotNull(resource); - Assert.True(resource.CanRead); - } +public class VerifyResources +{ + [Theory] + [InlineData("Bit.Icons.Resources.public_suffix_list.dat")] + public void Resources_FoundAndReadable(string resourceName) + { + var assembly = typeof(Program).Assembly; + + using (var resource = assembly.GetManifestResourceStream(resourceName)) + { + Assert.NotNull(resource); + Assert.True(resource.CanRead); } } } diff --git a/test/Icons.Test/Services/IconFetchingServiceTests.cs b/test/Icons.Test/Services/IconFetchingServiceTests.cs index ed317fb62..59f25af24 100644 --- a/test/Icons.Test/Services/IconFetchingServiceTests.cs +++ b/test/Icons.Test/Services/IconFetchingServiceTests.cs @@ -3,49 +3,48 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Xunit; -namespace Bit.Icons.Test.Services +namespace Bit.Icons.Test.Services; + +public class IconFetchingServiceTests { - public class IconFetchingServiceTests + [Theory] + [InlineData("www.google.com")] // https site + [InlineData("neverssl.com")] // http site + [InlineData("ameritrade.com")] + [InlineData("icloud.com")] + [InlineData("bofa.com", Skip = "Broken in pipeline for .NET 6. Tracking link: https://bitwarden.atlassian.net/browse/PS-982")] + public async Task GetIconAsync_Success(string domain) { - [Theory] - [InlineData("www.google.com")] // https site - [InlineData("neverssl.com")] // http site - [InlineData("ameritrade.com")] - [InlineData("icloud.com")] - [InlineData("bofa.com", Skip = "Broken in pipeline for .NET 6. Tracking link: https://bitwarden.atlassian.net/browse/PS-982")] - public async Task GetIconAsync_Success(string domain) + var sut = new IconFetchingService(GetLogger()); + var result = await sut.GetIconAsync(domain); + + Assert.NotNull(result); + Assert.NotNull(result.Icon); + } + + [Theory] + [InlineData("1.1.1.1")] + [InlineData("")] + [InlineData("localhost")] + public async Task GetIconAsync_ReturnsNull(string domain) + { + var sut = new IconFetchingService(GetLogger()); + var result = await sut.GetIconAsync(domain); + + Assert.Null(result); + } + + private static ILogger GetLogger() + { + var services = new ServiceCollection(); + services.AddLogging(b => { - var sut = new IconFetchingService(GetLogger()); - var result = await sut.GetIconAsync(domain); + b.ClearProviders(); + b.AddDebug(); + }); - Assert.NotNull(result); - Assert.NotNull(result.Icon); - } + var provider = services.BuildServiceProvider(); - [Theory] - [InlineData("1.1.1.1")] - [InlineData("")] - [InlineData("localhost")] - public async Task GetIconAsync_ReturnsNull(string domain) - { - var sut = new IconFetchingService(GetLogger()); - var result = await sut.GetIconAsync(domain); - - Assert.Null(result); - } - - private static ILogger GetLogger() - { - var services = new ServiceCollection(); - services.AddLogging(b => - { - b.ClearProviders(); - b.AddDebug(); - }); - - var provider = services.BuildServiceProvider(); - - return provider.GetRequiredService>(); - } + return provider.GetRequiredService>(); } } diff --git a/test/Identity.IntegrationTest/Controllers/AccountsControllerTests.cs b/test/Identity.IntegrationTest/Controllers/AccountsControllerTests.cs index 31cab0e3c..3d03d39a9 100644 --- a/test/Identity.IntegrationTest/Controllers/AccountsControllerTests.cs +++ b/test/Identity.IntegrationTest/Controllers/AccountsControllerTests.cs @@ -3,33 +3,32 @@ using Bit.IntegrationTestCommon.Factories; using Microsoft.EntityFrameworkCore; using Xunit; -namespace Bit.Identity.IntegrationTest.Controllers +namespace Bit.Identity.IntegrationTest.Controllers; + +public class AccountsControllerTests : IClassFixture { - public class AccountsControllerTests : IClassFixture + private readonly IdentityApplicationFactory _factory; + + public AccountsControllerTests(IdentityApplicationFactory factory) { - private readonly IdentityApplicationFactory _factory; + _factory = factory; + } - public AccountsControllerTests(IdentityApplicationFactory factory) + [Fact] + public async Task PostRegister_Success() + { + var context = await _factory.RegisterAsync(new RegisterRequestModel { - _factory = factory; - } + Email = "test+register@email.com", + MasterPasswordHash = "master_password_hash" + }); - [Fact] - public async Task PostRegister_Success() - { - var context = await _factory.RegisterAsync(new RegisterRequestModel - { - Email = "test+register@email.com", - MasterPasswordHash = "master_password_hash" - }); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + var database = _factory.GetDatabaseContext(); + var user = await database.Users + .SingleAsync(u => u.Email == "test+register@email.com"); - var database = _factory.GetDatabaseContext(); - var user = await database.Users - .SingleAsync(u => u.Email == "test+register@email.com"); - - Assert.NotNull(user); - } + Assert.NotNull(user); } } diff --git a/test/Identity.IntegrationTest/Endpoints/IdentityServerTests.cs b/test/Identity.IntegrationTest/Endpoints/IdentityServerTests.cs index 1d7d0dd8d..14600ba26 100644 --- a/test/Identity.IntegrationTest/Endpoints/IdentityServerTests.cs +++ b/test/Identity.IntegrationTest/Endpoints/IdentityServerTests.cs @@ -9,51 +9,425 @@ using Bit.Test.Common.Helpers; using Microsoft.EntityFrameworkCore; using Xunit; -namespace Bit.Identity.IntegrationTest.Endpoints +namespace Bit.Identity.IntegrationTest.Endpoints; + +public class IdentityServerTests : IClassFixture { - public class IdentityServerTests : IClassFixture + private const int SecondsInMinute = 60; + private const int MinutesInHour = 60; + private const int SecondsInHour = SecondsInMinute * MinutesInHour; + private readonly IdentityApplicationFactory _factory; + + public IdentityServerTests(IdentityApplicationFactory factory) { - private const int SecondsInMinute = 60; - private const int MinutesInHour = 60; - private const int SecondsInHour = SecondsInMinute * MinutesInHour; - private readonly IdentityApplicationFactory _factory; + _factory = factory; + } - public IdentityServerTests(IdentityApplicationFactory factory) + [Fact] + public async Task WellKnownEndpoint_Success() + { + var context = await _factory.Server.GetAsync("/.well-known/openid-configuration"); + + using var body = await AssertHelper.AssertResponseTypeIs(context); + var endpointRoot = body.RootElement; + + // WARNING: Edits to this file should NOT just be made to "get the test to work" they should be made when intentional + // changes were made to this endpoint and proper testing will take place to ensure clients are backwards compatible + // or loss of functionality is properly noted. + await using var fs = File.OpenRead("openid-configuration.json"); + using var knownConfiguration = await JsonSerializer.DeserializeAsync(fs); + var knownConfigurationRoot = knownConfiguration.RootElement; + + AssertHelper.AssertEqualJson(endpointRoot, knownConfigurationRoot); + } + + [Fact] + public async Task TokenEndpoint_GrantTypePassword_Success() + { + var deviceId = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; + var username = "test+tokenpassword@email.com"; + + await _factory.RegisterAsync(new RegisterRequestModel { - _factory = factory; + Email = username, + MasterPasswordHash = "master_password_hash" + }); + + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "scope", "api offline_access" }, + { "client_id", "web" }, + { "deviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, + { "deviceIdentifier", deviceId }, + { "deviceName", "firefox" }, + { "grant_type", "password" }, + { "username", username }, + { "password", "master_password_hash" }, + }), context => context.SetAuthEmail(username)); + + using var body = await AssertDefaultTokenBodyAsync(context); + var root = body.RootElement; + AssertRefreshTokenExists(root); + AssertHelper.AssertJsonProperty(root, "ForcePasswordReset", JsonValueKind.False); + AssertHelper.AssertJsonProperty(root, "ResetMasterPassword", JsonValueKind.False); + var kdf = AssertHelper.AssertJsonProperty(root, "Kdf", JsonValueKind.Number).GetInt32(); + Assert.Equal(0, kdf); + var kdfIterations = AssertHelper.AssertJsonProperty(root, "KdfIterations", JsonValueKind.Number).GetInt32(); + Assert.Equal(5000, kdfIterations); + } + + [Fact] + public async Task TokenEndpoint_GrantTypePassword_NoAuthEmailHeader_Fails() + { + var deviceId = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; + var username = "test+noauthemailheader@email.com"; + + await _factory.RegisterAsync(new RegisterRequestModel + { + Email = username, + MasterPasswordHash = "master_password_hash", + }); + + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "scope", "api offline_access" }, + { "client_id", "web" }, + { "deviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, + { "deviceIdentifier", deviceId }, + { "deviceName", "firefox" }, + { "grant_type", "password" }, + { "username", username }, + { "password", "master_password_hash" }, + })); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + + var body = await AssertHelper.AssertResponseTypeIs(context); + var root = body.RootElement; + + var error = AssertHelper.AssertJsonProperty(root, "error", JsonValueKind.String).GetString(); + Assert.Equal("invalid_grant", error); + AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String); + } + + [Fact] + public async Task TokenEndpoint_GrantTypePassword_InvalidBase64AuthEmailHeader_Fails() + { + var deviceId = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; + var username = "test+badauthheader@email.com"; + + await _factory.RegisterAsync(new RegisterRequestModel + { + Email = username, + MasterPasswordHash = "master_password_hash", + }); + + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "scope", "api offline_access" }, + { "client_id", "web" }, + { "deviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, + { "deviceIdentifier", deviceId }, + { "deviceName", "firefox" }, + { "grant_type", "password" }, + { "username", username }, + { "password", "master_password_hash" }, + }), context => context.Request.Headers.Add("Auth-Email", "bad_value")); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + + var body = await AssertHelper.AssertResponseTypeIs(context); + var root = body.RootElement; + + var error = AssertHelper.AssertJsonProperty(root, "error", JsonValueKind.String).GetString(); + Assert.Equal("invalid_grant", error); + AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String); + } + + [Fact] + public async Task TokenEndpoint_GrantTypePassword_WrongAuthEmailHeader_Fails() + { + var deviceId = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; + var username = "test+badauthheader@email.com"; + + await _factory.RegisterAsync(new RegisterRequestModel + { + Email = username, + MasterPasswordHash = "master_password_hash", + }); + + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "scope", "api offline_access" }, + { "client_id", "web" }, + { "deviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, + { "deviceIdentifier", deviceId }, + { "deviceName", "firefox" }, + { "grant_type", "password" }, + { "username", username }, + { "password", "master_password_hash" }, + }), context => context.SetAuthEmail("bad_value")); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + + var body = await AssertHelper.AssertResponseTypeIs(context); + var root = body.RootElement; + + var error = AssertHelper.AssertJsonProperty(root, "error", JsonValueKind.String).GetString(); + Assert.Equal("invalid_grant", error); + AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String); + } + + [Fact] + public async Task TokenEndpoint_GrantTypeRefreshToken_Success() + { + var deviceId = "5a7b19df-0c9d-46bf-a104-8034b5a17182"; + var username = "test+tokenrefresh@email.com"; + + await _factory.RegisterAsync(new RegisterRequestModel + { + Email = username, + MasterPasswordHash = "master_password_hash", + }); + + var (_, refreshToken) = await _factory.TokenFromPasswordAsync(username, "master_password_hash", deviceId); + + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "refresh_token" }, + { "client_id", "web" }, + { "refresh_token", refreshToken }, + })); + + using var body = await AssertDefaultTokenBodyAsync(context); + AssertRefreshTokenExists(body.RootElement); + } + + [Fact] + public async Task TokenEndpoint_GrantTypeClientCredentials_Success() + { + var username = "test+tokenclientcredentials@email.com"; + var deviceId = "8f14a393-edfe-40ba-8c67-a856cb89c509"; + + await _factory.RegisterAsync(new RegisterRequestModel + { + Email = username, + MasterPasswordHash = "master_password_hash", + }); + + var database = _factory.GetDatabaseContext(); + var user = await database.Users + .FirstAsync(u => u.Email == username); + + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "client_id", $"user.{user.Id}" }, + { "client_secret", user.ApiKey }, + { "scope", "api" }, + { "DeviceIdentifier", deviceId }, + { "DeviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, + { "DeviceName", "firefox" }, + })); + + await AssertDefaultTokenBodyAsync(context, "api"); + } + + [Theory, BitAutoData] + public async Task TokenEndpoint_GrantTypeClientCredentials_AsOrganization_Success(Organization organization, OrganizationApiKey organizationApiKey) + { + var orgRepo = _factory.Services.GetRequiredService(); + organization.Enabled = true; + organization.UseApi = true; + organization = await orgRepo.CreateAsync(organization); + organizationApiKey.OrganizationId = organization.Id; + organizationApiKey.Type = OrganizationApiKeyType.Default; + + var orgApiKeyRepo = _factory.Services.GetRequiredService(); + await orgApiKeyRepo.CreateAsync(organizationApiKey); + + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "client_id", $"organization.{organization.Id}" }, + { "client_secret", organizationApiKey.ApiKey }, + { "scope", "api.organization" }, + })); + + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + await AssertDefaultTokenBodyAsync(context, "api.organization"); + } + + [Fact] + public async Task TokenEndpoint_GrantTypeClientCredentials_AsOrganization_BadOrgId_Fails() + { + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "client_id", "organization.bad_guid_zz&" }, + { "client_secret", "something" }, + { "scope", "api.organization" }, + })); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + + var errorBody = await AssertHelper.AssertResponseTypeIs(context); + var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); + Assert.Equal("invalid_client", error); + } + + /// + /// This test currently does not test any code that is not covered by other tests but + /// it shows that we probably have some dead code in + /// for installation, organization, and user they split on a '.' but have already checked that at least one + /// '.' exists in the client_id by checking it with + /// I believe that idParts.Length > 1 will ALWAYS return true + /// + [Fact] + public async Task TokenEndpoint_GrantTypeClientCredentials_AsOrganization_NoIdPart_Fails() + { + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "client_id", "organization." }, + { "client_secret", "something" }, + { "scope", "api.organization" }, + })); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + + var errorBody = await AssertHelper.AssertResponseTypeIs(context); + var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); + Assert.Equal("invalid_client", error); + } + + [Fact] + public async Task TokenEndpoint_GrantTypeClientCredentials_AsOrganization_OrgDoesNotExist_Fails() + { + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "client_id", $"organization.{Guid.NewGuid()}" }, + { "client_secret", "something" }, + { "scope", "api.organization" }, + })); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + + var errorBody = await AssertHelper.AssertResponseTypeIs(context); + var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); + Assert.Equal("invalid_client", error); + } + + [Theory, BitAutoData] + public async Task TokenEndpoint_GrantTypeClientCredentials_AsInstallation_InstallationExists_Succeeds(Installation installation) + { + var installationRepo = _factory.Services.GetRequiredService(); + installation = await installationRepo.CreateAsync(installation); + + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "client_id", $"installation.{installation.Id}" }, + { "client_secret", installation.Key }, + { "scope", "api.push" }, + })); + + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + await AssertDefaultTokenBodyAsync(context, "api.push", 24 * SecondsInHour); + } + + [Fact] + public async Task TokenEndpoint_GrantTypeClientCredentials_AsInstallation_InstallationDoesNotExist_Fails() + { + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "client_id", $"installation.{Guid.NewGuid()}" }, + { "client_secret", "something" }, + { "scope", "api.push" }, + })); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + + var errorBody = await AssertHelper.AssertResponseTypeIs(context); + var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); + Assert.Equal("invalid_client", error); + } + + [Fact] + public async Task TokenEndpoint_GrantTypeClientCredentials_AsInstallation_BadInsallationId_Fails() + { + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "client_id", "organization.bad_guid_zz&" }, + { "client_secret", "something" }, + { "scope", "api.organization" }, + })); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + + var errorBody = await AssertHelper.AssertResponseTypeIs(context); + var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); + Assert.Equal("invalid_client", error); + } + + /// + [Fact] + public async Task TokenEndpoint_GrantTypeClientCredentials_AsInstallation_NoIdPart_Fails() + { + var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + { + { "grant_type", "client_credentials" }, + { "client_id", "installation." }, + { "client_secret", "something" }, + { "scope", "api.push" }, + })); + + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + + var errorBody = await AssertHelper.AssertResponseTypeIs(context); + var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); + Assert.Equal("invalid_client", error); + } + + [Fact] + public async Task TokenEndpoint_ToQuickInOneSecond_BlockRequest() + { + const int AmountInOneSecondAllowed = 5; + + // The rule we are testing is 10 requests in 1 second + var username = "test+ratelimiting@email.com"; + var deviceId = "8f14a393-edfe-40ba-8c67-a856cb89c509"; + + await _factory.RegisterAsync(new RegisterRequestModel + { + Email = username, + MasterPasswordHash = "master_password_hash", + }); + + var database = _factory.GetDatabaseContext(); + var user = await database.Users + .FirstAsync(u => u.Email == username); + + var tasks = new Task[AmountInOneSecondAllowed + 1]; + + for (var i = 0; i < AmountInOneSecondAllowed + 1; i++) + { + // Queue all the amount of calls allowed plus 1 + tasks[i] = MakeRequest(); } - [Fact] - public async Task WellKnownEndpoint_Success() + var responses = (await Task.WhenAll(tasks)).ToList(); + + Assert.Equal(5, responses.Count(c => c.Response.StatusCode == StatusCodes.Status200OK)); + Assert.Equal(1, responses.Count(c => c.Response.StatusCode == StatusCodes.Status429TooManyRequests)); + + Task MakeRequest() { - var context = await _factory.Server.GetAsync("/.well-known/openid-configuration"); - - using var body = await AssertHelper.AssertResponseTypeIs(context); - var endpointRoot = body.RootElement; - - // WARNING: Edits to this file should NOT just be made to "get the test to work" they should be made when intentional - // changes were made to this endpoint and proper testing will take place to ensure clients are backwards compatible - // or loss of functionality is properly noted. - await using var fs = File.OpenRead("openid-configuration.json"); - using var knownConfiguration = await JsonSerializer.DeserializeAsync(fs); - var knownConfigurationRoot = knownConfiguration.RootElement; - - AssertHelper.AssertEqualJson(endpointRoot, knownConfigurationRoot); - } - - [Fact] - public async Task TokenEndpoint_GrantTypePassword_Success() - { - var deviceId = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; - var username = "test+tokenpassword@email.com"; - - await _factory.RegisterAsync(new RegisterRequestModel - { - Email = username, - MasterPasswordHash = "master_password_hash" - }); - - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary + return _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary { { "scope", "api offline_access" }, { "client_id", "web" }, @@ -63,434 +437,59 @@ namespace Bit.Identity.IntegrationTest.Endpoints { "grant_type", "password" }, { "username", username }, { "password", "master_password_hash" }, - }), context => context.SetAuthEmail(username)); - - using var body = await AssertDefaultTokenBodyAsync(context); - var root = body.RootElement; - AssertRefreshTokenExists(root); - AssertHelper.AssertJsonProperty(root, "ForcePasswordReset", JsonValueKind.False); - AssertHelper.AssertJsonProperty(root, "ResetMasterPassword", JsonValueKind.False); - var kdf = AssertHelper.AssertJsonProperty(root, "Kdf", JsonValueKind.Number).GetInt32(); - Assert.Equal(0, kdf); - var kdfIterations = AssertHelper.AssertJsonProperty(root, "KdfIterations", JsonValueKind.Number).GetInt32(); - Assert.Equal(5000, kdfIterations); - } - - [Fact] - public async Task TokenEndpoint_GrantTypePassword_NoAuthEmailHeader_Fails() - { - var deviceId = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; - var username = "test+noauthemailheader@email.com"; - - await _factory.RegisterAsync(new RegisterRequestModel - { - Email = username, - MasterPasswordHash = "master_password_hash", - }); - - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "scope", "api offline_access" }, - { "client_id", "web" }, - { "deviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, - { "deviceIdentifier", deviceId }, - { "deviceName", "firefox" }, - { "grant_type", "password" }, - { "username", username }, - { "password", "master_password_hash" }, - })); - - Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); - - var body = await AssertHelper.AssertResponseTypeIs(context); - var root = body.RootElement; - - var error = AssertHelper.AssertJsonProperty(root, "error", JsonValueKind.String).GetString(); - Assert.Equal("invalid_grant", error); - AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String); - } - - [Fact] - public async Task TokenEndpoint_GrantTypePassword_InvalidBase64AuthEmailHeader_Fails() - { - var deviceId = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; - var username = "test+badauthheader@email.com"; - - await _factory.RegisterAsync(new RegisterRequestModel - { - Email = username, - MasterPasswordHash = "master_password_hash", - }); - - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "scope", "api offline_access" }, - { "client_id", "web" }, - { "deviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, - { "deviceIdentifier", deviceId }, - { "deviceName", "firefox" }, - { "grant_type", "password" }, - { "username", username }, - { "password", "master_password_hash" }, - }), context => context.Request.Headers.Add("Auth-Email", "bad_value")); - - Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); - - var body = await AssertHelper.AssertResponseTypeIs(context); - var root = body.RootElement; - - var error = AssertHelper.AssertJsonProperty(root, "error", JsonValueKind.String).GetString(); - Assert.Equal("invalid_grant", error); - AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String); - } - - [Fact] - public async Task TokenEndpoint_GrantTypePassword_WrongAuthEmailHeader_Fails() - { - var deviceId = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; - var username = "test+badauthheader@email.com"; - - await _factory.RegisterAsync(new RegisterRequestModel - { - Email = username, - MasterPasswordHash = "master_password_hash", - }); - - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "scope", "api offline_access" }, - { "client_id", "web" }, - { "deviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, - { "deviceIdentifier", deviceId }, - { "deviceName", "firefox" }, - { "grant_type", "password" }, - { "username", username }, - { "password", "master_password_hash" }, - }), context => context.SetAuthEmail("bad_value")); - - Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); - - var body = await AssertHelper.AssertResponseTypeIs(context); - var root = body.RootElement; - - var error = AssertHelper.AssertJsonProperty(root, "error", JsonValueKind.String).GetString(); - Assert.Equal("invalid_grant", error); - AssertHelper.AssertJsonProperty(root, "error_description", JsonValueKind.String); - } - - [Fact] - public async Task TokenEndpoint_GrantTypeRefreshToken_Success() - { - var deviceId = "5a7b19df-0c9d-46bf-a104-8034b5a17182"; - var username = "test+tokenrefresh@email.com"; - - await _factory.RegisterAsync(new RegisterRequestModel - { - Email = username, - MasterPasswordHash = "master_password_hash", - }); - - var (_, refreshToken) = await _factory.TokenFromPasswordAsync(username, "master_password_hash", deviceId); - - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "refresh_token" }, - { "client_id", "web" }, - { "refresh_token", refreshToken }, - })); - - using var body = await AssertDefaultTokenBodyAsync(context); - AssertRefreshTokenExists(body.RootElement); - } - - [Fact] - public async Task TokenEndpoint_GrantTypeClientCredentials_Success() - { - var username = "test+tokenclientcredentials@email.com"; - var deviceId = "8f14a393-edfe-40ba-8c67-a856cb89c509"; - - await _factory.RegisterAsync(new RegisterRequestModel - { - Email = username, - MasterPasswordHash = "master_password_hash", - }); - - var database = _factory.GetDatabaseContext(); - var user = await database.Users - .FirstAsync(u => u.Email == username); - - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "client_id", $"user.{user.Id}" }, - { "client_secret", user.ApiKey }, - { "scope", "api" }, - { "DeviceIdentifier", deviceId }, - { "DeviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, - { "DeviceName", "firefox" }, - })); - - await AssertDefaultTokenBodyAsync(context, "api"); - } - - [Theory, BitAutoData] - public async Task TokenEndpoint_GrantTypeClientCredentials_AsOrganization_Success(Organization organization, OrganizationApiKey organizationApiKey) - { - var orgRepo = _factory.Services.GetRequiredService(); - organization.Enabled = true; - organization.UseApi = true; - organization = await orgRepo.CreateAsync(organization); - organizationApiKey.OrganizationId = organization.Id; - organizationApiKey.Type = OrganizationApiKeyType.Default; - - var orgApiKeyRepo = _factory.Services.GetRequiredService(); - await orgApiKeyRepo.CreateAsync(organizationApiKey); - - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "client_id", $"organization.{organization.Id}" }, - { "client_secret", organizationApiKey.ApiKey }, - { "scope", "api.organization" }, - })); - - Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - - await AssertDefaultTokenBodyAsync(context, "api.organization"); - } - - [Fact] - public async Task TokenEndpoint_GrantTypeClientCredentials_AsOrganization_BadOrgId_Fails() - { - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "client_id", "organization.bad_guid_zz&" }, - { "client_secret", "something" }, - { "scope", "api.organization" }, - })); - - Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); - - var errorBody = await AssertHelper.AssertResponseTypeIs(context); - var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); - Assert.Equal("invalid_client", error); - } - - /// - /// This test currently does not test any code that is not covered by other tests but - /// it shows that we probably have some dead code in - /// for installation, organization, and user they split on a '.' but have already checked that at least one - /// '.' exists in the client_id by checking it with - /// I believe that idParts.Length > 1 will ALWAYS return true - /// - [Fact] - public async Task TokenEndpoint_GrantTypeClientCredentials_AsOrganization_NoIdPart_Fails() - { - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "client_id", "organization." }, - { "client_secret", "something" }, - { "scope", "api.organization" }, - })); - - Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); - - var errorBody = await AssertHelper.AssertResponseTypeIs(context); - var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); - Assert.Equal("invalid_client", error); - } - - [Fact] - public async Task TokenEndpoint_GrantTypeClientCredentials_AsOrganization_OrgDoesNotExist_Fails() - { - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "client_id", $"organization.{Guid.NewGuid()}" }, - { "client_secret", "something" }, - { "scope", "api.organization" }, - })); - - Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); - - var errorBody = await AssertHelper.AssertResponseTypeIs(context); - var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); - Assert.Equal("invalid_client", error); - } - - [Theory, BitAutoData] - public async Task TokenEndpoint_GrantTypeClientCredentials_AsInstallation_InstallationExists_Succeeds(Installation installation) - { - var installationRepo = _factory.Services.GetRequiredService(); - installation = await installationRepo.CreateAsync(installation); - - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "client_id", $"installation.{installation.Id}" }, - { "client_secret", installation.Key }, - { "scope", "api.push" }, - })); - - Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - await AssertDefaultTokenBodyAsync(context, "api.push", 24 * SecondsInHour); - } - - [Fact] - public async Task TokenEndpoint_GrantTypeClientCredentials_AsInstallation_InstallationDoesNotExist_Fails() - { - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "client_id", $"installation.{Guid.NewGuid()}" }, - { "client_secret", "something" }, - { "scope", "api.push" }, - })); - - Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); - - var errorBody = await AssertHelper.AssertResponseTypeIs(context); - var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); - Assert.Equal("invalid_client", error); - } - - [Fact] - public async Task TokenEndpoint_GrantTypeClientCredentials_AsInstallation_BadInsallationId_Fails() - { - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "client_id", "organization.bad_guid_zz&" }, - { "client_secret", "something" }, - { "scope", "api.organization" }, - })); - - Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); - - var errorBody = await AssertHelper.AssertResponseTypeIs(context); - var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); - Assert.Equal("invalid_client", error); - } - - /// - [Fact] - public async Task TokenEndpoint_GrantTypeClientCredentials_AsInstallation_NoIdPart_Fails() - { - var context = await _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "grant_type", "client_credentials" }, - { "client_id", "installation." }, - { "client_secret", "something" }, - { "scope", "api.push" }, - })); - - Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); - - var errorBody = await AssertHelper.AssertResponseTypeIs(context); - var error = AssertHelper.AssertJsonProperty(errorBody.RootElement, "error", JsonValueKind.String).GetString(); - Assert.Equal("invalid_client", error); - } - - [Fact] - public async Task TokenEndpoint_ToQuickInOneSecond_BlockRequest() - { - const int AmountInOneSecondAllowed = 5; - - // The rule we are testing is 10 requests in 1 second - var username = "test+ratelimiting@email.com"; - var deviceId = "8f14a393-edfe-40ba-8c67-a856cb89c509"; - - await _factory.RegisterAsync(new RegisterRequestModel - { - Email = username, - MasterPasswordHash = "master_password_hash", - }); - - var database = _factory.GetDatabaseContext(); - var user = await database.Users - .FirstAsync(u => u.Email == username); - - var tasks = new Task[AmountInOneSecondAllowed + 1]; - - for (var i = 0; i < AmountInOneSecondAllowed + 1; i++) - { - // Queue all the amount of calls allowed plus 1 - tasks[i] = MakeRequest(); - } - - var responses = (await Task.WhenAll(tasks)).ToList(); - - Assert.Equal(5, responses.Count(c => c.Response.StatusCode == StatusCodes.Status200OK)); - Assert.Equal(1, responses.Count(c => c.Response.StatusCode == StatusCodes.Status429TooManyRequests)); - - Task MakeRequest() - { - return _factory.Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "scope", "api offline_access" }, - { "client_id", "web" }, - { "deviceType", DeviceTypeAsString(DeviceType.FirefoxBrowser) }, - { "deviceIdentifier", deviceId }, - { "deviceName", "firefox" }, - { "grant_type", "password" }, - { "username", username }, - { "password", "master_password_hash" }, - }), context => context.SetAuthEmail(username).SetIp("1.1.1.2")); - } - } - - private static string DeviceTypeAsString(DeviceType deviceType) - { - return ((int)deviceType).ToString(); - } - - private static async Task AssertDefaultTokenBodyAsync(HttpContext httpContext, string expectedScope = "api offline_access", int expectedExpiresIn = SecondsInHour * 1) - { - var body = await AssertHelper.AssertResponseTypeIs(httpContext); - var root = body.RootElement; - - Assert.Equal(JsonValueKind.Object, root.ValueKind); - AssertAccessTokenExists(root); - AssertExpiresIn(root, expectedExpiresIn); - AssertTokenType(root); - AssertScope(root, expectedScope); - return body; - } - - private static void AssertTokenType(JsonElement tokenResponse) - { - var tokenTypeProperty = AssertHelper.AssertJsonProperty(tokenResponse, "token_type", JsonValueKind.String).GetString(); - Assert.Equal("Bearer", tokenTypeProperty); - } - - private static int AssertExpiresIn(JsonElement tokenResponse, int expectedExpiresIn = 3600) - { - var expiresIn = AssertHelper.AssertJsonProperty(tokenResponse, "expires_in", JsonValueKind.Number).GetInt32(); - Assert.Equal(expectedExpiresIn, expiresIn); - return expiresIn; - } - - private static string AssertAccessTokenExists(JsonElement tokenResponse) - { - return AssertHelper.AssertJsonProperty(tokenResponse, "access_token", JsonValueKind.String).GetString(); - } - - private static string AssertRefreshTokenExists(JsonElement tokenResponse) - { - return AssertHelper.AssertJsonProperty(tokenResponse, "refresh_token", JsonValueKind.String).GetString(); - } - - private static string AssertScopeExists(JsonElement tokenResponse) - { - return AssertHelper.AssertJsonProperty(tokenResponse, "scope", JsonValueKind.String).GetString(); - } - - private static void AssertScope(JsonElement tokenResponse, string expectedScope) - { - var actualScope = AssertScopeExists(tokenResponse); - Assert.Equal(expectedScope, actualScope); + }), context => context.SetAuthEmail(username).SetIp("1.1.1.2")); } } + + private static string DeviceTypeAsString(DeviceType deviceType) + { + return ((int)deviceType).ToString(); + } + + private static async Task AssertDefaultTokenBodyAsync(HttpContext httpContext, string expectedScope = "api offline_access", int expectedExpiresIn = SecondsInHour * 1) + { + var body = await AssertHelper.AssertResponseTypeIs(httpContext); + var root = body.RootElement; + + Assert.Equal(JsonValueKind.Object, root.ValueKind); + AssertAccessTokenExists(root); + AssertExpiresIn(root, expectedExpiresIn); + AssertTokenType(root); + AssertScope(root, expectedScope); + return body; + } + + private static void AssertTokenType(JsonElement tokenResponse) + { + var tokenTypeProperty = AssertHelper.AssertJsonProperty(tokenResponse, "token_type", JsonValueKind.String).GetString(); + Assert.Equal("Bearer", tokenTypeProperty); + } + + private static int AssertExpiresIn(JsonElement tokenResponse, int expectedExpiresIn = 3600) + { + var expiresIn = AssertHelper.AssertJsonProperty(tokenResponse, "expires_in", JsonValueKind.Number).GetInt32(); + Assert.Equal(expectedExpiresIn, expiresIn); + return expiresIn; + } + + private static string AssertAccessTokenExists(JsonElement tokenResponse) + { + return AssertHelper.AssertJsonProperty(tokenResponse, "access_token", JsonValueKind.String).GetString(); + } + + private static string AssertRefreshTokenExists(JsonElement tokenResponse) + { + return AssertHelper.AssertJsonProperty(tokenResponse, "refresh_token", JsonValueKind.String).GetString(); + } + + private static string AssertScopeExists(JsonElement tokenResponse) + { + return AssertHelper.AssertJsonProperty(tokenResponse, "scope", JsonValueKind.String).GetString(); + } + + private static void AssertScope(JsonElement tokenResponse, string expectedScope) + { + var actualScope = AssertScopeExists(tokenResponse); + Assert.Equal(expectedScope, actualScope); + } } diff --git a/test/Identity.Test/Controllers/AccountsControllerTests.cs b/test/Identity.Test/Controllers/AccountsControllerTests.cs index 6fa8f493c..54b585654 100644 --- a/test/Identity.Test/Controllers/AccountsControllerTests.cs +++ b/test/Identity.Test/Controllers/AccountsControllerTests.cs @@ -11,102 +11,101 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Identity.Test.Controllers +namespace Bit.Identity.Test.Controllers; + +public class AccountsControllerTests : IDisposable { - public class AccountsControllerTests : IDisposable + + private readonly AccountsController _sut; + private readonly ILogger _logger; + private readonly IUserRepository _userRepository; + private readonly IUserService _userService; + + public AccountsControllerTests() { + _logger = Substitute.For>(); + _userRepository = Substitute.For(); + _userService = Substitute.For(); + _sut = new AccountsController( + _logger, + _userRepository, + _userService + ); + } - private readonly AccountsController _sut; - private readonly ILogger _logger; - private readonly IUserRepository _userRepository; - private readonly IUserService _userService; + public void Dispose() + { + _sut?.Dispose(); + } - public AccountsControllerTests() + [Fact] + public async Task PostPrelogin_WhenUserExists_ShouldReturnUserKdfInfo() + { + var userKdfInfo = new UserKdfInformation { - _logger = Substitute.For>(); - _userRepository = Substitute.For(); - _userService = Substitute.For(); - _sut = new AccountsController( - _logger, - _userRepository, - _userService - ); - } + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 5000 + }; + _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(Task.FromResult(userKdfInfo)); - public void Dispose() + var response = await _sut.PostPrelogin(new PreloginRequestModel { Email = "user@example.com" }); + + Assert.Equal(userKdfInfo.Kdf, response.Kdf); + Assert.Equal(userKdfInfo.KdfIterations, response.KdfIterations); + } + + [Fact] + public async Task PostPrelogin_WhenUserDoesNotExist_ShouldDefaultToSha256And100000Iterations() + { + _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(Task.FromResult(null!)); + + var response = await _sut.PostPrelogin(new PreloginRequestModel { Email = "user@example.com" }); + + Assert.Equal(KdfType.PBKDF2_SHA256, response.Kdf); + Assert.Equal(100000, response.KdfIterations); + } + + [Fact] + public async Task PostRegister_ShouldRegisterUser() + { + var passwordHash = "abcdef"; + var token = "123456"; + var userGuid = new Guid(); + _userService.RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid) + .Returns(Task.FromResult(IdentityResult.Success)); + var request = new RegisterRequestModel { - _sut?.Dispose(); - } + Name = "Example User", + Email = "user@example.com", + MasterPasswordHash = passwordHash, + MasterPasswordHint = "example", + Token = token, + OrganizationUserId = userGuid + }; - [Fact] - public async Task PostPrelogin_WhenUserExists_ShouldReturnUserKdfInfo() + await _sut.PostRegister(request); + + await _userService.Received(1).RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid); + } + + [Fact] + public async Task PostRegister_WhenUserServiceFails_ShouldThrowBadRequestException() + { + var passwordHash = "abcdef"; + var token = "123456"; + var userGuid = new Guid(); + _userService.RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid) + .Returns(Task.FromResult(IdentityResult.Failed())); + var request = new RegisterRequestModel { - var userKdfInfo = new UserKdfInformation - { - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 5000 - }; - _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(Task.FromResult(userKdfInfo)); + Name = "Example User", + Email = "user@example.com", + MasterPasswordHash = passwordHash, + MasterPasswordHint = "example", + Token = token, + OrganizationUserId = userGuid + }; - var response = await _sut.PostPrelogin(new PreloginRequestModel { Email = "user@example.com" }); - - Assert.Equal(userKdfInfo.Kdf, response.Kdf); - Assert.Equal(userKdfInfo.KdfIterations, response.KdfIterations); - } - - [Fact] - public async Task PostPrelogin_WhenUserDoesNotExist_ShouldDefaultToSha256And100000Iterations() - { - _userRepository.GetKdfInformationByEmailAsync(Arg.Any()).Returns(Task.FromResult(null!)); - - var response = await _sut.PostPrelogin(new PreloginRequestModel { Email = "user@example.com" }); - - Assert.Equal(KdfType.PBKDF2_SHA256, response.Kdf); - Assert.Equal(100000, response.KdfIterations); - } - - [Fact] - public async Task PostRegister_ShouldRegisterUser() - { - var passwordHash = "abcdef"; - var token = "123456"; - var userGuid = new Guid(); - _userService.RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid) - .Returns(Task.FromResult(IdentityResult.Success)); - var request = new RegisterRequestModel - { - Name = "Example User", - Email = "user@example.com", - MasterPasswordHash = passwordHash, - MasterPasswordHint = "example", - Token = token, - OrganizationUserId = userGuid - }; - - await _sut.PostRegister(request); - - await _userService.Received(1).RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid); - } - - [Fact] - public async Task PostRegister_WhenUserServiceFails_ShouldThrowBadRequestException() - { - var passwordHash = "abcdef"; - var token = "123456"; - var userGuid = new Guid(); - _userService.RegisterUserAsync(Arg.Any(), passwordHash, token, userGuid) - .Returns(Task.FromResult(IdentityResult.Failed())); - var request = new RegisterRequestModel - { - Name = "Example User", - Email = "user@example.com", - MasterPasswordHash = passwordHash, - MasterPasswordHint = "example", - Token = token, - OrganizationUserId = userGuid - }; - - await Assert.ThrowsAsync(() => _sut.PostRegister(request)); - } + await Assert.ThrowsAsync(() => _sut.PostRegister(request)); } } diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/CipherFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/CipherFixtures.cs index a027dc240..13d222316 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/CipherFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/CipherFixtures.cs @@ -9,112 +9,111 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class CipherBuilder : ISpecimenBuilder { - internal class CipherBuilder : ISpecimenBuilder + public bool OrganizationOwned { get; set; } + public object Create(object request, ISpecimenContext context) { - public bool OrganizationOwned { get; set; } - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || (type != typeof(Cipher) && type != typeof(List))) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - - if (!OrganizationOwned) - { - fixture.Customize(composer => composer - .Without(c => c.OrganizationId)); - } - - // Can't test valid Favorites and Folders without creating those values inide each test, - // since we won't have any UserIds until the test is running & creating data - fixture.Customize(c => c - .Without(e => e.Favorites) - .Without(e => e.Folders)); - // - var serializerOptions = new JsonSerializerOptions() - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase - }; - - if (type == typeof(Cipher)) - { - var obj = fixture.WithAutoNSubstitutions().Create(); - var cipherData = fixture.WithAutoNSubstitutions().Create(); - var cipherAttachements = fixture.WithAutoNSubstitutions().Create>(); - obj.Data = JsonSerializer.Serialize(cipherData, serializerOptions); - obj.Attachments = JsonSerializer.Serialize(cipherAttachements, serializerOptions); - - return obj; - } - if (type == typeof(List)) - { - var ciphers = fixture.WithAutoNSubstitutions().CreateMany().ToArray(); - for (var i = 0; i < ciphers.Count(); i++) - { - var cipherData = fixture.WithAutoNSubstitutions().Create(); - var cipherAttachements = fixture.WithAutoNSubstitutions().Create>(); - ciphers[i].Data = JsonSerializer.Serialize(cipherData, serializerOptions); - ciphers[i].Attachments = JsonSerializer.Serialize(cipherAttachements, serializerOptions); - } - - return ciphers; - } + throw new ArgumentNullException(nameof(context)); + } + var type = request as Type; + if (type == null || (type != typeof(Cipher) && type != typeof(List))) + { return new NoSpecimen(); } - } - internal class EfCipher : ICustomization - { - public bool OrganizationOwned { get; set; } - public void Customize(IFixture fixture) + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + + if (!OrganizationOwned) { - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new CipherBuilder() - { - OrganizationOwned = OrganizationOwned - }); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new OrganizationUserBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customize(composer => composer + .Without(c => c.OrganizationId)); } - } - internal class EfUserCipherAutoDataAttribute : CustomAutoDataAttribute - { - public EfUserCipherAutoDataAttribute() : base(new SutProviderCustomization(), new EfCipher()) - { } - } - - internal class EfOrganizationCipherAutoDataAttribute : CustomAutoDataAttribute - { - public EfOrganizationCipherAutoDataAttribute() : base(new SutProviderCustomization(), new EfCipher() + // Can't test valid Favorites and Folders without creating those values inide each test, + // since we won't have any UserIds until the test is running & creating data + fixture.Customize(c => c + .Without(e => e.Favorites) + .Without(e => e.Folders)); + // + var serializerOptions = new JsonSerializerOptions() { - OrganizationOwned = true, - }) - { } - } + PropertyNamingPolicy = JsonNamingPolicy.CamelCase + }; - internal class InlineEfCipherAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfCipherAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfCipher) }, values) - { } + if (type == typeof(Cipher)) + { + var obj = fixture.WithAutoNSubstitutions().Create(); + var cipherData = fixture.WithAutoNSubstitutions().Create(); + var cipherAttachements = fixture.WithAutoNSubstitutions().Create>(); + obj.Data = JsonSerializer.Serialize(cipherData, serializerOptions); + obj.Attachments = JsonSerializer.Serialize(cipherAttachements, serializerOptions); + + return obj; + } + if (type == typeof(List)) + { + var ciphers = fixture.WithAutoNSubstitutions().CreateMany().ToArray(); + for (var i = 0; i < ciphers.Count(); i++) + { + var cipherData = fixture.WithAutoNSubstitutions().Create(); + var cipherAttachements = fixture.WithAutoNSubstitutions().Create>(); + ciphers[i].Data = JsonSerializer.Serialize(cipherData, serializerOptions); + ciphers[i].Attachments = JsonSerializer.Serialize(cipherAttachements, serializerOptions); + } + + return ciphers; + } + + return new NoSpecimen(); } } + +internal class EfCipher : ICustomization +{ + public bool OrganizationOwned { get; set; } + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new CipherBuilder() + { + OrganizationOwned = OrganizationOwned + }); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new OrganizationUserBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfUserCipherAutoDataAttribute : CustomAutoDataAttribute +{ + public EfUserCipherAutoDataAttribute() : base(new SutProviderCustomization(), new EfCipher()) + { } +} + +internal class EfOrganizationCipherAutoDataAttribute : CustomAutoDataAttribute +{ + public EfOrganizationCipherAutoDataAttribute() : base(new SutProviderCustomization(), new EfCipher() + { + OrganizationOwned = true, + }) + { } +} + +internal class InlineEfCipherAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfCipherAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfCipher) }, values) + { } +} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/CollectionCipherFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/CollectionCipherFixtures.cs index 873e42439..89ffccb2b 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/CollectionCipherFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/CollectionCipherFixtures.cs @@ -7,57 +7,56 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class CollectionCipherBuilder : ISpecimenBuilder { - internal class CollectionCipherBuilder : ISpecimenBuilder + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != typeof(CollectionCipher)) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + throw new ArgumentNullException(nameof(context)); } - } - internal class EfCollectionCipher : ICustomization - { - public void Customize(IFixture fixture) + var type = request as Type; + if (type == null || type != typeof(CollectionCipher)) { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new CollectionCipherBuilder()); - fixture.Customizations.Add(new CollectionBuilder()); - fixture.Customizations.Add(new CipherBuilder()); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + return new NoSpecimen(); } - } - internal class EfCollectionCipherAutoDataAttribute : CustomAutoDataAttribute - { - public EfCollectionCipherAutoDataAttribute() : base(new SutProviderCustomization(), new EfCollectionCipher()) - { } - } - - internal class InlineEfCollectionCipherAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfCollectionCipherAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfCollectionCipher) }, values) - { } + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; } } + +internal class EfCollectionCipher : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new CollectionCipherBuilder()); + fixture.Customizations.Add(new CollectionBuilder()); + fixture.Customizations.Add(new CipherBuilder()); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfCollectionCipherAutoDataAttribute : CustomAutoDataAttribute +{ + public EfCollectionCipherAutoDataAttribute() : base(new SutProviderCustomization(), new EfCollectionCipher()) + { } +} + +internal class InlineEfCollectionCipherAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfCollectionCipherAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfCollectionCipher) }, values) + { } +} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/CollectionFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/CollectionFixtures.cs index 1d96bccdc..4cb6cfbd4 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/CollectionFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/CollectionFixtures.cs @@ -6,53 +6,52 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class CollectionBuilder : ISpecimenBuilder { - internal class CollectionBuilder : ISpecimenBuilder + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != typeof(Collection)) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + throw new ArgumentNullException(nameof(context)); } - } - internal class EfCollection : ICustomization - { - public void Customize(IFixture fixture) + var type = request as Type; + if (type == null || type != typeof(Collection)) { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new CollectionBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + return new NoSpecimen(); } - } - internal class EfCollectionAutoDataAttribute : CustomAutoDataAttribute - { - public EfCollectionAutoDataAttribute() : base(new SutProviderCustomization(), new EfCollection()) - { } - } - - internal class InlineEfCollectionAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfCollectionAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfCollection) }, values) - { } + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; } } + +internal class EfCollection : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new CollectionBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfCollectionAutoDataAttribute : CustomAutoDataAttribute +{ + public EfCollectionAutoDataAttribute() : base(new SutProviderCustomization(), new EfCollection()) + { } +} + +internal class InlineEfCollectionAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfCollectionAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfCollection) }, values) + { } +} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/DeviceFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/DeviceFixtures.cs index 9100af6a8..da5b5b767 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/DeviceFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/DeviceFixtures.cs @@ -7,54 +7,53 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class DeviceBuilder : ISpecimenBuilder { - internal class DeviceBuilder : ISpecimenBuilder + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != typeof(Device)) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + throw new ArgumentNullException(nameof(context)); } - } - internal class EfDevice : ICustomization - { - public void Customize(IFixture fixture) + var type = request as Type; + if (type == null || type != typeof(Device)) { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new DeviceBuilder()); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + return new NoSpecimen(); } - } - internal class EfDeviceAutoDataAttribute : CustomAutoDataAttribute - { - public EfDeviceAutoDataAttribute() : base(new SutProviderCustomization(), new EfDevice()) - { } - } - - internal class InlineEfDeviceAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfDeviceAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfDevice) }, values) - { } + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; } } +internal class EfDevice : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new DeviceBuilder()); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfDeviceAutoDataAttribute : CustomAutoDataAttribute +{ + public EfDeviceAutoDataAttribute() : base(new SutProviderCustomization(), new EfDevice()) + { } +} + +internal class InlineEfDeviceAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfDeviceAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfDevice) }, values) + { } +} + diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/EmergencyAccessFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/EmergencyAccessFixtures.cs index 82bc25f75..87a8f796c 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/EmergencyAccessFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/EmergencyAccessFixtures.cs @@ -7,55 +7,54 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class EmergencyAccessBuilder : ISpecimenBuilder { - internal class EmergencyAccessBuilder : ISpecimenBuilder + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != typeof(EmergencyAccess)) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - var obj = fixture.Create(); - return obj; + throw new ArgumentNullException(nameof(context)); } - } - internal class EfEmergencyAccess : ICustomization - { - public void Customize(IFixture fixture) + var type = request as Type; + if (type == null || type != typeof(EmergencyAccess)) { - // TODO: Make a base EF Customization with IgnoreVirtualMembers/GlobalSettings/All repos and inherit - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new EmergencyAccessBuilder()); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + return new NoSpecimen(); } - } - internal class EfEmergencyAccessAutoDataAttribute : CustomAutoDataAttribute - { - public EfEmergencyAccessAutoDataAttribute() : base(new SutProviderCustomization(), new EfEmergencyAccess()) - { } - } - - internal class InlineEfEmergencyAccessAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfEmergencyAccessAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfEmergencyAccess) }, values) - { } + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + var obj = fixture.Create(); + return obj; } } +internal class EfEmergencyAccess : ICustomization +{ + public void Customize(IFixture fixture) + { + // TODO: Make a base EF Customization with IgnoreVirtualMembers/GlobalSettings/All repos and inherit + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new EmergencyAccessBuilder()); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfEmergencyAccessAutoDataAttribute : CustomAutoDataAttribute +{ + public EfEmergencyAccessAutoDataAttribute() : base(new SutProviderCustomization(), new EfEmergencyAccess()) + { } +} + +internal class InlineEfEmergencyAccessAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfEmergencyAccessAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfEmergencyAccess) }, values) + { } +} + diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/EntityFrameworkRepositoryFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/EntityFrameworkRepositoryFixtures.cs index 4c83062b6..4a403b70b 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/EntityFrameworkRepositoryFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/EntityFrameworkRepositoryFixtures.cs @@ -10,113 +10,112 @@ using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; using Moq; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class ServiceScopeFactoryBuilder : ISpecimenBuilder { - internal class ServiceScopeFactoryBuilder : ISpecimenBuilder + private DbContextOptions _options { get; set; } + public ServiceScopeFactoryBuilder(DbContextOptions options) { - private DbContextOptions _options { get; set; } - public ServiceScopeFactoryBuilder(DbContextOptions options) - { - _options = options; - } - - public object Create(object request, ISpecimenContext context) - { - var fixture = new Fixture(); - var serviceProvider = new Mock(); - var dbContext = new DatabaseContext(_options); - serviceProvider - .Setup(x => x.GetService(typeof(DatabaseContext))) - .Returns(dbContext); - - var serviceScope = new Mock(); - serviceScope.Setup(x => x.ServiceProvider).Returns(serviceProvider.Object); - - var serviceScopeFactory = new Mock(); - serviceScopeFactory - .Setup(x => x.CreateScope()) - .Returns(serviceScope.Object); - return serviceScopeFactory.Object; - } + _options = options; } - public class EfRepositoryListBuilder : ISpecimenBuilder where T : BaseEntityFrameworkRepository + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) - { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } + var fixture = new Fixture(); + var serviceProvider = new Mock(); + var dbContext = new DatabaseContext(_options); + serviceProvider + .Setup(x => x.GetService(typeof(DatabaseContext))) + .Returns(dbContext); - var t = request as ParameterInfo; - if (t == null || t.ParameterType != typeof(List)) - { - return new NoSpecimen(); - } + var serviceScope = new Mock(); + serviceScope.Setup(x => x.ServiceProvider).Returns(serviceProvider.Object); - var list = new List(); - foreach (var option in DatabaseOptionsFactory.Options) - { - var fixture = new Fixture(); - fixture.Customize(x => x.FromFactory(new ServiceScopeFactoryBuilder(option))); - fixture.Customize(x => x.FromFactory(() => - new MapperConfiguration(cfg => - { - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - cfg.AddProfile(); - }) - .CreateMapper())); - - var repo = fixture.Create(); - list.Add(repo); - } - return list; - } - } - - public class IgnoreVirtualMembersCustomization : ISpecimenBuilder - { - public object Create(object request, ISpecimenContext context) - { - if (context == null) - { - throw new ArgumentNullException("context"); - } - - var pi = request as PropertyInfo; - if (pi == null) - { - return new NoSpecimen(); - } - - if (pi.GetGetMethod().IsVirtual && pi.DeclaringType != typeof(GlobalSettings)) - { - return null; - } - return new NoSpecimen(); - } + var serviceScopeFactory = new Mock(); + serviceScopeFactory + .Setup(x => x.CreateScope()) + .Returns(serviceScope.Object); + return serviceScopeFactory.Object; + } +} + +public class EfRepositoryListBuilder : ISpecimenBuilder where T : BaseEntityFrameworkRepository +{ + public object Create(object request, ISpecimenContext context) + { + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + var t = request as ParameterInfo; + if (t == null || t.ParameterType != typeof(List)) + { + return new NoSpecimen(); + } + + var list = new List(); + foreach (var option in DatabaseOptionsFactory.Options) + { + var fixture = new Fixture(); + fixture.Customize(x => x.FromFactory(new ServiceScopeFactoryBuilder(option))); + fixture.Customize(x => x.FromFactory(() => + new MapperConfiguration(cfg => + { + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + cfg.AddProfile(); + }) + .CreateMapper())); + + var repo = fixture.Create(); + list.Add(repo); + } + return list; + } +} + +public class IgnoreVirtualMembersCustomization : ISpecimenBuilder +{ + public object Create(object request, ISpecimenContext context) + { + if (context == null) + { + throw new ArgumentNullException("context"); + } + + var pi = request as PropertyInfo; + if (pi == null) + { + return new NoSpecimen(); + } + + if (pi.GetGetMethod().IsVirtual && pi.DeclaringType != typeof(GlobalSettings)) + { + return null; + } + return new NoSpecimen(); } } diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/EventFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/EventFixtures.cs index ecb4f0ef9..70b2e9bc9 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/EventFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/EventFixtures.cs @@ -6,52 +6,51 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class EventBuilder : ISpecimenBuilder { - internal class EventBuilder : ISpecimenBuilder + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != typeof(Event)) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + throw new ArgumentNullException(nameof(context)); } - } - internal class EfEvent : ICustomization - { - public void Customize(IFixture fixture) + var type = request as Type; + if (type == null || type != typeof(Event)) { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new EventBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + return new NoSpecimen(); } - } - internal class EfEventAutoDataAttribute : CustomAutoDataAttribute - { - public EfEventAutoDataAttribute() : base(new SutProviderCustomization(), new EfEvent()) - { } - } - - internal class InlineEfEventAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfEventAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfEvent) }, values) - { } + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; } } +internal class EfEvent : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new EventBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfEventAutoDataAttribute : CustomAutoDataAttribute +{ + public EfEventAutoDataAttribute() : base(new SutProviderCustomization(), new EfEvent()) + { } +} + +internal class InlineEfEventAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfEventAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfEvent) }, values) + { } +} + diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/FolderFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/FolderFixtures.cs index 290fffb60..884933ffd 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/FolderFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/FolderFixtures.cs @@ -7,54 +7,53 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class FolderBuilder : ISpecimenBuilder { - internal class FolderBuilder : ISpecimenBuilder + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != typeof(Folder)) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + throw new ArgumentNullException(nameof(context)); } - } - internal class EfFolder : ICustomization - { - public void Customize(IFixture fixture) + var type = request as Type; + if (type == null || type != typeof(Folder)) { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new FolderBuilder()); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + return new NoSpecimen(); } - } - internal class EfFolderAutoDataAttribute : CustomAutoDataAttribute - { - public EfFolderAutoDataAttribute() : base(new SutProviderCustomization(), new EfFolder()) - { } - } - - internal class InlineEfFolderAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfFolderAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfFolder) }, values) - { } + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; } } +internal class EfFolder : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new FolderBuilder()); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfFolderAutoDataAttribute : CustomAutoDataAttribute +{ + public EfFolderAutoDataAttribute() : base(new SutProviderCustomization(), new EfFolder()) + { } +} + +internal class InlineEfFolderAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfFolderAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfFolder) }, values) + { } +} + diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/GrantFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/GrantFixtures.cs index 7824426bb..d431132de 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/GrantFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/GrantFixtures.cs @@ -6,51 +6,50 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class GrantBuilder : ISpecimenBuilder { - internal class GrantBuilder : ISpecimenBuilder + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != typeof(Grant)) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + throw new ArgumentNullException(nameof(context)); } - } - internal class EfGrant : ICustomization - { - public void Customize(IFixture fixture) + var type = request as Type; + if (type == null || type != typeof(Grant)) { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new GrantBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + return new NoSpecimen(); } - } - internal class EfGrantAutoDataAttribute : CustomAutoDataAttribute - { - public EfGrantAutoDataAttribute() : base(new SutProviderCustomization(), new EfGrant()) - { } - } - - internal class InlineEfGrantAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfGrantAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfGrant) }, values) - { } + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; } } + +internal class EfGrant : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new GrantBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfGrantAutoDataAttribute : CustomAutoDataAttribute +{ + public EfGrantAutoDataAttribute() : base(new SutProviderCustomization(), new EfGrant()) + { } +} + +internal class InlineEfGrantAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfGrantAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfGrant) }, values) + { } +} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/GroupFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/GroupFixtures.cs index cfb232ab1..c6cca4901 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/GroupFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/GroupFixtures.cs @@ -6,53 +6,52 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class GroupBuilder : ISpecimenBuilder { - internal class GroupBuilder : ISpecimenBuilder + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != typeof(Group)) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + throw new ArgumentNullException(nameof(context)); } - } - internal class EfGroup : ICustomization - { - public void Customize(IFixture fixture) + var type = request as Type; + if (type == null || type != typeof(Group)) { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new GroupBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + return new NoSpecimen(); } - } - internal class EfGroupAutoDataAttribute : CustomAutoDataAttribute - { - public EfGroupAutoDataAttribute() : base(new SutProviderCustomization(), new EfGroup()) - { } - } - - internal class InlineEfGroupAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfGroupAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfGroup) }, values) - { } + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; } } + +internal class EfGroup : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new GroupBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfGroupAutoDataAttribute : CustomAutoDataAttribute +{ + public EfGroupAutoDataAttribute() : base(new SutProviderCustomization(), new EfGroup()) + { } +} + +internal class InlineEfGroupAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfGroupAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfGroup) }, values) + { } +} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/GroupUserFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/GroupUserFixtures.cs index d7303b59c..2b68cde32 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/GroupUserFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/GroupUserFixtures.cs @@ -5,51 +5,50 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class GroupUserBuilder : ISpecimenBuilder { - internal class GroupUserBuilder : ISpecimenBuilder + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != typeof(GroupUser)) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + throw new ArgumentNullException(nameof(context)); } - } - internal class EfGroupUser : ICustomization - { - public void Customize(IFixture fixture) + var type = request as Type; + if (type == null || type != typeof(GroupUser)) { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new GroupUserBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + return new NoSpecimen(); } - } - internal class EfGroupUserAutoDataAttribute : CustomAutoDataAttribute - { - public EfGroupUserAutoDataAttribute() : base(new SutProviderCustomization(), new EfGroupUser()) - { } - } - - internal class InlineEfGroupUserAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfGroupUserAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfGroupUser) }, values) - { } + var fixture = new Fixture(); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; } } +internal class EfGroupUser : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new GroupUserBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfGroupUserAutoDataAttribute : CustomAutoDataAttribute +{ + public EfGroupUserAutoDataAttribute() : base(new SutProviderCustomization(), new EfGroupUser()) + { } +} + +internal class InlineEfGroupUserAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfGroupUserAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfGroupUser) }, values) + { } +} + diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/InstallationFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/InstallationFixtures.cs index 1a8c54627..c090a2e38 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/InstallationFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/InstallationFixtures.cs @@ -5,51 +5,50 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class InstallationBuilder : ISpecimenBuilder { - internal class InstallationBuilder : ISpecimenBuilder + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != typeof(Installation)) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + throw new ArgumentNullException(nameof(context)); } - } - internal class EfInstallation : ICustomization - { - public void Customize(IFixture fixture) + var type = request as Type; + if (type == null || type != typeof(Installation)) { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new InstallationBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + return new NoSpecimen(); } - } - internal class EfInstallationAutoDataAttribute : CustomAutoDataAttribute - { - public EfInstallationAutoDataAttribute() : base(new SutProviderCustomization(), new EfInstallation()) - { } - } - - internal class InlineEfInstallationAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfInstallationAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfInstallation) }, values) - { } + var fixture = new Fixture(); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; } } +internal class EfInstallation : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new InstallationBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfInstallationAutoDataAttribute : CustomAutoDataAttribute +{ + public EfInstallationAutoDataAttribute() : base(new SutProviderCustomization(), new EfInstallation()) + { } +} + +internal class InlineEfInstallationAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfInstallationAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfInstallation) }, values) + { } +} + diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationFixtures.cs index f09760390..800ee14d2 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationFixtures.cs @@ -7,52 +7,51 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class OrganizationBuilder : ISpecimenBuilder { - internal class OrganizationBuilder : ISpecimenBuilder + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != typeof(Organization)) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - var providers = fixture.Create>(); - var organization = new Fixture().WithAutoNSubstitutions().Create(); - organization.SetTwoFactorProviders(providers); - return organization; + throw new ArgumentNullException(nameof(context)); } - } - internal class EfOrganization : ICustomization - { - public void Customize(IFixture fixture) + var type = request as Type; + if (type == null || type != typeof(Organization)) { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + return new NoSpecimen(); } - } - internal class EfOrganizationAutoDataAttribute : CustomAutoDataAttribute - { - public EfOrganizationAutoDataAttribute() : base(new SutProviderCustomization(), new EfOrganization()) - { } - } - - internal class InlineEfOrganizationAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfOrganizationAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfOrganization) }, values) - { } + var fixture = new Fixture(); + var providers = fixture.Create>(); + var organization = new Fixture().WithAutoNSubstitutions().Create(); + organization.SetTwoFactorProviders(providers); + return organization; } } + +internal class EfOrganization : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfOrganizationAutoDataAttribute : CustomAutoDataAttribute +{ + public EfOrganizationAutoDataAttribute() : base(new SutProviderCustomization(), new EfOrganization()) + { } +} + +internal class InlineEfOrganizationAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfOrganizationAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfOrganization) }, values) + { } +} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationSponsorshipFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationSponsorshipFixtures.cs index c4b97ad4e..ede2d2129 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationSponsorshipFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationSponsorshipFixtures.cs @@ -5,52 +5,51 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class OrganizationSponsorshipBuilder : ISpecimenBuilder { - internal class OrganizationSponsorshipBuilder : ISpecimenBuilder + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != typeof(OrganizationSponsorship)) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + throw new ArgumentNullException(nameof(context)); } - } - internal class EfOrganizationSponsorship : ICustomization - { - public void Customize(IFixture fixture) + var type = request as Type; + if (type == null || type != typeof(OrganizationSponsorship)) { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new OrganizationSponsorshipBuilder()); - fixture.Customizations.Add(new OrganizationUserBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + return new NoSpecimen(); } - } - internal class EfOrganizationSponsorshipAutoDataAttribute : CustomAutoDataAttribute - { - public EfOrganizationSponsorshipAutoDataAttribute() : base(new SutProviderCustomization(), new EfOrganizationSponsorship(), new EfOrganization()) - { } - } - - internal class InlineEfOrganizationSponsorshipAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfOrganizationSponsorshipAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfOrganizationSponsorship), typeof(EfOrganization) }, values) - { } + var fixture = new Fixture(); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; } } + +internal class EfOrganizationSponsorship : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new OrganizationSponsorshipBuilder()); + fixture.Customizations.Add(new OrganizationUserBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfOrganizationSponsorshipAutoDataAttribute : CustomAutoDataAttribute +{ + public EfOrganizationSponsorshipAutoDataAttribute() : base(new SutProviderCustomization(), new EfOrganizationSponsorship(), new EfOrganization()) + { } +} + +internal class InlineEfOrganizationSponsorshipAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfOrganizationSponsorshipAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfOrganizationSponsorship), typeof(EfOrganization) }, values) + { } +} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationUserFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationUserFixtures.cs index 1ae72117e..c457a463d 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationUserFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/OrganizationUserFixtures.cs @@ -11,73 +11,72 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture -{ - internal class OrganizationUserBuilder : ISpecimenBuilder - { - public object Create(object request, ISpecimenContext context) - { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; - var type = request as Type; - if (type == typeof(OrganizationUserCustomization)) +internal class OrganizationUserBuilder : ISpecimenBuilder +{ + public object Create(object request, ISpecimenContext context) + { + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + var type = request as Type; + if (type == typeof(OrganizationUserCustomization)) + { + var fixture = new Fixture(); + var orgUser = fixture.WithAutoNSubstitutions().Create(); + var orgUserPermissions = fixture.WithAutoNSubstitutions().Create(); + orgUser.Permissions = JsonSerializer.Serialize(orgUserPermissions, new JsonSerializerOptions() { - var fixture = new Fixture(); - var orgUser = fixture.WithAutoNSubstitutions().Create(); + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }); + return orgUser; + } + else if (type == typeof(List)) + { + var fixture = new Fixture(); + var orgUsers = fixture.WithAutoNSubstitutions().CreateMany(2); + foreach (var orgUser in orgUsers) + { + var providers = fixture.Create>(); var orgUserPermissions = fixture.WithAutoNSubstitutions().Create(); orgUser.Permissions = JsonSerializer.Serialize(orgUserPermissions, new JsonSerializerOptions() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase, }); - return orgUser; } - else if (type == typeof(List)) - { - var fixture = new Fixture(); - var orgUsers = fixture.WithAutoNSubstitutions().CreateMany(2); - foreach (var orgUser in orgUsers) - { - var providers = fixture.Create>(); - var orgUserPermissions = fixture.WithAutoNSubstitutions().Create(); - orgUser.Permissions = JsonSerializer.Serialize(orgUserPermissions, new JsonSerializerOptions() - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }); - } - return orgUsers; - } - return new NoSpecimen(); + return orgUsers; } - } - - internal class EfOrganizationUser : ICustomization - { - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new OrganizationUserBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } - } - - internal class EfOrganizationUserAutoDataAttribute : CustomAutoDataAttribute - { - public EfOrganizationUserAutoDataAttribute() : base(new SutProviderCustomization(), new EfOrganizationUser()) - { } - } - - internal class InlineEfOrganizationUserAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfOrganizationUserAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfOrganizationUser) }, values) - { } + return new NoSpecimen(); } } + +internal class EfOrganizationUser : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new OrganizationUserBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfOrganizationUserAutoDataAttribute : CustomAutoDataAttribute +{ + public EfOrganizationUserAutoDataAttribute() : base(new SutProviderCustomization(), new EfOrganizationUser()) + { } +} + +internal class InlineEfOrganizationUserAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfOrganizationUserAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfOrganizationUser) }, values) + { } +} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/PolicyFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/PolicyFixtures.cs index 0b6424d54..70cea3e01 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/PolicyFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/PolicyFixtures.cs @@ -5,76 +5,75 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class PolicyBuilder : ISpecimenBuilder { - internal class PolicyBuilder : ISpecimenBuilder + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != typeof(Policy)) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + throw new ArgumentNullException(nameof(context)); } - } - internal class EfPolicy : ICustomization - { - public void Customize(IFixture fixture) + var type = request as Type; + if (type == null || type != typeof(Policy)) { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new PolicyBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + return new NoSpecimen(); } - } - internal class EfPolicyApplicableToUser : ICustomization - { - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new PolicyBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } - } - - internal class EfPolicyAutoDataAttribute : CustomAutoDataAttribute - { - public EfPolicyAutoDataAttribute() : base(new SutProviderCustomization(), new EfPolicy()) - { } - } - - internal class EfPolicyApplicableToUserInlineAutoDataAttribute : InlineCustomAutoDataAttribute - { - public EfPolicyApplicableToUserInlineAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), typeof(EfPolicyApplicableToUser) }, values) - { } - } - - internal class InlineEfPolicyAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfPolicyAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfPolicy) }, values) - { } + var fixture = new Fixture(); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; } } + +internal class EfPolicy : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new PolicyBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfPolicyApplicableToUser : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new PolicyBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfPolicyAutoDataAttribute : CustomAutoDataAttribute +{ + public EfPolicyAutoDataAttribute() : base(new SutProviderCustomization(), new EfPolicy()) + { } +} + +internal class EfPolicyApplicableToUserInlineAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public EfPolicyApplicableToUserInlineAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), typeof(EfPolicyApplicableToUser) }, values) + { } +} + +internal class InlineEfPolicyAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfPolicyAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfPolicy) }, values) + { } +} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/Relays/MaxLengthStringRelay.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/Relays/MaxLengthStringRelay.cs index e2a3812cc..75f03e34b 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/Relays/MaxLengthStringRelay.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/Relays/MaxLengthStringRelay.cs @@ -2,40 +2,39 @@ using System.Reflection; using AutoFixture.Kernel; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture.Relays +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture.Relays; + +// Creates a string the same length as any availible MaxLength data annotation +// Modified version of the StringLenfthRelay provided by AutoFixture +// https://github.com/AutoFixture/AutoFixture/blob/master/Src/AutoFixture/DataAnnotations/StringLengthAttributeRelay.cs +public class MaxLengthStringRelay : ISpecimenBuilder { - // Creates a string the same length as any availible MaxLength data annotation - // Modified version of the StringLenfthRelay provided by AutoFixture - // https://github.com/AutoFixture/AutoFixture/blob/master/Src/AutoFixture/DataAnnotations/StringLengthAttributeRelay.cs - public class MaxLengthStringRelay : ISpecimenBuilder + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) + if (request == null) { - if (request == null) - { - return new NoSpecimen(); - } - - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var p = request as PropertyInfo; - if (p == null) - { - return new NoSpecimen(); - } - - var a = (MaxLengthAttribute)p.GetCustomAttributes(typeof(MaxLengthAttribute), false).SingleOrDefault(); - - if (a == null) - { - return new NoSpecimen(); - } - - return context.Resolve(new ConstrainedStringRequest(a.Length, a.Length)); + return new NoSpecimen(); } + + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + var p = request as PropertyInfo; + if (p == null) + { + return new NoSpecimen(); + } + + var a = (MaxLengthAttribute)p.GetCustomAttributes(typeof(MaxLengthAttribute), false).SingleOrDefault(); + + if (a == null) + { + return new NoSpecimen(); + } + + return context.Resolve(new ConstrainedStringRequest(a.Length, a.Length)); } } diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/SendFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/SendFixtures.cs index 222ea4ac0..162bdf6e5 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/SendFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/SendFixtures.cs @@ -7,64 +7,63 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class SendBuilder : ISpecimenBuilder { - internal class SendBuilder : ISpecimenBuilder + public bool OrganizationOwned { get; set; } + public object Create(object request, ISpecimenContext context) { - public bool OrganizationOwned { get; set; } - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != typeof(Send)) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - if (!OrganizationOwned) - { - fixture.Customize(composer => composer - .Without(c => c.OrganizationId)); - } - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + throw new ArgumentNullException(nameof(context)); } - } - internal class EfSend : ICustomization - { - public bool OrganizationOwned { get; set; } - public void Customize(IFixture fixture) + var type = request as Type; + if (type == null || type != typeof(Send)) { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new SendBuilder()); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + return new NoSpecimen(); } - } - internal class EfUserSendAutoDataAttribute : CustomAutoDataAttribute - { - public EfUserSendAutoDataAttribute() : base(new SutProviderCustomization(), new EfSend()) - { } - } - - internal class EfOrganizationSendAutoDataAttribute : CustomAutoDataAttribute - { - public EfOrganizationSendAutoDataAttribute() : base(new SutProviderCustomization(), new EfSend() + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + if (!OrganizationOwned) { - OrganizationOwned = true, - }) - { } + fixture.Customize(composer => composer + .Without(c => c.OrganizationId)); + } + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; } } + +internal class EfSend : ICustomization +{ + public bool OrganizationOwned { get; set; } + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new SendBuilder()); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfUserSendAutoDataAttribute : CustomAutoDataAttribute +{ + public EfUserSendAutoDataAttribute() : base(new SutProviderCustomization(), new EfSend()) + { } +} + +internal class EfOrganizationSendAutoDataAttribute : CustomAutoDataAttribute +{ + public EfOrganizationSendAutoDataAttribute() : base(new SutProviderCustomization(), new EfSend() + { + OrganizationOwned = true, + }) + { } +} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/SsoConfigFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/SsoConfigFixtures.cs index 83f3064f3..4cad2154f 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/SsoConfigFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/SsoConfigFixtures.cs @@ -6,54 +6,53 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class SsoConfigBuilder : ISpecimenBuilder { - internal class SsoConfigBuilder : ISpecimenBuilder + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != typeof(SsoConfig)) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - var ssoConfig = fixture.WithAutoNSubstitutions().Create(); - var ssoConfigData = fixture.WithAutoNSubstitutions().Create(); - ssoConfig.SetData(ssoConfigData); - return ssoConfig; + throw new ArgumentNullException(nameof(context)); } - } - internal class EfSsoConfig : ICustomization - { - public void Customize(IFixture fixture) + var type = request as Type; + if (type == null || type != typeof(SsoConfig)) { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new SsoConfigBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + return new NoSpecimen(); } - } - internal class EfSsoConfigAutoDataAttribute : CustomAutoDataAttribute - { - public EfSsoConfigAutoDataAttribute() : base(new SutProviderCustomization(), new EfSsoConfig()) - { } - } - - internal class InlineEfSsoConfigAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfSsoConfigAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfSsoConfig) }, values) - { } + var fixture = new Fixture(); + var ssoConfig = fixture.WithAutoNSubstitutions().Create(); + var ssoConfigData = fixture.WithAutoNSubstitutions().Create(); + ssoConfig.SetData(ssoConfigData); + return ssoConfig; } } + +internal class EfSsoConfig : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new SsoConfigBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfSsoConfigAutoDataAttribute : CustomAutoDataAttribute +{ + public EfSsoConfigAutoDataAttribute() : base(new SutProviderCustomization(), new EfSsoConfig()) + { } +} + +internal class InlineEfSsoConfigAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfSsoConfigAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfSsoConfig) }, values) + { } +} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/SsoUserFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/SsoUserFixtures.cs index 32b6ddf24..f2712e018 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/SsoUserFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/SsoUserFixtures.cs @@ -5,33 +5,32 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class EfSsoUser : ICustomization { - internal class EfSsoUser : ICustomization + public void Customize(IFixture fixture) { - public void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customize(composer => composer.Without(ou => ou.Id)); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } - } - - internal class EfSsoUserAutoDataAttribute : CustomAutoDataAttribute - { - public EfSsoUserAutoDataAttribute() : base(new SutProviderCustomization(), new EfSsoUser()) - { } - } - - internal class InlineEfSsoUserAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfSsoUserAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfSsoUser) }, values) - { } + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customize(composer => composer.Without(ou => ou.Id)); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); } } + +internal class EfSsoUserAutoDataAttribute : CustomAutoDataAttribute +{ + public EfSsoUserAutoDataAttribute() : base(new SutProviderCustomization(), new EfSsoUser()) + { } +} + +internal class InlineEfSsoUserAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfSsoUserAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfSsoUser) }, values) + { } +} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/TaxRateFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/TaxRateFixtures.cs index b22c6d8c2..c8cd8c692 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/TaxRateFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/TaxRateFixtures.cs @@ -6,52 +6,51 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class TaxRateBuilder : ISpecimenBuilder { - internal class TaxRateBuilder : ISpecimenBuilder + public object Create(object request, ISpecimenContext context) { - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != typeof(TaxRate)) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - fixture.Customizations.Insert(0, new MaxLengthStringRelay()); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + throw new ArgumentNullException(nameof(context)); } - } - internal class EfTaxRate : ICustomization - { - public void Customize(IFixture fixture) + var type = request as Type; + if (type == null || type != typeof(TaxRate)) { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new TaxRateBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + return new NoSpecimen(); } - } - internal class EfTaxRateAutoDataAttribute : CustomAutoDataAttribute - { - public EfTaxRateAutoDataAttribute() : base(new SutProviderCustomization(), new EfTaxRate()) - { } - } - - internal class InlineEfTaxRateAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfTaxRateAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfTaxRate) }, values) - { } + var fixture = new Fixture(); + fixture.Customizations.Insert(0, new MaxLengthStringRelay()); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; } } +internal class EfTaxRate : ICustomization +{ + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new TaxRateBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfTaxRateAutoDataAttribute : CustomAutoDataAttribute +{ + public EfTaxRateAutoDataAttribute() : base(new SutProviderCustomization(), new EfTaxRate()) + { } +} + +internal class InlineEfTaxRateAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfTaxRateAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfTaxRate) }, values) + { } +} + diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/TransactionFixutres.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/TransactionFixutres.cs index 437cdcd2a..7dbe42fc1 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/TransactionFixutres.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/TransactionFixutres.cs @@ -7,64 +7,63 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class TransactionBuilder : ISpecimenBuilder { - internal class TransactionBuilder : ISpecimenBuilder + public bool OrganizationOwned { get; set; } + public object Create(object request, ISpecimenContext context) { - public bool OrganizationOwned { get; set; } - public object Create(object request, ISpecimenContext context) + if (context == null) { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var type = request as Type; - if (type == null || type != typeof(Transaction)) - { - return new NoSpecimen(); - } - - var fixture = new Fixture(); - if (!OrganizationOwned) - { - fixture.Customize(composer => composer - .Without(c => c.OrganizationId)); - } - fixture.Customizations.Add(new MaxLengthStringRelay()); - var obj = fixture.WithAutoNSubstitutions().Create(); - return obj; + throw new ArgumentNullException(nameof(context)); } - } - internal class EfTransaction : ICustomization - { - public bool OrganizationOwned { get; set; } - public void Customize(IFixture fixture) + var type = request as Type; + if (type == null || type != typeof(Transaction)) { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - fixture.Customizations.Add(new GlobalSettingsBuilder()); - fixture.Customizations.Add(new TransactionBuilder()); - fixture.Customizations.Add(new UserBuilder()); - fixture.Customizations.Add(new OrganizationBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); + return new NoSpecimen(); } - } - internal class EfUserTransactionAutoDataAttribute : CustomAutoDataAttribute - { - public EfUserTransactionAutoDataAttribute() : base(new SutProviderCustomization(), new EfTransaction()) - { } - } - - internal class EfOrganizationTransactionAutoDataAttribute : CustomAutoDataAttribute - { - public EfOrganizationTransactionAutoDataAttribute() : base(new SutProviderCustomization(), new EfTransaction() + var fixture = new Fixture(); + if (!OrganizationOwned) { - OrganizationOwned = true, - }) - { } + fixture.Customize(composer => composer + .Without(c => c.OrganizationId)); + } + fixture.Customizations.Add(new MaxLengthStringRelay()); + var obj = fixture.WithAutoNSubstitutions().Create(); + return obj; } } + +internal class EfTransaction : ICustomization +{ + public bool OrganizationOwned { get; set; } + public void Customize(IFixture fixture) + { + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + fixture.Customizations.Add(new GlobalSettingsBuilder()); + fixture.Customizations.Add(new TransactionBuilder()); + fixture.Customizations.Add(new UserBuilder()); + fixture.Customizations.Add(new OrganizationBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + } +} + +internal class EfUserTransactionAutoDataAttribute : CustomAutoDataAttribute +{ + public EfUserTransactionAutoDataAttribute() : base(new SutProviderCustomization(), new EfTransaction()) + { } +} + +internal class EfOrganizationTransactionAutoDataAttribute : CustomAutoDataAttribute +{ + public EfOrganizationTransactionAutoDataAttribute() : base(new SutProviderCustomization(), new EfTransaction() + { + OrganizationOwned = true, + }) + { } +} diff --git a/test/Infrastructure.EFIntegration.Test/AutoFixture/UserFixtures.cs b/test/Infrastructure.EFIntegration.Test/AutoFixture/UserFixtures.cs index f54b7b758..98222e8f3 100644 --- a/test/Infrastructure.EFIntegration.Test/AutoFixture/UserFixtures.cs +++ b/test/Infrastructure.EFIntegration.Test/AutoFixture/UserFixtures.cs @@ -3,30 +3,29 @@ using Bit.Core.Test.AutoFixture.UserFixtures; using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Test.Common.AutoFixture.Attributes; -namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture +namespace Bit.Infrastructure.EFIntegration.Test.AutoFixture; + +internal class EfUser : UserFixture { - internal class EfUser : UserFixture + public override void Customize(IFixture fixture) { - public override void Customize(IFixture fixture) - { - fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); - base.Customize(fixture); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - fixture.Customizations.Add(new EfRepositoryListBuilder()); - } - } - - internal class EfUserAutoDataAttribute : CustomAutoDataAttribute - { - public EfUserAutoDataAttribute() : base(new SutProviderCustomization(), new EfUser()) - { } - } - - internal class InlineEfUserAutoDataAttribute : InlineCustomAutoDataAttribute - { - public InlineEfUserAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), - typeof(EfUser) }, values) - { } + fixture.Customizations.Add(new IgnoreVirtualMembersCustomization()); + base.Customize(fixture); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); + fixture.Customizations.Add(new EfRepositoryListBuilder()); } } + +internal class EfUserAutoDataAttribute : CustomAutoDataAttribute +{ + public EfUserAutoDataAttribute() : base(new SutProviderCustomization(), new EfUser()) + { } +} + +internal class InlineEfUserAutoDataAttribute : InlineCustomAutoDataAttribute +{ + public InlineEfUserAutoDataAttribute(params object[] values) : base(new[] { typeof(SutProviderCustomization), + typeof(EfUser) }, values) + { } +} diff --git a/test/Infrastructure.EFIntegration.Test/Helpers/DatabaseOptionsFactory.cs b/test/Infrastructure.EFIntegration.Test/Helpers/DatabaseOptionsFactory.cs index 25ac5912b..fbf0d9828 100644 --- a/test/Infrastructure.EFIntegration.Test/Helpers/DatabaseOptionsFactory.cs +++ b/test/Infrastructure.EFIntegration.Test/Helpers/DatabaseOptionsFactory.cs @@ -2,25 +2,24 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Microsoft.EntityFrameworkCore; -namespace Bit.Infrastructure.EFIntegration.Test.Helpers -{ - public static class DatabaseOptionsFactory - { - public static List> Options { get; } = new(); +namespace Bit.Infrastructure.EFIntegration.Test.Helpers; - static DatabaseOptionsFactory() +public static class DatabaseOptionsFactory +{ + public static List> Options { get; } = new(); + + static DatabaseOptionsFactory() + { + var globalSettings = GlobalSettingsFactory.GlobalSettings; + if (!string.IsNullOrWhiteSpace(GlobalSettingsFactory.GlobalSettings.PostgreSql?.ConnectionString)) { - var globalSettings = GlobalSettingsFactory.GlobalSettings; - if (!string.IsNullOrWhiteSpace(GlobalSettingsFactory.GlobalSettings.PostgreSql?.ConnectionString)) - { - AppContext.SetSwitch("Npgsql.EnableLegacyTimestampBehavior", true); - Options.Add(new DbContextOptionsBuilder().UseNpgsql(globalSettings.PostgreSql.ConnectionString).Options); - } - if (!string.IsNullOrWhiteSpace(GlobalSettingsFactory.GlobalSettings.MySql?.ConnectionString)) - { - var mySqlConnectionString = globalSettings.MySql.ConnectionString; - Options.Add(new DbContextOptionsBuilder().UseMySql(mySqlConnectionString, ServerVersion.AutoDetect(mySqlConnectionString)).Options); - } + AppContext.SetSwitch("Npgsql.EnableLegacyTimestampBehavior", true); + Options.Add(new DbContextOptionsBuilder().UseNpgsql(globalSettings.PostgreSql.ConnectionString).Options); + } + if (!string.IsNullOrWhiteSpace(GlobalSettingsFactory.GlobalSettings.MySql?.ConnectionString)) + { + var mySqlConnectionString = globalSettings.MySql.ConnectionString; + Options.Add(new DbContextOptionsBuilder().UseMySql(mySqlConnectionString, ServerVersion.AutoDetect(mySqlConnectionString)).Options); } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/CipherRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/CipherRepositoryTests.cs index 21e9f4ee1..9b70bffe7 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/CipherRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/CipherRepositoryTests.cs @@ -9,184 +9,183 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories +namespace Bit.Infrastructure.EFIntegration.Test.Repositories; + +public class CipherRepositoryTests { - public class CipherRepositoryTests + [Theory(Skip = "Run ad-hoc"), EfUserCipherAutoData] + public async void RefreshDb(List suts) { - [Theory(Skip = "Run ad-hoc"), EfUserCipherAutoData] - public async void RefreshDb(List suts) + foreach (var sut in suts) { - foreach (var sut in suts) - { - await sut.RefreshDb(); - } + await sut.RefreshDb(); } + } - [CiSkippedTheory, EfUserCipherAutoData, EfOrganizationCipherAutoData] - public async void CreateAsync_Works_DataMatches(Cipher cipher, User user, Organization org, - CipherCompare equalityComparer, List suts, List efUserRepos, - List efOrgRepos, SqlRepo.CipherRepository sqlCipherRepo, - SqlRepo.UserRepository sqlUserRepo, SqlRepo.OrganizationRepository sqlOrgRepo) + [CiSkippedTheory, EfUserCipherAutoData, EfOrganizationCipherAutoData] + public async void CreateAsync_Works_DataMatches(Cipher cipher, User user, Organization org, + CipherCompare equalityComparer, List suts, List efUserRepos, + List efOrgRepos, SqlRepo.CipherRepository sqlCipherRepo, + SqlRepo.UserRepository sqlUserRepo, SqlRepo.OrganizationRepository sqlOrgRepo) + { + var savedCiphers = new List(); + foreach (var sut in suts) { - var savedCiphers = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); + var i = suts.IndexOf(sut); - var efUser = await efUserRepos[i].CreateAsync(user); - sut.ClearChangeTracking(); - cipher.UserId = efUser.Id; - - if (cipher.OrganizationId.HasValue) - { - var efOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); - cipher.OrganizationId = efOrg.Id; - } - - var postEfCipher = await sut.CreateAsync(cipher); - sut.ClearChangeTracking(); - - var savedCipher = await sut.GetByIdAsync(postEfCipher.Id); - savedCiphers.Add(savedCipher); - } - - var sqlUser = await sqlUserRepo.CreateAsync(user); - cipher.UserId = sqlUser.Id; + var efUser = await efUserRepos[i].CreateAsync(user); + sut.ClearChangeTracking(); + cipher.UserId = efUser.Id; if (cipher.OrganizationId.HasValue) { - var sqlOrg = await sqlOrgRepo.CreateAsync(org); - cipher.OrganizationId = sqlOrg.Id; - } - - var sqlCipher = await sqlCipherRepo.CreateAsync(cipher); - var savedSqlCipher = await sqlCipherRepo.GetByIdAsync(sqlCipher.Id); - savedCiphers.Add(savedSqlCipher); - - var distinctItems = savedCiphers.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfUserCipherAutoData] - public async void CreateAsync_BumpsUserAccountRevisionDate(Cipher cipher, User user, List suts, List efUserRepos) - { - var bumpedUsers = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); - - var efUser = await efUserRepos[i].CreateAsync(user); - efUserRepos[i].ClearChangeTracking(); - cipher.UserId = efUser.Id; - cipher.OrganizationId = null; - - var postEfCipher = await sut.CreateAsync(cipher); - sut.ClearChangeTracking(); - - var bumpedUser = await efUserRepos[i].GetByIdAsync(efUser.Id); - bumpedUsers.Add(bumpedUser); - } - - Assert.True(bumpedUsers.All(u => u.AccountRevisionDate.ToShortDateString() == DateTime.UtcNow.ToShortDateString())); - } - - [CiSkippedTheory, EfOrganizationCipherAutoData] - public async void CreateAsync_BumpsOrgUserAccountRevisionDates(Cipher cipher, List users, - List orgUsers, Collection collection, Organization org, List suts, List efUserRepos, List efOrgRepos, - List efOrgUserRepos, List efCollectionRepos) - { - var savedCiphers = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); - - var efUsers = await efUserRepos[i].CreateMany(users); - efUserRepos[i].ClearChangeTracking(); var efOrg = await efOrgRepos[i].CreateAsync(org); - efOrgRepos[i].ClearChangeTracking(); - - cipher.OrganizationId = efOrg.Id; - - collection.OrganizationId = efOrg.Id; - var efCollection = await efCollectionRepos[i].CreateAsync(collection); - efCollectionRepos[i].ClearChangeTracking(); - - IEnumerable[] lists = { efUsers, orgUsers }; - var maxOrgUsers = lists.Min(l => l.Count()); - - orgUsers = orgUsers.Take(maxOrgUsers).ToList(); - efUsers = efUsers.Take(maxOrgUsers).ToList(); - - for (var j = 0; j < maxOrgUsers; j++) - { - orgUsers[j].OrganizationId = efOrg.Id; - orgUsers[j].UserId = efUsers[j].Id; - } - - orgUsers = await efOrgUserRepos[i].CreateMany(orgUsers); - - var selectionReadOnlyList = new List(); - orgUsers.ForEach(ou => selectionReadOnlyList.Add(new SelectionReadOnly() { Id = ou.Id })); - - await efCollectionRepos[i].UpdateUsersAsync(efCollection.Id, selectionReadOnlyList); - efCollectionRepos[i].ClearChangeTracking(); - - foreach (var ou in orgUsers) - { - var collectionUser = new CollectionUser() - { - CollectionId = efCollection.Id, - OrganizationUserId = ou.Id - }; - } - - cipher.UserId = null; - var postEfCipher = await sut.CreateAsync(cipher); sut.ClearChangeTracking(); - - var query = new UserBumpAccountRevisionDateByCipherIdQuery(cipher); - var modifiedUsers = await sut.Run(query).ToListAsync(); - Assert.True(modifiedUsers - .All(u => u.AccountRevisionDate.ToShortDateString() == - DateTime.UtcNow.ToShortDateString())); + cipher.OrganizationId = efOrg.Id; } + + var postEfCipher = await sut.CreateAsync(cipher); + sut.ClearChangeTracking(); + + var savedCipher = await sut.GetByIdAsync(postEfCipher.Id); + savedCiphers.Add(savedCipher); } - [CiSkippedTheory, EfUserCipherAutoData, EfOrganizationCipherAutoData] - public async void DeleteAsync_CipherIsDeleted( - Cipher cipher, - User user, - Organization org, - List suts, - List efUserRepos, - List efOrgRepos - ) + var sqlUser = await sqlUserRepo.CreateAsync(user); + cipher.UserId = sqlUser.Id; + + if (cipher.OrganizationId.HasValue) { - foreach (var sut in suts) + var sqlOrg = await sqlOrgRepo.CreateAsync(org); + cipher.OrganizationId = sqlOrg.Id; + } + + var sqlCipher = await sqlCipherRepo.CreateAsync(cipher); + var savedSqlCipher = await sqlCipherRepo.GetByIdAsync(sqlCipher.Id); + savedCiphers.Add(savedSqlCipher); + + var distinctItems = savedCiphers.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfUserCipherAutoData] + public async void CreateAsync_BumpsUserAccountRevisionDate(Cipher cipher, User user, List suts, List efUserRepos) + { + var bumpedUsers = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); + + var efUser = await efUserRepos[i].CreateAsync(user); + efUserRepos[i].ClearChangeTracking(); + cipher.UserId = efUser.Id; + cipher.OrganizationId = null; + + var postEfCipher = await sut.CreateAsync(cipher); + sut.ClearChangeTracking(); + + var bumpedUser = await efUserRepos[i].GetByIdAsync(efUser.Id); + bumpedUsers.Add(bumpedUser); + } + + Assert.True(bumpedUsers.All(u => u.AccountRevisionDate.ToShortDateString() == DateTime.UtcNow.ToShortDateString())); + } + + [CiSkippedTheory, EfOrganizationCipherAutoData] + public async void CreateAsync_BumpsOrgUserAccountRevisionDates(Cipher cipher, List users, + List orgUsers, Collection collection, Organization org, List suts, List efUserRepos, List efOrgRepos, + List efOrgUserRepos, List efCollectionRepos) + { + var savedCiphers = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); + + var efUsers = await efUserRepos[i].CreateMany(users); + efUserRepos[i].ClearChangeTracking(); + var efOrg = await efOrgRepos[i].CreateAsync(org); + efOrgRepos[i].ClearChangeTracking(); + + cipher.OrganizationId = efOrg.Id; + + collection.OrganizationId = efOrg.Id; + var efCollection = await efCollectionRepos[i].CreateAsync(collection); + efCollectionRepos[i].ClearChangeTracking(); + + IEnumerable[] lists = { efUsers, orgUsers }; + var maxOrgUsers = lists.Min(l => l.Count()); + + orgUsers = orgUsers.Take(maxOrgUsers).ToList(); + efUsers = efUsers.Take(maxOrgUsers).ToList(); + + for (var j = 0; j < maxOrgUsers; j++) { - var i = suts.IndexOf(sut); - - var postEfOrg = await efOrgRepos[i].CreateAsync(org); - efOrgRepos[i].ClearChangeTracking(); - var postEfUser = await efUserRepos[i].CreateAsync(user); - efUserRepos[i].ClearChangeTracking(); - - if (cipher.OrganizationId.HasValue) - { - cipher.OrganizationId = postEfOrg.Id; - } - cipher.UserId = postEfUser.Id; - - await sut.CreateAsync(cipher); - sut.ClearChangeTracking(); - - await sut.DeleteAsync(cipher); - sut.ClearChangeTracking(); - - var savedCipher = await sut.GetByIdAsync(cipher.Id); - Assert.True(savedCipher == null); + orgUsers[j].OrganizationId = efOrg.Id; + orgUsers[j].UserId = efUsers[j].Id; } + + orgUsers = await efOrgUserRepos[i].CreateMany(orgUsers); + + var selectionReadOnlyList = new List(); + orgUsers.ForEach(ou => selectionReadOnlyList.Add(new SelectionReadOnly() { Id = ou.Id })); + + await efCollectionRepos[i].UpdateUsersAsync(efCollection.Id, selectionReadOnlyList); + efCollectionRepos[i].ClearChangeTracking(); + + foreach (var ou in orgUsers) + { + var collectionUser = new CollectionUser() + { + CollectionId = efCollection.Id, + OrganizationUserId = ou.Id + }; + } + + cipher.UserId = null; + var postEfCipher = await sut.CreateAsync(cipher); + sut.ClearChangeTracking(); + + var query = new UserBumpAccountRevisionDateByCipherIdQuery(cipher); + var modifiedUsers = await sut.Run(query).ToListAsync(); + Assert.True(modifiedUsers + .All(u => u.AccountRevisionDate.ToShortDateString() == + DateTime.UtcNow.ToShortDateString())); + } + } + + [CiSkippedTheory, EfUserCipherAutoData, EfOrganizationCipherAutoData] + public async void DeleteAsync_CipherIsDeleted( + Cipher cipher, + User user, + Organization org, + List suts, + List efUserRepos, + List efOrgRepos + ) + { + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); + + var postEfOrg = await efOrgRepos[i].CreateAsync(org); + efOrgRepos[i].ClearChangeTracking(); + var postEfUser = await efUserRepos[i].CreateAsync(user); + efUserRepos[i].ClearChangeTracking(); + + if (cipher.OrganizationId.HasValue) + { + cipher.OrganizationId = postEfOrg.Id; + } + cipher.UserId = postEfUser.Id; + + await sut.CreateAsync(cipher); + sut.ClearChangeTracking(); + + await sut.DeleteAsync(cipher); + sut.ClearChangeTracking(); + + var savedCipher = await sut.GetByIdAsync(cipher.Id); + Assert.True(savedCipher == null); } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/CollectionRepository.cs b/test/Infrastructure.EFIntegration.Test/Repositories/CollectionRepository.cs index ed2bcf74b..1fb20c684 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/CollectionRepository.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/CollectionRepository.cs @@ -6,45 +6,44 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories +namespace Bit.Infrastructure.EFIntegration.Test.Repositories; + +public class CollectionRepositoryTests { - public class CollectionRepositoryTests + [CiSkippedTheory, EfCollectionAutoData] + public async void CreateAsync_Works_DataMatches( + Collection collection, + Organization organization, + CollectionCompare equalityComparer, + List suts, + List efOrganizationRepos, + SqlRepo.CollectionRepository sqlCollectionRepo, + SqlRepo.OrganizationRepository sqlOrganizationRepo + ) { - [CiSkippedTheory, EfCollectionAutoData] - public async void CreateAsync_Works_DataMatches( - Collection collection, - Organization organization, - CollectionCompare equalityComparer, - List suts, - List efOrganizationRepos, - SqlRepo.CollectionRepository sqlCollectionRepo, - SqlRepo.OrganizationRepository sqlOrganizationRepo - ) + var savedCollections = new List(); + foreach (var sut in suts) { - var savedCollections = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); - var efOrganization = await efOrganizationRepos[i].CreateAsync(organization); - sut.ClearChangeTracking(); + var i = suts.IndexOf(sut); + var efOrganization = await efOrganizationRepos[i].CreateAsync(organization); + sut.ClearChangeTracking(); - collection.OrganizationId = efOrganization.Id; - var postEfCollection = await sut.CreateAsync(collection); - sut.ClearChangeTracking(); + collection.OrganizationId = efOrganization.Id; + var postEfCollection = await sut.CreateAsync(collection); + sut.ClearChangeTracking(); - var savedCollection = await sut.GetByIdAsync(postEfCollection.Id); - savedCollections.Add(savedCollection); - } - - var sqlOrganization = await sqlOrganizationRepo.CreateAsync(organization); - collection.OrganizationId = sqlOrganization.Id; - - var sqlCollection = await sqlCollectionRepo.CreateAsync(collection); - var savedSqlCollection = await sqlCollectionRepo.GetByIdAsync(sqlCollection.Id); - savedCollections.Add(savedSqlCollection); - - var distinctItems = savedCollections.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var savedCollection = await sut.GetByIdAsync(postEfCollection.Id); + savedCollections.Add(savedCollection); } + + var sqlOrganization = await sqlOrganizationRepo.CreateAsync(organization); + collection.OrganizationId = sqlOrganization.Id; + + var sqlCollection = await sqlCollectionRepo.CreateAsync(collection); + var savedSqlCollection = await sqlCollectionRepo.GetByIdAsync(sqlCollection.Id); + savedCollections.Add(savedSqlCollection); + + var distinctItems = savedCollections.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/DeviceRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/DeviceRepositoryTests.cs index 4c5de177c..fc1f5c8b3 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/DeviceRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/DeviceRepositoryTests.cs @@ -6,42 +6,41 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories +namespace Bit.Infrastructure.EFIntegration.Test.Repositories; + +public class DeviceRepositoryTests { - public class DeviceRepositoryTests + [CiSkippedTheory, EfDeviceAutoData] + public async void CreateAsync_Works_DataMatches(Device device, User user, + DeviceCompare equalityComparer, List suts, + List efUserRepos, SqlRepo.DeviceRepository sqlDeviceRepo, + SqlRepo.UserRepository sqlUserRepo) { - [CiSkippedTheory, EfDeviceAutoData] - public async void CreateAsync_Works_DataMatches(Device device, User user, - DeviceCompare equalityComparer, List suts, - List efUserRepos, SqlRepo.DeviceRepository sqlDeviceRepo, - SqlRepo.UserRepository sqlUserRepo) + var savedDevices = new List(); + foreach (var sut in suts) { - var savedDevices = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); + var i = suts.IndexOf(sut); - var efUser = await efUserRepos[i].CreateAsync(user); - device.UserId = efUser.Id; - sut.ClearChangeTracking(); + var efUser = await efUserRepos[i].CreateAsync(user); + device.UserId = efUser.Id; + sut.ClearChangeTracking(); - var postEfDevice = await sut.CreateAsync(device); - sut.ClearChangeTracking(); + var postEfDevice = await sut.CreateAsync(device); + sut.ClearChangeTracking(); - var savedDevice = await sut.GetByIdAsync(postEfDevice.Id); - savedDevices.Add(savedDevice); - } - - var sqlUser = await sqlUserRepo.CreateAsync(user); - device.UserId = sqlUser.Id; - - var sqlDevice = await sqlDeviceRepo.CreateAsync(device); - var savedSqlDevice = await sqlDeviceRepo.GetByIdAsync(sqlDevice.Id); - savedDevices.Add(savedSqlDevice); - - var distinctItems = savedDevices.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var savedDevice = await sut.GetByIdAsync(postEfDevice.Id); + savedDevices.Add(savedDevice); } + var sqlUser = await sqlUserRepo.CreateAsync(user); + device.UserId = sqlUser.Id; + + var sqlDevice = await sqlDeviceRepo.CreateAsync(device); + var savedSqlDevice = await sqlDeviceRepo.GetByIdAsync(sqlDevice.Id); + savedDevices.Add(savedSqlDevice); + + var distinctItems = savedDevices.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } + } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EmergencyAccessRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EmergencyAccessRepositoryTests.cs index 1bb31d476..d014d463a 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EmergencyAccessRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EmergencyAccessRepositoryTests.cs @@ -6,54 +6,53 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories +namespace Bit.Infrastructure.EFIntegration.Test.Repositories; + +public class EmergencyAccessRepositoryTests { - public class EmergencyAccessRepositoryTests + [CiSkippedTheory, EfEmergencyAccessAutoData] + public async void CreateAsync_Works_DataMatches( + EmergencyAccess emergencyAccess, + List users, + EmergencyAccessCompare equalityComparer, + List suts, + List efUserRepos, + SqlRepo.EmergencyAccessRepository sqlEmergencyAccessRepo, + SqlRepo.UserRepository sqlUserRepo + ) { - [CiSkippedTheory, EfEmergencyAccessAutoData] - public async void CreateAsync_Works_DataMatches( - EmergencyAccess emergencyAccess, - List users, - EmergencyAccessCompare equalityComparer, - List suts, - List efUserRepos, - SqlRepo.EmergencyAccessRepository sqlEmergencyAccessRepo, - SqlRepo.UserRepository sqlUserRepo - ) + var savedEmergencyAccesss = new List(); + foreach (var sut in suts) { - var savedEmergencyAccesss = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); - - for (int j = 0; j < users.Count; j++) - { - users[j] = await efUserRepos[i].CreateAsync(users[j]); - } - sut.ClearChangeTracking(); - - emergencyAccess.GrantorId = users[0].Id; - emergencyAccess.GranteeId = users[0].Id; - var postEfEmergencyAccess = await sut.CreateAsync(emergencyAccess); - sut.ClearChangeTracking(); - - var savedEmergencyAccess = await sut.GetByIdAsync(postEfEmergencyAccess.Id); - savedEmergencyAccesss.Add(savedEmergencyAccess); - } + var i = suts.IndexOf(sut); for (int j = 0; j < users.Count; j++) { - users[j] = await sqlUserRepo.CreateAsync(users[j]); + users[j] = await efUserRepos[i].CreateAsync(users[j]); } + sut.ClearChangeTracking(); emergencyAccess.GrantorId = users[0].Id; emergencyAccess.GranteeId = users[0].Id; - var sqlEmergencyAccess = await sqlEmergencyAccessRepo.CreateAsync(emergencyAccess); - var savedSqlEmergencyAccess = await sqlEmergencyAccessRepo.GetByIdAsync(sqlEmergencyAccess.Id); - savedEmergencyAccesss.Add(savedSqlEmergencyAccess); + var postEfEmergencyAccess = await sut.CreateAsync(emergencyAccess); + sut.ClearChangeTracking(); - var distinctItems = savedEmergencyAccesss.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var savedEmergencyAccess = await sut.GetByIdAsync(postEfEmergencyAccess.Id); + savedEmergencyAccesss.Add(savedEmergencyAccess); } + + for (int j = 0; j < users.Count; j++) + { + users[j] = await sqlUserRepo.CreateAsync(users[j]); + } + + emergencyAccess.GrantorId = users[0].Id; + emergencyAccess.GranteeId = users[0].Id; + var sqlEmergencyAccess = await sqlEmergencyAccessRepo.CreateAsync(emergencyAccess); + var savedSqlEmergencyAccess = await sqlEmergencyAccessRepo.GetByIdAsync(sqlEmergencyAccess.Id); + savedEmergencyAccesss.Add(savedSqlEmergencyAccess); + + var distinctItems = savedEmergencyAccesss.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/CipherCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/CipherCompare.cs index f5be069bd..230b51dd6 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/CipherCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/CipherCompare.cs @@ -1,21 +1,20 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class CipherCompare : IEqualityComparer - { - public bool Equals(Cipher x, Cipher y) - { - return x.Type == y.Type && - x.Data == y.Data && - x.Favorites == y.Favorites && - x.Attachments == y.Attachments; - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] Cipher obj) - { - return base.GetHashCode(); - } +public class CipherCompare : IEqualityComparer +{ + public bool Equals(Cipher x, Cipher y) + { + return x.Type == y.Type && + x.Data == y.Data && + x.Favorites == y.Favorites && + x.Attachments == y.Attachments; + } + + public int GetHashCode([DisallowNull] Cipher obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/CollectionCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/CollectionCompare.cs index a7cef8f6d..56cb0acf7 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/CollectionCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/CollectionCompare.cs @@ -1,19 +1,18 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class CollectionCompare : IEqualityComparer - { - public bool Equals(Collection x, Collection y) - { - return x.Name == y.Name && - x.ExternalId == y.ExternalId; - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] Collection obj) - { - return base.GetHashCode(); - } +public class CollectionCompare : IEqualityComparer +{ + public bool Equals(Collection x, Collection y) + { + return x.Name == y.Name && + x.ExternalId == y.ExternalId; + } + + public int GetHashCode([DisallowNull] Collection obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/DeviceCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/DeviceCompare.cs index ac8a24d20..086199b38 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/DeviceCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/DeviceCompare.cs @@ -1,21 +1,20 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class DeviceCompare : IEqualityComparer - { - public bool Equals(Device x, Device y) - { - return x.Name == y.Name && - x.Type == y.Type && - x.Identifier == y.Identifier && - x.PushToken == y.PushToken; - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] Device obj) - { - return base.GetHashCode(); - } +public class DeviceCompare : IEqualityComparer +{ + public bool Equals(Device x, Device y) + { + return x.Name == y.Name && + x.Type == y.Type && + x.Identifier == y.Identifier && + x.PushToken == y.PushToken; + } + + public int GetHashCode([DisallowNull] Device obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/EmergencyAccessCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/EmergencyAccessCompare.cs index bc2592f43..eb182d6e9 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/EmergencyAccessCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/EmergencyAccessCompare.cs @@ -1,24 +1,23 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class EmergencyAccessCompare : IEqualityComparer - { - public bool Equals(EmergencyAccess x, EmergencyAccess y) - { - return x.Email == y.Email && - x.KeyEncrypted == y.KeyEncrypted && - x.Type == y.Type && - x.Status == y.Status && - x.WaitTimeDays == y.WaitTimeDays && - x.RecoveryInitiatedDate == y.RecoveryInitiatedDate && - x.LastNotificationDate == y.LastNotificationDate; - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] EmergencyAccess obj) - { - return base.GetHashCode(); - } +public class EmergencyAccessCompare : IEqualityComparer +{ + public bool Equals(EmergencyAccess x, EmergencyAccess y) + { + return x.Email == y.Email && + x.KeyEncrypted == y.KeyEncrypted && + x.Type == y.Type && + x.Status == y.Status && + x.WaitTimeDays == y.WaitTimeDays && + x.RecoveryInitiatedDate == y.RecoveryInitiatedDate && + x.LastNotificationDate == y.LastNotificationDate; + } + + public int GetHashCode([DisallowNull] EmergencyAccess obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/EventCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/EventCompare.cs index a42f8cb5e..e414f7c25 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/EventCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/EventCompare.cs @@ -1,20 +1,19 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class EventCompare : IEqualityComparer - { - public bool Equals(Event x, Event y) - { - return x.Date.ToShortDateString() == y.Date.ToShortDateString() && - x.Type == y.Type && - x.IpAddress == y.IpAddress; - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] Event obj) - { - return base.GetHashCode(); - } +public class EventCompare : IEqualityComparer +{ + public bool Equals(Event x, Event y) + { + return x.Date.ToShortDateString() == y.Date.ToShortDateString() && + x.Type == y.Type && + x.IpAddress == y.IpAddress; + } + + public int GetHashCode([DisallowNull] Event obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/FolderCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/FolderCompare.cs index 61e261f8a..2bdb71385 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/FolderCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/FolderCompare.cs @@ -1,18 +1,17 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class FolderCompare : IEqualityComparer - { - public bool Equals(Folder x, Folder y) - { - return x.Name == y.Name; - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] Folder obj) - { - return base.GetHashCode(); - } +public class FolderCompare : IEqualityComparer +{ + public bool Equals(Folder x, Folder y) + { + return x.Name == y.Name; + } + + public int GetHashCode([DisallowNull] Folder obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/GrantCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/GrantCompare.cs index 978d4d62d..762157716 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/GrantCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/GrantCompare.cs @@ -1,25 +1,24 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class GrantCompare : IEqualityComparer - { - public bool Equals(Grant x, Grant y) - { - return x.Key == y.Key && - x.Type == y.Type && - x.SubjectId == y.SubjectId && - x.ClientId == y.ClientId && - x.Description == y.Description && - x.ExpirationDate == y.ExpirationDate && - x.ConsumedDate == y.ConsumedDate && - x.Data == y.Data; - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] Grant obj) - { - return base.GetHashCode(); - } +public class GrantCompare : IEqualityComparer +{ + public bool Equals(Grant x, Grant y) + { + return x.Key == y.Key && + x.Type == y.Type && + x.SubjectId == y.SubjectId && + x.ClientId == y.ClientId && + x.Description == y.Description && + x.ExpirationDate == y.ExpirationDate && + x.ConsumedDate == y.ConsumedDate && + x.Data == y.Data; + } + + public int GetHashCode([DisallowNull] Grant obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/GroupCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/GroupCompare.cs index aa2e1ae89..dcb0be2ff 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/GroupCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/GroupCompare.cs @@ -1,20 +1,19 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class GroupCompare : IEqualityComparer - { - public bool Equals(Group x, Group y) - { - return x.Name == y.Name && - x.AccessAll == y.AccessAll && - x.ExternalId == y.ExternalId; - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] Group obj) - { - return base.GetHashCode(); - } +public class GroupCompare : IEqualityComparer +{ + public bool Equals(Group x, Group y) + { + return x.Name == y.Name && + x.AccessAll == y.AccessAll && + x.ExternalId == y.ExternalId; + } + + public int GetHashCode([DisallowNull] Group obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/InstallationCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/InstallationCompare.cs index 38a92daa3..7794785b3 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/InstallationCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/InstallationCompare.cs @@ -1,20 +1,19 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class InstallationCompare : IEqualityComparer - { - public bool Equals(Installation x, Installation y) - { - return x.Email == y.Email && - x.Key == y.Key && - x.Enabled == y.Enabled; - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] Installation obj) - { - return base.GetHashCode(); - } +public class InstallationCompare : IEqualityComparer +{ + public bool Equals(Installation x, Installation y) + { + return x.Email == y.Email && + x.Key == y.Key && + x.Enabled == y.Enabled; + } + + public int GetHashCode([DisallowNull] Installation obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationCompare.cs index a8f32643e..f1879937a 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationCompare.cs @@ -1,54 +1,53 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class OrganizationCompare : IEqualityComparer - { - public bool Equals(Organization x, Organization y) - { - var a = x.ExpirationDate.ToString(); - var b = y.ExpirationDate.ToString(); - return x.Identifier.Equals(y.Identifier) && - x.Name.Equals(y.Name) && - x.BusinessName.Equals(y.BusinessName) && - x.BusinessAddress1.Equals(y.BusinessAddress1) && - x.BusinessAddress2.Equals(y.BusinessAddress2) && - x.BusinessAddress3.Equals(y.BusinessAddress3) && - x.BusinessCountry.Equals(y.BusinessCountry) && - x.BusinessTaxNumber.Equals(y.BusinessTaxNumber) && - x.BillingEmail.Equals(y.BillingEmail) && - x.Plan.Equals(y.Plan) && - x.PlanType.Equals(y.PlanType) && - x.Seats.Equals(y.Seats) && - x.MaxCollections.Equals(y.MaxCollections) && - x.UsePolicies.Equals(y.UsePolicies) && - x.UseSso.Equals(y.UseSso) && - x.UseKeyConnector.Equals(y.UseKeyConnector) && - x.UseScim.Equals(y.UseScim) && - x.UseGroups.Equals(y.UseGroups) && - x.UseDirectory.Equals(y.UseDirectory) && - x.UseEvents.Equals(y.UseEvents) && - x.UseTotp.Equals(y.UseTotp) && - x.Use2fa.Equals(y.Use2fa) && - x.UseApi.Equals(y.UseApi) && - x.SelfHost.Equals(y.SelfHost) && - x.UsersGetPremium.Equals(y.UsersGetPremium) && - x.Storage.Equals(y.Storage) && - x.MaxStorageGb.Equals(y.MaxStorageGb) && - x.Gateway.Equals(y.Gateway) && - x.GatewayCustomerId.Equals(y.GatewayCustomerId) && - x.GatewaySubscriptionId.Equals(y.GatewaySubscriptionId) && - x.ReferenceData.Equals(y.ReferenceData) && - x.Enabled.Equals(y.Enabled) && - x.LicenseKey.Equals(y.LicenseKey) && - x.TwoFactorProviders.Equals(y.TwoFactorProviders) && - x.ExpirationDate.ToString().Equals(y.ExpirationDate.ToString()); - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] Organization obj) - { - return base.GetHashCode(); - } +public class OrganizationCompare : IEqualityComparer +{ + public bool Equals(Organization x, Organization y) + { + var a = x.ExpirationDate.ToString(); + var b = y.ExpirationDate.ToString(); + return x.Identifier.Equals(y.Identifier) && + x.Name.Equals(y.Name) && + x.BusinessName.Equals(y.BusinessName) && + x.BusinessAddress1.Equals(y.BusinessAddress1) && + x.BusinessAddress2.Equals(y.BusinessAddress2) && + x.BusinessAddress3.Equals(y.BusinessAddress3) && + x.BusinessCountry.Equals(y.BusinessCountry) && + x.BusinessTaxNumber.Equals(y.BusinessTaxNumber) && + x.BillingEmail.Equals(y.BillingEmail) && + x.Plan.Equals(y.Plan) && + x.PlanType.Equals(y.PlanType) && + x.Seats.Equals(y.Seats) && + x.MaxCollections.Equals(y.MaxCollections) && + x.UsePolicies.Equals(y.UsePolicies) && + x.UseSso.Equals(y.UseSso) && + x.UseKeyConnector.Equals(y.UseKeyConnector) && + x.UseScim.Equals(y.UseScim) && + x.UseGroups.Equals(y.UseGroups) && + x.UseDirectory.Equals(y.UseDirectory) && + x.UseEvents.Equals(y.UseEvents) && + x.UseTotp.Equals(y.UseTotp) && + x.Use2fa.Equals(y.Use2fa) && + x.UseApi.Equals(y.UseApi) && + x.SelfHost.Equals(y.SelfHost) && + x.UsersGetPremium.Equals(y.UsersGetPremium) && + x.Storage.Equals(y.Storage) && + x.MaxStorageGb.Equals(y.MaxStorageGb) && + x.Gateway.Equals(y.Gateway) && + x.GatewayCustomerId.Equals(y.GatewayCustomerId) && + x.GatewaySubscriptionId.Equals(y.GatewaySubscriptionId) && + x.ReferenceData.Equals(y.ReferenceData) && + x.Enabled.Equals(y.Enabled) && + x.LicenseKey.Equals(y.LicenseKey) && + x.TwoFactorProviders.Equals(y.TwoFactorProviders) && + x.ExpirationDate.ToString().Equals(y.ExpirationDate.ToString()); + } + + public int GetHashCode([DisallowNull] Organization obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationSponsorshipCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationSponsorshipCompare.cs index c90aaf065..e17e76592 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationSponsorshipCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationSponsorshipCompare.cs @@ -1,23 +1,22 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class OrganizationSponsorshipCompare : IEqualityComparer - { - public bool Equals(OrganizationSponsorship x, OrganizationSponsorship y) - { - return x.SponsoringOrganizationId.Equals(y.SponsoringOrganizationId) && - x.SponsoringOrganizationUserId.Equals(y.SponsoringOrganizationUserId) && - x.SponsoredOrganizationId.Equals(y.SponsoredOrganizationId) && - x.OfferedToEmail.Equals(y.OfferedToEmail) && - x.ToDelete.Equals(y.ToDelete) && - x.ValidUntil.ToString().Equals(y.ValidUntil.ToString()); - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] OrganizationSponsorship obj) - { - return base.GetHashCode(); - } +public class OrganizationSponsorshipCompare : IEqualityComparer +{ + public bool Equals(OrganizationSponsorship x, OrganizationSponsorship y) + { + return x.SponsoringOrganizationId.Equals(y.SponsoringOrganizationId) && + x.SponsoringOrganizationUserId.Equals(y.SponsoringOrganizationUserId) && + x.SponsoredOrganizationId.Equals(y.SponsoredOrganizationId) && + x.OfferedToEmail.Equals(y.OfferedToEmail) && + x.ToDelete.Equals(y.ToDelete) && + x.ValidUntil.ToString().Equals(y.ValidUntil.ToString()); + } + + public int GetHashCode([DisallowNull] OrganizationSponsorship obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationUserCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationUserCompare.cs index bb7895a2f..6d947cc6c 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationUserCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/OrganizationUserCompare.cs @@ -1,23 +1,22 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class OrganizationUserCompare : IEqualityComparer - { - public bool Equals(OrganizationUser x, OrganizationUser y) - { - return x.Email == y.Email && - x.Status == y.Status && - x.Type == y.Type && - x.AccessAll == y.AccessAll && - x.ExternalId == y.ExternalId && - x.Permissions == y.Permissions; - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] OrganizationUser obj) - { - return base.GetHashCode(); - } +public class OrganizationUserCompare : IEqualityComparer +{ + public bool Equals(OrganizationUser x, OrganizationUser y) + { + return x.Email == y.Email && + x.Status == y.Status && + x.Type == y.Type && + x.AccessAll == y.AccessAll && + x.ExternalId == y.ExternalId && + x.Permissions == y.Permissions; + } + + public int GetHashCode([DisallowNull] OrganizationUser obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/PolicyCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/PolicyCompare.cs index 758675c5a..f3bd7dc7a 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/PolicyCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/PolicyCompare.cs @@ -1,29 +1,28 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class PolicyCompare : IEqualityComparer - { - public bool Equals(Policy x, Policy y) - { - return x.Type == y.Type && - x.Data == y.Data && - x.Enabled == y.Enabled; - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] Policy obj) - { - return base.GetHashCode(); - } +public class PolicyCompare : IEqualityComparer +{ + public bool Equals(Policy x, Policy y) + { + return x.Type == y.Type && + x.Data == y.Data && + x.Enabled == y.Enabled; } - public class PolicyCompareIncludingOrganization : PolicyCompare + public int GetHashCode([DisallowNull] Policy obj) { - public new bool Equals(Policy x, Policy y) - { - return base.Equals(x, y) && - x.OrganizationId == y.OrganizationId; - } + return base.GetHashCode(); + } +} + +public class PolicyCompareIncludingOrganization : PolicyCompare +{ + public new bool Equals(Policy x, Policy y) + { + return base.Equals(x, y) && + x.OrganizationId == y.OrganizationId; } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SendCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SendCompare.cs index 705799779..b4723051c 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SendCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SendCompare.cs @@ -1,27 +1,26 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class SendCompare : IEqualityComparer - { - public bool Equals(Send x, Send y) - { - return x.Type == y.Type && - x.Data == y.Data && - x.Key == y.Key && - x.Password == y.Password && - x.MaxAccessCount == y.MaxAccessCount && - x.AccessCount == y.AccessCount && - x.ExpirationDate?.ToShortDateString() == y.ExpirationDate?.ToShortDateString() && - x.DeletionDate.ToShortDateString() == y.DeletionDate.ToShortDateString() && - x.Disabled == y.Disabled && - x.HideEmail == y.HideEmail; - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] Send obj) - { - return base.GetHashCode(); - } +public class SendCompare : IEqualityComparer +{ + public bool Equals(Send x, Send y) + { + return x.Type == y.Type && + x.Data == y.Data && + x.Key == y.Key && + x.Password == y.Password && + x.MaxAccessCount == y.MaxAccessCount && + x.AccessCount == y.AccessCount && + x.ExpirationDate?.ToShortDateString() == y.ExpirationDate?.ToShortDateString() && + x.DeletionDate.ToShortDateString() == y.DeletionDate.ToShortDateString() && + x.Disabled == y.Disabled && + x.HideEmail == y.HideEmail; + } + + public int GetHashCode([DisallowNull] Send obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SsoConfigCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SsoConfigCompare.cs index 8d6accd86..766b8c685 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SsoConfigCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SsoConfigCompare.cs @@ -1,20 +1,19 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class SsoConfigCompare : IEqualityComparer - { - public bool Equals(SsoConfig x, SsoConfig y) - { - return x.Enabled == y.Enabled && - x.OrganizationId == y.OrganizationId && - x.Data == y.Data; - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] SsoConfig obj) - { - return base.GetHashCode(); - } +public class SsoConfigCompare : IEqualityComparer +{ + public bool Equals(SsoConfig x, SsoConfig y) + { + return x.Enabled == y.Enabled && + x.OrganizationId == y.OrganizationId && + x.Data == y.Data; + } + + public int GetHashCode([DisallowNull] SsoConfig obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SsoUserCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SsoUserCompare.cs index a50054514..fffd512c6 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SsoUserCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/SsoUserCompare.cs @@ -1,18 +1,17 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class SsoUserCompare : IEqualityComparer - { - public bool Equals(SsoUser x, SsoUser y) - { - return x.ExternalId == y.ExternalId; - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] SsoUser obj) - { - return base.GetHashCode(); - } +public class SsoUserCompare : IEqualityComparer +{ + public bool Equals(SsoUser x, SsoUser y) + { + return x.ExternalId == y.ExternalId; + } + + public int GetHashCode([DisallowNull] SsoUser obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/TaxRateCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/TaxRateCompare.cs index c2305b959..ff3c0a600 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/TaxRateCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/TaxRateCompare.cs @@ -1,22 +1,21 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class TaxRateCompare : IEqualityComparer - { - public bool Equals(TaxRate x, TaxRate y) - { - return x.Country == y.Country && - x.State == y.State && - x.PostalCode == y.PostalCode && - x.Rate == y.Rate && - x.Active == y.Active; - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] TaxRate obj) - { - return base.GetHashCode(); - } +public class TaxRateCompare : IEqualityComparer +{ + public bool Equals(TaxRate x, TaxRate y) + { + return x.Country == y.Country && + x.State == y.State && + x.PostalCode == y.PostalCode && + x.Rate == y.Rate && + x.Active == y.Active; + } + + public int GetHashCode([DisallowNull] TaxRate obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/TransactionCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/TransactionCompare.cs index 2ce594ec4..fadcdf5b1 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/TransactionCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/TransactionCompare.cs @@ -1,24 +1,23 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class TransactionCompare : IEqualityComparer - { - public bool Equals(Transaction x, Transaction y) - { - return x.Type == y.Type && - x.Amount == y.Amount && - x.Refunded == y.Refunded && - x.Details == y.Details && - x.PaymentMethodType == y.PaymentMethodType && - x.Gateway == y.Gateway && - x.GatewayId == y.GatewayId; - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] Transaction obj) - { - return base.GetHashCode(); - } +public class TransactionCompare : IEqualityComparer +{ + public bool Equals(Transaction x, Transaction y) + { + return x.Type == y.Type && + x.Amount == y.Amount && + x.Refunded == y.Refunded && + x.Details == y.Details && + x.PaymentMethodType == y.PaymentMethodType && + x.Gateway == y.Gateway && + x.GatewayId == y.GatewayId; + } + + public int GetHashCode([DisallowNull] Transaction obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/UserCompare.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/UserCompare.cs index 311d4a01f..90a6af51b 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/UserCompare.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/UserCompare.cs @@ -1,40 +1,39 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class UserCompare : IEqualityComparer - { - public bool Equals(User x, User y) - { - return x.Name == y.Name && - x.Email == y.Email && - x.EmailVerified == y.EmailVerified && - x.MasterPassword == y.MasterPassword && - x.MasterPasswordHint == y.MasterPasswordHint && - x.Culture == y.Culture && - x.SecurityStamp == y.SecurityStamp && - x.TwoFactorProviders == y.TwoFactorProviders && - x.TwoFactorRecoveryCode == y.TwoFactorRecoveryCode && - x.EquivalentDomains == y.EquivalentDomains && - x.Key == y.Key && - x.PublicKey == y.PublicKey && - x.PrivateKey == y.PrivateKey && - x.Premium == y.Premium && - x.Storage == y.Storage && - x.MaxStorageGb == y.MaxStorageGb && - x.Gateway == y.Gateway && - x.GatewayCustomerId == y.GatewayCustomerId && - x.ReferenceData == y.ReferenceData && - x.LicenseKey == y.LicenseKey && - x.ApiKey == y.ApiKey && - x.Kdf == y.Kdf && - x.KdfIterations == y.KdfIterations; - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] User obj) - { - return base.GetHashCode(); - } +public class UserCompare : IEqualityComparer +{ + public bool Equals(User x, User y) + { + return x.Name == y.Name && + x.Email == y.Email && + x.EmailVerified == y.EmailVerified && + x.MasterPassword == y.MasterPassword && + x.MasterPasswordHint == y.MasterPasswordHint && + x.Culture == y.Culture && + x.SecurityStamp == y.SecurityStamp && + x.TwoFactorProviders == y.TwoFactorProviders && + x.TwoFactorRecoveryCode == y.TwoFactorRecoveryCode && + x.EquivalentDomains == y.EquivalentDomains && + x.Key == y.Key && + x.PublicKey == y.PublicKey && + x.PrivateKey == y.PrivateKey && + x.Premium == y.Premium && + x.Storage == y.Storage && + x.MaxStorageGb == y.MaxStorageGb && + x.Gateway == y.Gateway && + x.GatewayCustomerId == y.GatewayCustomerId && + x.ReferenceData == y.ReferenceData && + x.LicenseKey == y.LicenseKey && + x.ApiKey == y.ApiKey && + x.Kdf == y.Kdf && + x.KdfIterations == y.KdfIterations; + } + + public int GetHashCode([DisallowNull] User obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/UserKdfInformation.cs b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/UserKdfInformation.cs index 143903de3..079d37c3f 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/UserKdfInformation.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/EqualityComparers/UserKdfInformation.cs @@ -1,19 +1,18 @@ using System.Diagnostics.CodeAnalysis; using Bit.Core.Models.Data; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers -{ - public class UserKdfInformationCompare : IEqualityComparer - { - public bool Equals(UserKdfInformation x, UserKdfInformation y) - { - return x.Kdf == y.Kdf && - x.KdfIterations == y.KdfIterations; - } +namespace Bit.Infrastructure.EFIntegration.Test.Repositories.EqualityComparers; - public int GetHashCode([DisallowNull] UserKdfInformation obj) - { - return base.GetHashCode(); - } +public class UserKdfInformationCompare : IEqualityComparer +{ + public bool Equals(UserKdfInformation x, UserKdfInformation y) + { + return x.Kdf == y.Kdf && + x.KdfIterations == y.KdfIterations; + } + + public int GetHashCode([DisallowNull] UserKdfInformation obj) + { + return base.GetHashCode(); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/FolderRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/FolderRepositoryTests.cs index ae3f4fe9b..53edbd3c4 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/FolderRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/FolderRepositoryTests.cs @@ -6,44 +6,43 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories +namespace Bit.Infrastructure.EFIntegration.Test.Repositories; + +public class FolderRepositoryTests { - public class FolderRepositoryTests + [CiSkippedTheory, EfFolderAutoData] + public async void CreateAsync_Works_DataMatches( + Folder folder, + User user, + FolderCompare equalityComparer, + List suts, + List efUserRepos, + SqlRepo.FolderRepository sqlFolderRepo, + SqlRepo.UserRepository sqlUserRepo) { - [CiSkippedTheory, EfFolderAutoData] - public async void CreateAsync_Works_DataMatches( - Folder folder, - User user, - FolderCompare equalityComparer, - List suts, - List efUserRepos, - SqlRepo.FolderRepository sqlFolderRepo, - SqlRepo.UserRepository sqlUserRepo) + var savedFolders = new List(); + foreach (var sut in suts) { - var savedFolders = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); + var i = suts.IndexOf(sut); - var efUser = await efUserRepos[i].CreateAsync(user); - sut.ClearChangeTracking(); + var efUser = await efUserRepos[i].CreateAsync(user); + sut.ClearChangeTracking(); - folder.UserId = efUser.Id; - var postEfFolder = await sut.CreateAsync(folder); - sut.ClearChangeTracking(); + folder.UserId = efUser.Id; + var postEfFolder = await sut.CreateAsync(folder); + sut.ClearChangeTracking(); - var savedFolder = await sut.GetByIdAsync(folder.Id); - savedFolders.Add(savedFolder); - } - - var sqlUser = await sqlUserRepo.CreateAsync(user); - - folder.UserId = sqlUser.Id; - var sqlFolder = await sqlFolderRepo.CreateAsync(folder); - savedFolders.Add(await sqlFolderRepo.GetByIdAsync(sqlFolder.Id)); - - var distinctItems = savedFolders.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var savedFolder = await sut.GetByIdAsync(folder.Id); + savedFolders.Add(savedFolder); } + + var sqlUser = await sqlUserRepo.CreateAsync(user); + + folder.UserId = sqlUser.Id; + var sqlFolder = await sqlFolderRepo.CreateAsync(folder); + savedFolders.Add(await sqlFolderRepo.GetByIdAsync(sqlFolder.Id)); + + var distinctItems = savedFolders.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/InstallationRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/InstallationRepositoryTests.cs index 90b8d5bbc..9827b0c03 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/InstallationRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/InstallationRepositoryTests.cs @@ -6,34 +6,33 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories +namespace Bit.Infrastructure.EFIntegration.Test.Repositories; + +public class InstallationRepositoryTests { - public class InstallationRepositoryTests + [CiSkippedTheory, EfInstallationAutoData] + public async void CreateAsync_Works_DataMatches( + Installation installation, + InstallationCompare equalityComparer, + List suts, + SqlRepo.InstallationRepository sqlInstallationRepo + ) { - [CiSkippedTheory, EfInstallationAutoData] - public async void CreateAsync_Works_DataMatches( - Installation installation, - InstallationCompare equalityComparer, - List suts, - SqlRepo.InstallationRepository sqlInstallationRepo - ) + var savedInstallations = new List(); + foreach (var sut in suts) { - var savedInstallations = new List(); - foreach (var sut in suts) - { - var postEfInstallation = await sut.CreateAsync(installation); - sut.ClearChangeTracking(); + var postEfInstallation = await sut.CreateAsync(installation); + sut.ClearChangeTracking(); - var savedInstallation = await sut.GetByIdAsync(postEfInstallation.Id); - savedInstallations.Add(savedInstallation); - } - - var sqlInstallation = await sqlInstallationRepo.CreateAsync(installation); - var savedSqlInstallation = await sqlInstallationRepo.GetByIdAsync(sqlInstallation.Id); - savedInstallations.Add(savedSqlInstallation); - - var distinctItems = savedInstallations.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var savedInstallation = await sut.GetByIdAsync(postEfInstallation.Id); + savedInstallations.Add(savedInstallation); } + + var sqlInstallation = await sqlInstallationRepo.CreateAsync(installation); + var savedSqlInstallation = await sqlInstallationRepo.GetByIdAsync(sqlInstallation.Id); + savedInstallations.Add(savedSqlInstallation); + + var distinctItems = savedInstallations.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationRepositoryTests.cs index eb6713afb..04e314d56 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationRepositoryTests.cs @@ -7,144 +7,143 @@ using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using Organization = Bit.Core.Entities.Organization; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories +namespace Bit.Infrastructure.EFIntegration.Test.Repositories; + +public class OrganizationRepositoryTests { - public class OrganizationRepositoryTests + [CiSkippedTheory, EfOrganizationAutoData] + public async void CreateAsync_Works_DataMatches( + Organization organization, + SqlRepo.OrganizationRepository sqlOrganizationRepo, OrganizationCompare equalityComparer, + List suts) { - [CiSkippedTheory, EfOrganizationAutoData] - public async void CreateAsync_Works_DataMatches( - Organization organization, - SqlRepo.OrganizationRepository sqlOrganizationRepo, OrganizationCompare equalityComparer, - List suts) + var savedOrganizations = new List(); + foreach (var sut in suts) { - var savedOrganizations = new List(); - foreach (var sut in suts) - { - var postEfOrganization = await sut.CreateAsync(organization); - sut.ClearChangeTracking(); + var postEfOrganization = await sut.CreateAsync(organization); + sut.ClearChangeTracking(); - var savedOrganization = await sut.GetByIdAsync(organization.Id); - savedOrganizations.Add(savedOrganization); - } - - var sqlOrganization = await sqlOrganizationRepo.CreateAsync(organization); - savedOrganizations.Add(await sqlOrganizationRepo.GetByIdAsync(sqlOrganization.Id)); - - var distinctItems = savedOrganizations.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var savedOrganization = await sut.GetByIdAsync(organization.Id); + savedOrganizations.Add(savedOrganization); } - [CiSkippedTheory, EfOrganizationAutoData] - public async void ReplaceAsync_Works_DataMatches(Organization postOrganization, - Organization replaceOrganization, SqlRepo.OrganizationRepository sqlOrganizationRepo, - OrganizationCompare equalityComparer, List suts) + var sqlOrganization = await sqlOrganizationRepo.CreateAsync(organization); + savedOrganizations.Add(await sqlOrganizationRepo.GetByIdAsync(sqlOrganization.Id)); + + var distinctItems = savedOrganizations.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfOrganizationAutoData] + public async void ReplaceAsync_Works_DataMatches(Organization postOrganization, + Organization replaceOrganization, SqlRepo.OrganizationRepository sqlOrganizationRepo, + OrganizationCompare equalityComparer, List suts) + { + var savedOrganizations = new List(); + foreach (var sut in suts) { - var savedOrganizations = new List(); - foreach (var sut in suts) - { - var postEfOrganization = await sut.CreateAsync(postOrganization); - sut.ClearChangeTracking(); + var postEfOrganization = await sut.CreateAsync(postOrganization); + sut.ClearChangeTracking(); - replaceOrganization.Id = postEfOrganization.Id; - await sut.ReplaceAsync(replaceOrganization); - sut.ClearChangeTracking(); + replaceOrganization.Id = postEfOrganization.Id; + await sut.ReplaceAsync(replaceOrganization); + sut.ClearChangeTracking(); - var replacedOrganization = await sut.GetByIdAsync(replaceOrganization.Id); - savedOrganizations.Add(replacedOrganization); - } - - var postSqlOrganization = await sqlOrganizationRepo.CreateAsync(postOrganization); - replaceOrganization.Id = postSqlOrganization.Id; - await sqlOrganizationRepo.ReplaceAsync(replaceOrganization); - savedOrganizations.Add(await sqlOrganizationRepo.GetByIdAsync(replaceOrganization.Id)); - - var distinctItems = savedOrganizations.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var replacedOrganization = await sut.GetByIdAsync(replaceOrganization.Id); + savedOrganizations.Add(replacedOrganization); } - [CiSkippedTheory, EfOrganizationAutoData] - public async void DeleteAsync_Works_DataMatches(Organization organization, - SqlRepo.OrganizationRepository sqlOrganizationRepo, List suts) + var postSqlOrganization = await sqlOrganizationRepo.CreateAsync(postOrganization); + replaceOrganization.Id = postSqlOrganization.Id; + await sqlOrganizationRepo.ReplaceAsync(replaceOrganization); + savedOrganizations.Add(await sqlOrganizationRepo.GetByIdAsync(replaceOrganization.Id)); + + var distinctItems = savedOrganizations.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfOrganizationAutoData] + public async void DeleteAsync_Works_DataMatches(Organization organization, + SqlRepo.OrganizationRepository sqlOrganizationRepo, List suts) + { + foreach (var sut in suts) { - foreach (var sut in suts) - { - var postEfOrganization = await sut.CreateAsync(organization); - sut.ClearChangeTracking(); + var postEfOrganization = await sut.CreateAsync(organization); + sut.ClearChangeTracking(); - var savedEfOrganization = await sut.GetByIdAsync(postEfOrganization.Id); - sut.ClearChangeTracking(); - Assert.True(savedEfOrganization != null); + var savedEfOrganization = await sut.GetByIdAsync(postEfOrganization.Id); + sut.ClearChangeTracking(); + Assert.True(savedEfOrganization != null); - await sut.DeleteAsync(savedEfOrganization); - sut.ClearChangeTracking(); + await sut.DeleteAsync(savedEfOrganization); + sut.ClearChangeTracking(); - savedEfOrganization = await sut.GetByIdAsync(savedEfOrganization.Id); - Assert.True(savedEfOrganization == null); - } - - var postSqlOrganization = await sqlOrganizationRepo.CreateAsync(organization); - var savedSqlOrganization = await sqlOrganizationRepo.GetByIdAsync(postSqlOrganization.Id); - Assert.True(savedSqlOrganization != null); - - await sqlOrganizationRepo.DeleteAsync(postSqlOrganization); - savedSqlOrganization = await sqlOrganizationRepo.GetByIdAsync(postSqlOrganization.Id); - Assert.True(savedSqlOrganization == null); + savedEfOrganization = await sut.GetByIdAsync(savedEfOrganization.Id); + Assert.True(savedEfOrganization == null); } - [CiSkippedTheory, EfOrganizationAutoData] - public async void GetByIdentifierAsync_Works_DataMatches(Organization organization, - SqlRepo.OrganizationRepository sqlOrganizationRepo, OrganizationCompare equalityComparer, - List suts) + var postSqlOrganization = await sqlOrganizationRepo.CreateAsync(organization); + var savedSqlOrganization = await sqlOrganizationRepo.GetByIdAsync(postSqlOrganization.Id); + Assert.True(savedSqlOrganization != null); + + await sqlOrganizationRepo.DeleteAsync(postSqlOrganization); + savedSqlOrganization = await sqlOrganizationRepo.GetByIdAsync(postSqlOrganization.Id); + Assert.True(savedSqlOrganization == null); + } + + [CiSkippedTheory, EfOrganizationAutoData] + public async void GetByIdentifierAsync_Works_DataMatches(Organization organization, + SqlRepo.OrganizationRepository sqlOrganizationRepo, OrganizationCompare equalityComparer, + List suts) + { + var returnedOrgs = new List(); + foreach (var sut in suts) { - var returnedOrgs = new List(); - foreach (var sut in suts) - { - var postEfOrg = await sut.CreateAsync(organization); - sut.ClearChangeTracking(); + var postEfOrg = await sut.CreateAsync(organization); + sut.ClearChangeTracking(); - var returnedOrg = await sut.GetByIdentifierAsync(postEfOrg.Identifier.ToUpperInvariant()); - returnedOrgs.Add(returnedOrg); - } - - var postSqlOrg = await sqlOrganizationRepo.CreateAsync(organization); - returnedOrgs.Add(await sqlOrganizationRepo.GetByIdentifierAsync(postSqlOrg.Identifier.ToUpperInvariant())); - - var distinctItems = returnedOrgs.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var returnedOrg = await sut.GetByIdentifierAsync(postEfOrg.Identifier.ToUpperInvariant()); + returnedOrgs.Add(returnedOrg); } - [CiSkippedTheory, EfOrganizationAutoData] - public async void GetManyByEnabledAsync_Works_DataMatches(Organization organization, - SqlRepo.OrganizationRepository sqlOrganizationRepo, List suts) + var postSqlOrg = await sqlOrganizationRepo.CreateAsync(organization); + returnedOrgs.Add(await sqlOrganizationRepo.GetByIdentifierAsync(postSqlOrg.Identifier.ToUpperInvariant())); + + var distinctItems = returnedOrgs.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfOrganizationAutoData] + public async void GetManyByEnabledAsync_Works_DataMatches(Organization organization, + SqlRepo.OrganizationRepository sqlOrganizationRepo, List suts) + { + var returnedOrgs = new List(); + foreach (var sut in suts) { - var returnedOrgs = new List(); - foreach (var sut in suts) - { - var postEfOrg = await sut.CreateAsync(organization); - sut.ClearChangeTracking(); + var postEfOrg = await sut.CreateAsync(organization); + sut.ClearChangeTracking(); - var efReturnedOrgs = await sut.GetManyByEnabledAsync(); - returnedOrgs.Concat(efReturnedOrgs); - } - - var postSqlOrg = await sqlOrganizationRepo.CreateAsync(organization); - returnedOrgs.Concat(await sqlOrganizationRepo.GetManyByEnabledAsync()); - - Assert.True(returnedOrgs.All(o => o.Enabled)); + var efReturnedOrgs = await sut.GetManyByEnabledAsync(); + returnedOrgs.Concat(efReturnedOrgs); } - // testing data matches here would require manipulating all organization abilities in the db - [CiSkippedTheory, EfOrganizationAutoData] - public async void GetManyAbilitiesAsync_Works(SqlRepo.OrganizationRepository sqlOrganizationRepo, List suts) - { - var list = new List(); - foreach (var sut in suts) - { - list.Concat(await sut.GetManyAbilitiesAsync()); - } + var postSqlOrg = await sqlOrganizationRepo.CreateAsync(organization); + returnedOrgs.Concat(await sqlOrganizationRepo.GetManyByEnabledAsync()); - list.Concat(await sqlOrganizationRepo.GetManyAbilitiesAsync()); - Assert.True(list.All(x => x.GetType() == typeof(OrganizationAbility))); + Assert.True(returnedOrgs.All(o => o.Enabled)); + } + + // testing data matches here would require manipulating all organization abilities in the db + [CiSkippedTheory, EfOrganizationAutoData] + public async void GetManyAbilitiesAsync_Works(SqlRepo.OrganizationRepository sqlOrganizationRepo, List suts) + { + var list = new List(); + foreach (var sut in suts) + { + list.Concat(await sut.GetManyAbilitiesAsync()); } + + list.Concat(await sqlOrganizationRepo.GetManyAbilitiesAsync()); + Assert.True(list.All(x => x.GetType() == typeof(OrganizationAbility))); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationSponsorshipRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationSponsorshipRepositoryTests.cs index 29482df29..ee7d0d271 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationSponsorshipRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationSponsorshipRepositoryTests.cs @@ -6,127 +6,126 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories +namespace Bit.Infrastructure.EFIntegration.Test.Repositories; + +public class OrganizationSponsorshipRepositoryTests { - public class OrganizationSponsorshipRepositoryTests + [CiSkippedTheory, EfOrganizationSponsorshipAutoData] + public async void CreateAsync_Works_DataMatches( + OrganizationSponsorship organizationSponsorship, Organization sponsoringOrg, + List efOrgRepos, + SqlRepo.OrganizationRepository sqlOrganizationRepo, + SqlRepo.OrganizationSponsorshipRepository sqlOrganizationSponsorshipRepo, + OrganizationSponsorshipCompare equalityComparer, + List suts) { - [CiSkippedTheory, EfOrganizationSponsorshipAutoData] - public async void CreateAsync_Works_DataMatches( - OrganizationSponsorship organizationSponsorship, Organization sponsoringOrg, - List efOrgRepos, - SqlRepo.OrganizationRepository sqlOrganizationRepo, - SqlRepo.OrganizationSponsorshipRepository sqlOrganizationSponsorshipRepo, - OrganizationSponsorshipCompare equalityComparer, - List suts) + organizationSponsorship.SponsoredOrganizationId = null; + + var savedOrganizationSponsorships = new List(); + foreach (var (sut, orgRepo) in suts.Zip(efOrgRepos)) { - organizationSponsorship.SponsoredOrganizationId = null; + var efSponsoringOrg = await orgRepo.CreateAsync(sponsoringOrg); + sut.ClearChangeTracking(); + organizationSponsorship.SponsoringOrganizationId = efSponsoringOrg.Id; - var savedOrganizationSponsorships = new List(); - foreach (var (sut, orgRepo) in suts.Zip(efOrgRepos)) - { - var efSponsoringOrg = await orgRepo.CreateAsync(sponsoringOrg); - sut.ClearChangeTracking(); - organizationSponsorship.SponsoringOrganizationId = efSponsoringOrg.Id; + await sut.CreateAsync(organizationSponsorship); + sut.ClearChangeTracking(); - await sut.CreateAsync(organizationSponsorship); - sut.ClearChangeTracking(); - - var savedOrganizationSponsorship = await sut.GetByIdAsync(organizationSponsorship.Id); - savedOrganizationSponsorships.Add(savedOrganizationSponsorship); - } - - var sqlSponsoringOrg = await sqlOrganizationRepo.CreateAsync(sponsoringOrg); - organizationSponsorship.SponsoringOrganizationId = sqlSponsoringOrg.Id; - - var sqlOrganizationSponsorship = await sqlOrganizationSponsorshipRepo.CreateAsync(organizationSponsorship); - savedOrganizationSponsorships.Add(await sqlOrganizationSponsorshipRepo.GetByIdAsync(sqlOrganizationSponsorship.Id)); - - var distinctItems = savedOrganizationSponsorships.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var savedOrganizationSponsorship = await sut.GetByIdAsync(organizationSponsorship.Id); + savedOrganizationSponsorships.Add(savedOrganizationSponsorship); } - [CiSkippedTheory, EfOrganizationSponsorshipAutoData] - public async void ReplaceAsync_Works_DataMatches(OrganizationSponsorship postOrganizationSponsorship, - OrganizationSponsorship replaceOrganizationSponsorship, Organization sponsoringOrg, - List efOrgRepos, - SqlRepo.OrganizationRepository sqlOrganizationRepo, - SqlRepo.OrganizationSponsorshipRepository sqlOrganizationSponsorshipRepo, - OrganizationSponsorshipCompare equalityComparer, List suts) + var sqlSponsoringOrg = await sqlOrganizationRepo.CreateAsync(sponsoringOrg); + organizationSponsorship.SponsoringOrganizationId = sqlSponsoringOrg.Id; + + var sqlOrganizationSponsorship = await sqlOrganizationSponsorshipRepo.CreateAsync(organizationSponsorship); + savedOrganizationSponsorships.Add(await sqlOrganizationSponsorshipRepo.GetByIdAsync(sqlOrganizationSponsorship.Id)); + + var distinctItems = savedOrganizationSponsorships.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfOrganizationSponsorshipAutoData] + public async void ReplaceAsync_Works_DataMatches(OrganizationSponsorship postOrganizationSponsorship, + OrganizationSponsorship replaceOrganizationSponsorship, Organization sponsoringOrg, + List efOrgRepos, + SqlRepo.OrganizationRepository sqlOrganizationRepo, + SqlRepo.OrganizationSponsorshipRepository sqlOrganizationSponsorshipRepo, + OrganizationSponsorshipCompare equalityComparer, List suts) + { + postOrganizationSponsorship.SponsoredOrganizationId = null; + replaceOrganizationSponsorship.SponsoredOrganizationId = null; + + var savedOrganizationSponsorships = new List(); + foreach (var (sut, orgRepo) in suts.Zip(efOrgRepos)) { - postOrganizationSponsorship.SponsoredOrganizationId = null; - replaceOrganizationSponsorship.SponsoredOrganizationId = null; + var efSponsoringOrg = await orgRepo.CreateAsync(sponsoringOrg); + sut.ClearChangeTracking(); + postOrganizationSponsorship.SponsoringOrganizationId = efSponsoringOrg.Id; + replaceOrganizationSponsorship.SponsoringOrganizationId = efSponsoringOrg.Id; - var savedOrganizationSponsorships = new List(); - foreach (var (sut, orgRepo) in suts.Zip(efOrgRepos)) - { - var efSponsoringOrg = await orgRepo.CreateAsync(sponsoringOrg); - sut.ClearChangeTracking(); - postOrganizationSponsorship.SponsoringOrganizationId = efSponsoringOrg.Id; - replaceOrganizationSponsorship.SponsoringOrganizationId = efSponsoringOrg.Id; + var postEfOrganizationSponsorship = await sut.CreateAsync(postOrganizationSponsorship); + sut.ClearChangeTracking(); - var postEfOrganizationSponsorship = await sut.CreateAsync(postOrganizationSponsorship); - sut.ClearChangeTracking(); + replaceOrganizationSponsorship.Id = postEfOrganizationSponsorship.Id; + await sut.ReplaceAsync(replaceOrganizationSponsorship); + sut.ClearChangeTracking(); - replaceOrganizationSponsorship.Id = postEfOrganizationSponsorship.Id; - await sut.ReplaceAsync(replaceOrganizationSponsorship); - sut.ClearChangeTracking(); - - var replacedOrganizationSponsorship = await sut.GetByIdAsync(replaceOrganizationSponsorship.Id); - savedOrganizationSponsorships.Add(replacedOrganizationSponsorship); - } - - var sqlSponsoringOrg = await sqlOrganizationRepo.CreateAsync(sponsoringOrg); - postOrganizationSponsorship.SponsoringOrganizationId = sqlSponsoringOrg.Id; - - var postSqlOrganization = await sqlOrganizationSponsorshipRepo.CreateAsync(postOrganizationSponsorship); - replaceOrganizationSponsorship.Id = postSqlOrganization.Id; - await sqlOrganizationSponsorshipRepo.ReplaceAsync(replaceOrganizationSponsorship); - savedOrganizationSponsorships.Add(await sqlOrganizationSponsorshipRepo.GetByIdAsync(replaceOrganizationSponsorship.Id)); - - var distinctItems = savedOrganizationSponsorships.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var replacedOrganizationSponsorship = await sut.GetByIdAsync(replaceOrganizationSponsorship.Id); + savedOrganizationSponsorships.Add(replacedOrganizationSponsorship); } - [CiSkippedTheory, EfOrganizationSponsorshipAutoData] - public async void DeleteAsync_Works_DataMatches(OrganizationSponsorship organizationSponsorship, - Organization sponsoringOrg, - List efOrgRepos, - SqlRepo.OrganizationRepository sqlOrganizationRepo, - SqlRepo.OrganizationSponsorshipRepository sqlOrganizationSponsorshipRepo, - List suts) + var sqlSponsoringOrg = await sqlOrganizationRepo.CreateAsync(sponsoringOrg); + postOrganizationSponsorship.SponsoringOrganizationId = sqlSponsoringOrg.Id; + + var postSqlOrganization = await sqlOrganizationSponsorshipRepo.CreateAsync(postOrganizationSponsorship); + replaceOrganizationSponsorship.Id = postSqlOrganization.Id; + await sqlOrganizationSponsorshipRepo.ReplaceAsync(replaceOrganizationSponsorship); + savedOrganizationSponsorships.Add(await sqlOrganizationSponsorshipRepo.GetByIdAsync(replaceOrganizationSponsorship.Id)); + + var distinctItems = savedOrganizationSponsorships.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfOrganizationSponsorshipAutoData] + public async void DeleteAsync_Works_DataMatches(OrganizationSponsorship organizationSponsorship, + Organization sponsoringOrg, + List efOrgRepos, + SqlRepo.OrganizationRepository sqlOrganizationRepo, + SqlRepo.OrganizationSponsorshipRepository sqlOrganizationSponsorshipRepo, + List suts) + { + organizationSponsorship.SponsoredOrganizationId = null; + + foreach (var (sut, orgRepo) in suts.Zip(efOrgRepos)) { - organizationSponsorship.SponsoredOrganizationId = null; + var efSponsoringOrg = await orgRepo.CreateAsync(sponsoringOrg); + sut.ClearChangeTracking(); + organizationSponsorship.SponsoringOrganizationId = efSponsoringOrg.Id; - foreach (var (sut, orgRepo) in suts.Zip(efOrgRepos)) - { - var efSponsoringOrg = await orgRepo.CreateAsync(sponsoringOrg); - sut.ClearChangeTracking(); - organizationSponsorship.SponsoringOrganizationId = efSponsoringOrg.Id; + var postEfOrganizationSponsorship = await sut.CreateAsync(organizationSponsorship); + sut.ClearChangeTracking(); - var postEfOrganizationSponsorship = await sut.CreateAsync(organizationSponsorship); - sut.ClearChangeTracking(); + var savedEfOrganizationSponsorship = await sut.GetByIdAsync(postEfOrganizationSponsorship.Id); + sut.ClearChangeTracking(); + Assert.True(savedEfOrganizationSponsorship != null); - var savedEfOrganizationSponsorship = await sut.GetByIdAsync(postEfOrganizationSponsorship.Id); - sut.ClearChangeTracking(); - Assert.True(savedEfOrganizationSponsorship != null); + await sut.DeleteAsync(savedEfOrganizationSponsorship); + sut.ClearChangeTracking(); - await sut.DeleteAsync(savedEfOrganizationSponsorship); - sut.ClearChangeTracking(); - - savedEfOrganizationSponsorship = await sut.GetByIdAsync(savedEfOrganizationSponsorship.Id); - Assert.True(savedEfOrganizationSponsorship == null); - } - - var sqlSponsoringOrg = await sqlOrganizationRepo.CreateAsync(sponsoringOrg); - organizationSponsorship.SponsoringOrganizationId = sqlSponsoringOrg.Id; - - var postSqlOrganizationSponsorship = await sqlOrganizationSponsorshipRepo.CreateAsync(organizationSponsorship); - var savedSqlOrganizationSponsorship = await sqlOrganizationSponsorshipRepo.GetByIdAsync(postSqlOrganizationSponsorship.Id); - Assert.True(savedSqlOrganizationSponsorship != null); - - await sqlOrganizationSponsorshipRepo.DeleteAsync(postSqlOrganizationSponsorship); - savedSqlOrganizationSponsorship = await sqlOrganizationSponsorshipRepo.GetByIdAsync(postSqlOrganizationSponsorship.Id); - Assert.True(savedSqlOrganizationSponsorship == null); + savedEfOrganizationSponsorship = await sut.GetByIdAsync(savedEfOrganizationSponsorship.Id); + Assert.True(savedEfOrganizationSponsorship == null); } + + var sqlSponsoringOrg = await sqlOrganizationRepo.CreateAsync(sponsoringOrg); + organizationSponsorship.SponsoringOrganizationId = sqlSponsoringOrg.Id; + + var postSqlOrganizationSponsorship = await sqlOrganizationSponsorshipRepo.CreateAsync(organizationSponsorship); + var savedSqlOrganizationSponsorship = await sqlOrganizationSponsorshipRepo.GetByIdAsync(postSqlOrganizationSponsorship.Id); + Assert.True(savedSqlOrganizationSponsorship != null); + + await sqlOrganizationSponsorshipRepo.DeleteAsync(postSqlOrganizationSponsorship); + savedSqlOrganizationSponsorship = await sqlOrganizationSponsorshipRepo.GetByIdAsync(postSqlOrganizationSponsorship.Id); + Assert.True(savedSqlOrganizationSponsorship == null); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationUserRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationUserRepositoryTests.cs index 34f1c6f4b..2becc0fc6 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationUserRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/OrganizationUserRepositoryTests.cs @@ -7,142 +7,141 @@ using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using OrganizationUser = Bit.Core.Entities.OrganizationUser; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories +namespace Bit.Infrastructure.EFIntegration.Test.Repositories; + +public class OrganizationUserRepositoryTests { - public class OrganizationUserRepositoryTests + [CiSkippedTheory, EfOrganizationUserAutoData] + public async void CreateAsync_Works_DataMatches(OrganizationUser orgUser, User user, Organization org, + OrganizationUserCompare equalityComparer, List suts, + List efOrgRepos, List efUserRepos, + SqlRepo.OrganizationUserRepository sqlOrgUserRepo, SqlRepo.UserRepository sqlUserRepo, + SqlRepo.OrganizationRepository sqlOrgRepo) { - [CiSkippedTheory, EfOrganizationUserAutoData] - public async void CreateAsync_Works_DataMatches(OrganizationUser orgUser, User user, Organization org, - OrganizationUserCompare equalityComparer, List suts, - List efOrgRepos, List efUserRepos, - SqlRepo.OrganizationUserRepository sqlOrgUserRepo, SqlRepo.UserRepository sqlUserRepo, - SqlRepo.OrganizationRepository sqlOrgRepo) + var savedOrgUsers = new List(); + foreach (var sut in suts) { - var savedOrgUsers = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); - var postEfUser = await efUserRepos[i].CreateAsync(user); - var postEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + var i = suts.IndexOf(sut); + var postEfUser = await efUserRepos[i].CreateAsync(user); + var postEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - orgUser.UserId = postEfUser.Id; - orgUser.OrganizationId = postEfOrg.Id; - var postEfOrgUser = await sut.CreateAsync(orgUser); - sut.ClearChangeTracking(); + orgUser.UserId = postEfUser.Id; + orgUser.OrganizationId = postEfOrg.Id; + var postEfOrgUser = await sut.CreateAsync(orgUser); + sut.ClearChangeTracking(); - var savedOrgUser = await sut.GetByIdAsync(postEfOrgUser.Id); - savedOrgUsers.Add(savedOrgUser); - } - - var postSqlUser = await sqlUserRepo.CreateAsync(user); - var postSqlOrg = await sqlOrgRepo.CreateAsync(org); - - orgUser.UserId = postSqlUser.Id; - orgUser.OrganizationId = postSqlOrg.Id; - var sqlOrgUser = await sqlOrgUserRepo.CreateAsync(orgUser); - - var savedSqlOrgUser = await sqlOrgUserRepo.GetByIdAsync(sqlOrgUser.Id); - savedOrgUsers.Add(savedSqlOrgUser); - - var distinctItems = savedOrgUsers.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var savedOrgUser = await sut.GetByIdAsync(postEfOrgUser.Id); + savedOrgUsers.Add(savedOrgUser); } - [CiSkippedTheory, EfOrganizationUserAutoData] - public async void ReplaceAsync_Works_DataMatches( - OrganizationUser postOrgUser, - OrganizationUser replaceOrgUser, - User user, - Organization org, - OrganizationUserCompare equalityComparer, - List suts, - List efUserRepos, - List efOrgRepos, - SqlRepo.OrganizationUserRepository sqlOrgUserRepo, - SqlRepo.UserRepository sqlUserRepo, - SqlRepo.OrganizationRepository sqlOrgRepo - ) + var postSqlUser = await sqlUserRepo.CreateAsync(user); + var postSqlOrg = await sqlOrgRepo.CreateAsync(org); + + orgUser.UserId = postSqlUser.Id; + orgUser.OrganizationId = postSqlOrg.Id; + var sqlOrgUser = await sqlOrgUserRepo.CreateAsync(orgUser); + + var savedSqlOrgUser = await sqlOrgUserRepo.GetByIdAsync(sqlOrgUser.Id); + savedOrgUsers.Add(savedSqlOrgUser); + + var distinctItems = savedOrgUsers.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfOrganizationUserAutoData] + public async void ReplaceAsync_Works_DataMatches( + OrganizationUser postOrgUser, + OrganizationUser replaceOrgUser, + User user, + Organization org, + OrganizationUserCompare equalityComparer, + List suts, + List efUserRepos, + List efOrgRepos, + SqlRepo.OrganizationUserRepository sqlOrgUserRepo, + SqlRepo.UserRepository sqlUserRepo, + SqlRepo.OrganizationRepository sqlOrgRepo + ) + { + var savedOrgUsers = new List(); + foreach (var sut in suts) { - var savedOrgUsers = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); - var postEfUser = await efUserRepos[i].CreateAsync(user); - var postEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + var i = suts.IndexOf(sut); + var postEfUser = await efUserRepos[i].CreateAsync(user); + var postEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - postOrgUser.UserId = replaceOrgUser.UserId = postEfUser.Id; - postOrgUser.OrganizationId = replaceOrgUser.OrganizationId = postEfOrg.Id; - var postEfOrgUser = await sut.CreateAsync(postOrgUser); - sut.ClearChangeTracking(); + postOrgUser.UserId = replaceOrgUser.UserId = postEfUser.Id; + postOrgUser.OrganizationId = replaceOrgUser.OrganizationId = postEfOrg.Id; + var postEfOrgUser = await sut.CreateAsync(postOrgUser); + sut.ClearChangeTracking(); - replaceOrgUser.Id = postOrgUser.Id; - await sut.ReplaceAsync(replaceOrgUser); - sut.ClearChangeTracking(); + replaceOrgUser.Id = postOrgUser.Id; + await sut.ReplaceAsync(replaceOrgUser); + sut.ClearChangeTracking(); - var replacedOrganizationUser = await sut.GetByIdAsync(replaceOrgUser.Id); - savedOrgUsers.Add(replacedOrganizationUser); - } - - var postSqlUser = await sqlUserRepo.CreateAsync(user); - var postSqlOrg = await sqlOrgRepo.CreateAsync(org); - - postOrgUser.UserId = replaceOrgUser.UserId = postSqlUser.Id; - postOrgUser.OrganizationId = replaceOrgUser.OrganizationId = postSqlOrg.Id; - var postSqlOrgUser = await sqlOrgUserRepo.CreateAsync(postOrgUser); - - replaceOrgUser.Id = postSqlOrgUser.Id; - await sqlOrgUserRepo.ReplaceAsync(replaceOrgUser); - - var replacedSqlUser = await sqlOrgUserRepo.GetByIdAsync(replaceOrgUser.Id); - - var distinctItems = savedOrgUsers.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var replacedOrganizationUser = await sut.GetByIdAsync(replaceOrgUser.Id); + savedOrgUsers.Add(replacedOrganizationUser); } - [CiSkippedTheory, EfOrganizationUserAutoData] - public async void DeleteAsync_Works_DataMatches(OrganizationUser orgUser, User user, Organization org, List suts, - List efUserRepos, List efOrgRepos, - SqlRepo.OrganizationUserRepository sqlOrgUserRepo, SqlRepo.UserRepository sqlUserRepo, - SqlRepo.OrganizationRepository sqlOrgRepo) + var postSqlUser = await sqlUserRepo.CreateAsync(user); + var postSqlOrg = await sqlOrgRepo.CreateAsync(org); + + postOrgUser.UserId = replaceOrgUser.UserId = postSqlUser.Id; + postOrgUser.OrganizationId = replaceOrgUser.OrganizationId = postSqlOrg.Id; + var postSqlOrgUser = await sqlOrgUserRepo.CreateAsync(postOrgUser); + + replaceOrgUser.Id = postSqlOrgUser.Id; + await sqlOrgUserRepo.ReplaceAsync(replaceOrgUser); + + var replacedSqlUser = await sqlOrgUserRepo.GetByIdAsync(replaceOrgUser.Id); + + var distinctItems = savedOrgUsers.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfOrganizationUserAutoData] + public async void DeleteAsync_Works_DataMatches(OrganizationUser orgUser, User user, Organization org, List suts, + List efUserRepos, List efOrgRepos, + SqlRepo.OrganizationUserRepository sqlOrgUserRepo, SqlRepo.UserRepository sqlUserRepo, + SqlRepo.OrganizationRepository sqlOrgRepo) + { + foreach (var sut in suts) { - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); - var postEfUser = await efUserRepos[i].CreateAsync(user); - var postEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + var i = suts.IndexOf(sut); + var postEfUser = await efUserRepos[i].CreateAsync(user); + var postEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - orgUser.UserId = postEfUser.Id; - orgUser.OrganizationId = postEfOrg.Id; - var postEfOrgUser = await sut.CreateAsync(orgUser); - sut.ClearChangeTracking(); + orgUser.UserId = postEfUser.Id; + orgUser.OrganizationId = postEfOrg.Id; + var postEfOrgUser = await sut.CreateAsync(orgUser); + sut.ClearChangeTracking(); - var savedEfOrgUser = await sut.GetByIdAsync(postEfOrgUser.Id); - Assert.True(savedEfOrgUser != null); - sut.ClearChangeTracking(); + var savedEfOrgUser = await sut.GetByIdAsync(postEfOrgUser.Id); + Assert.True(savedEfOrgUser != null); + sut.ClearChangeTracking(); - await sut.DeleteAsync(savedEfOrgUser); - sut.ClearChangeTracking(); + await sut.DeleteAsync(savedEfOrgUser); + sut.ClearChangeTracking(); - savedEfOrgUser = await sut.GetByIdAsync(savedEfOrgUser.Id); - Assert.True(savedEfOrgUser == null); - } - - var postSqlUser = await sqlUserRepo.CreateAsync(user); - var postSqlOrg = await sqlOrgRepo.CreateAsync(org); - - orgUser.UserId = postSqlUser.Id; - orgUser.OrganizationId = postSqlOrg.Id; - var postSqlOrgUser = await sqlOrgUserRepo.CreateAsync(orgUser); - - var savedSqlOrgUser = await sqlOrgUserRepo.GetByIdAsync(postSqlOrgUser.Id); - Assert.True(savedSqlOrgUser != null); - - await sqlOrgUserRepo.DeleteAsync(postSqlOrgUser); - savedSqlOrgUser = await sqlOrgUserRepo.GetByIdAsync(postSqlOrgUser.Id); - Assert.True(savedSqlOrgUser == null); + savedEfOrgUser = await sut.GetByIdAsync(savedEfOrgUser.Id); + Assert.True(savedEfOrgUser == null); } + + var postSqlUser = await sqlUserRepo.CreateAsync(user); + var postSqlOrg = await sqlOrgRepo.CreateAsync(org); + + orgUser.UserId = postSqlUser.Id; + orgUser.OrganizationId = postSqlOrg.Id; + var postSqlOrgUser = await sqlOrgUserRepo.CreateAsync(orgUser); + + var savedSqlOrgUser = await sqlOrgUserRepo.GetByIdAsync(postSqlOrgUser.Id); + Assert.True(savedSqlOrgUser != null); + + await sqlOrgUserRepo.DeleteAsync(postSqlOrgUser); + savedSqlOrgUser = await sqlOrgUserRepo.GetByIdAsync(postSqlOrgUser.Id); + Assert.True(savedSqlOrgUser == null); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/PolicyRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/PolicyRepositoryTests.cs index d013de430..18a2676cd 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/PolicyRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/PolicyRepositoryTests.cs @@ -12,185 +12,184 @@ using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using Policy = Bit.Core.Entities.Policy; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories +namespace Bit.Infrastructure.EFIntegration.Test.Repositories; + +public class PolicyRepositoryTests { - public class PolicyRepositoryTests + [CiSkippedTheory, EfPolicyAutoData] + public async void CreateAsync_Works_DataMatches( + Policy policy, + Organization organization, + PolicyCompare equalityComparer, + List suts, + List efOrganizationRepos, + SqlRepo.PolicyRepository sqlPolicyRepo, + SqlRepo.OrganizationRepository sqlOrganizationRepo + ) { - [CiSkippedTheory, EfPolicyAutoData] - public async void CreateAsync_Works_DataMatches( - Policy policy, - Organization organization, - PolicyCompare equalityComparer, - List suts, - List efOrganizationRepos, - SqlRepo.PolicyRepository sqlPolicyRepo, - SqlRepo.OrganizationRepository sqlOrganizationRepo - ) + var savedPolicys = new List(); + foreach (var sut in suts) { - var savedPolicys = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); + var i = suts.IndexOf(sut); - var efOrganization = await efOrganizationRepos[i].CreateAsync(organization); - sut.ClearChangeTracking(); + var efOrganization = await efOrganizationRepos[i].CreateAsync(organization); + sut.ClearChangeTracking(); - policy.OrganizationId = efOrganization.Id; - var postEfPolicy = await sut.CreateAsync(policy); - sut.ClearChangeTracking(); + policy.OrganizationId = efOrganization.Id; + var postEfPolicy = await sut.CreateAsync(policy); + sut.ClearChangeTracking(); - var savedPolicy = await sut.GetByIdAsync(postEfPolicy.Id); - savedPolicys.Add(savedPolicy); - } - - var sqlOrganization = await sqlOrganizationRepo.CreateAsync(organization); - - policy.OrganizationId = sqlOrganization.Id; - var sqlPolicy = await sqlPolicyRepo.CreateAsync(policy); - var savedSqlPolicy = await sqlPolicyRepo.GetByIdAsync(sqlPolicy.Id); - savedPolicys.Add(savedSqlPolicy); - - var distinctItems = savedPolicys.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var savedPolicy = await sut.GetByIdAsync(postEfPolicy.Id); + savedPolicys.Add(savedPolicy); } - [CiSkippedTheory] - [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Confirmed, false, true, true, false)] // Ordinary user - [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Invited, true, true, true, false)] // Invited user - [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.Owner, false, OrganizationUserStatusType.Confirmed, false, true, true, false)] // Owner - [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.Admin, false, OrganizationUserStatusType.Confirmed, false, true, true, false)] // Admin - [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, true, OrganizationUserStatusType.Confirmed, false, true, true, false)] // canManagePolicies - [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Confirmed, false, true, true, true)] // Provider - [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Confirmed, false, false, true, false)] // Policy disabled - [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Confirmed, false, true, false, false)] // No policy of Type - [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Invited, false, true, true, false)] // User not minStatus + var sqlOrganization = await sqlOrganizationRepo.CreateAsync(organization); - public async void GetManyByTypeApplicableToUser_Works_DataMatches( - // Inline data - OrganizationUserType userType, - bool canManagePolicies, - OrganizationUserStatusType orgUserStatus, - bool includeInvited, - bool policyEnabled, - bool policySameType, - bool isProvider, + policy.OrganizationId = sqlOrganization.Id; + var sqlPolicy = await sqlPolicyRepo.CreateAsync(policy); + var savedSqlPolicy = await sqlPolicyRepo.GetByIdAsync(sqlPolicy.Id); + savedPolicys.Add(savedSqlPolicy); - // Auto data - models - Policy policy, - User user, - Organization organization, - OrganizationUser orgUser, - Provider provider, - ProviderOrganization providerOrganization, - ProviderUser providerUser, - PolicyCompareIncludingOrganization equalityComparer, + var distinctItems = savedPolicys.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } - // Auto data - EF repos - List suts, - List efUserRepository, - List efOrganizationRepository, - List efOrganizationUserRepository, - List efProviderRepository, - List efProviderOrganizationRepository, - List efProviderUserRepository, + [CiSkippedTheory] + [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Confirmed, false, true, true, false)] // Ordinary user + [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Invited, true, true, true, false)] // Invited user + [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.Owner, false, OrganizationUserStatusType.Confirmed, false, true, true, false)] // Owner + [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.Admin, false, OrganizationUserStatusType.Confirmed, false, true, true, false)] // Admin + [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, true, OrganizationUserStatusType.Confirmed, false, true, true, false)] // canManagePolicies + [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Confirmed, false, true, true, true)] // Provider + [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Confirmed, false, false, true, false)] // Policy disabled + [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Confirmed, false, true, false, false)] // No policy of Type + [EfPolicyApplicableToUserInlineAutoData(OrganizationUserType.User, false, OrganizationUserStatusType.Invited, false, true, true, false)] // User not minStatus - // Auto data - SQL repos - SqlRepo.PolicyRepository sqlPolicyRepo, - SqlRepo.UserRepository sqlUserRepo, - SqlRepo.OrganizationRepository sqlOrganizationRepo, - SqlRepo.ProviderRepository sqlProviderRepo, - SqlRepo.OrganizationUserRepository sqlOrganizationUserRepo, - SqlRepo.ProviderOrganizationRepository sqlProviderOrganizationRepo, - SqlRepo.ProviderUserRepository sqlProviderUserRepo - ) + public async void GetManyByTypeApplicableToUser_Works_DataMatches( + // Inline data + OrganizationUserType userType, + bool canManagePolicies, + OrganizationUserStatusType orgUserStatus, + bool includeInvited, + bool policyEnabled, + bool policySameType, + bool isProvider, + + // Auto data - models + Policy policy, + User user, + Organization organization, + OrganizationUser orgUser, + Provider provider, + ProviderOrganization providerOrganization, + ProviderUser providerUser, + PolicyCompareIncludingOrganization equalityComparer, + + // Auto data - EF repos + List suts, + List efUserRepository, + List efOrganizationRepository, + List efOrganizationUserRepository, + List efProviderRepository, + List efProviderOrganizationRepository, + List efProviderUserRepository, + + // Auto data - SQL repos + SqlRepo.PolicyRepository sqlPolicyRepo, + SqlRepo.UserRepository sqlUserRepo, + SqlRepo.OrganizationRepository sqlOrganizationRepo, + SqlRepo.ProviderRepository sqlProviderRepo, + SqlRepo.OrganizationUserRepository sqlOrganizationUserRepo, + SqlRepo.ProviderOrganizationRepository sqlProviderOrganizationRepo, + SqlRepo.ProviderUserRepository sqlProviderUserRepo + ) + { + // Combine EF and SQL repos into one list per type + var policyRepos = suts.ToList(); + policyRepos.Add(sqlPolicyRepo); + var userRepos = efUserRepository.ToList(); + userRepos.Add(sqlUserRepo); + var orgRepos = efOrganizationRepository.ToList(); + orgRepos.Add(sqlOrganizationRepo); + var orgUserRepos = efOrganizationUserRepository.ToList(); + orgUserRepos.Add(sqlOrganizationUserRepo); + var providerRepos = efProviderRepository.ToList(); + providerRepos.Add(sqlProviderRepo); + var providerOrgRepos = efProviderOrganizationRepository.ToList(); + providerOrgRepos.Add(sqlProviderOrganizationRepo); + var providerUserRepos = efProviderUserRepository.ToList(); + providerUserRepos.Add(sqlProviderUserRepo); + + // Arrange data + var savedPolicyType = PolicyType.SingleOrg; + var queriedPolicyType = policySameType ? savedPolicyType : PolicyType.DisableSend; + + orgUser.Type = userType; + orgUser.Status = orgUserStatus; + var permissionsData = new Permissions { ManagePolicies = canManagePolicies }; + orgUser.Permissions = JsonSerializer.Serialize(permissionsData, new JsonSerializerOptions { - // Combine EF and SQL repos into one list per type - var policyRepos = suts.ToList(); - policyRepos.Add(sqlPolicyRepo); - var userRepos = efUserRepository.ToList(); - userRepos.Add(sqlUserRepo); - var orgRepos = efOrganizationRepository.ToList(); - orgRepos.Add(sqlOrganizationRepo); - var orgUserRepos = efOrganizationUserRepository.ToList(); - orgUserRepos.Add(sqlOrganizationUserRepo); - var providerRepos = efProviderRepository.ToList(); - providerRepos.Add(sqlProviderRepo); - var providerOrgRepos = efProviderOrganizationRepository.ToList(); - providerOrgRepos.Add(sqlProviderOrganizationRepo); - var providerUserRepos = efProviderUserRepository.ToList(); - providerUserRepos.Add(sqlProviderUserRepo); + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + }); - // Arrange data - var savedPolicyType = PolicyType.SingleOrg; - var queriedPolicyType = policySameType ? savedPolicyType : PolicyType.DisableSend; + policy.Enabled = policyEnabled; + policy.Type = savedPolicyType; - orgUser.Type = userType; - orgUser.Status = orgUserStatus; - var permissionsData = new Permissions { ManagePolicies = canManagePolicies }; - orgUser.Permissions = JsonSerializer.Serialize(permissionsData, new JsonSerializerOptions + var results = new List(); + + foreach (var policyRepo in policyRepos) + { + var i = policyRepos.IndexOf(policyRepo); + + // Seed database + var savedUser = await userRepos[i].CreateAsync(user); + var savedOrg = await orgRepos[i].CreateAsync(organization); + + // Invited orgUsers are not associated with an account yet, so they are identified by Email not UserId + if (orgUserStatus == OrganizationUserStatusType.Invited) { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - }); - - policy.Enabled = policyEnabled; - policy.Type = savedPolicyType; - - var results = new List(); - - foreach (var policyRepo in policyRepos) + orgUser.Email = savedUser.Email; + orgUser.UserId = null; + } + else { - var i = policyRepos.IndexOf(policyRepo); - - // Seed database - var savedUser = await userRepos[i].CreateAsync(user); - var savedOrg = await orgRepos[i].CreateAsync(organization); - - // Invited orgUsers are not associated with an account yet, so they are identified by Email not UserId - if (orgUserStatus == OrganizationUserStatusType.Invited) - { - orgUser.Email = savedUser.Email; - orgUser.UserId = null; - } - else - { - orgUser.UserId = savedUser.Id; - } - - orgUser.OrganizationId = savedOrg.Id; - await orgUserRepos[i].CreateAsync(orgUser); - - if (isProvider) - { - var savedProvider = await providerRepos[i].CreateAsync(provider); - - providerOrganization.OrganizationId = savedOrg.Id; - providerOrganization.ProviderId = savedProvider.Id; - await providerOrgRepos[i].CreateAsync(providerOrganization); - - providerUser.UserId = savedUser.Id; - providerUser.ProviderId = savedProvider.Id; - await providerUserRepos[i].CreateAsync(providerUser); - } - - policy.OrganizationId = savedOrg.Id; - await policyRepo.CreateAsync(policy); - if (suts.Contains(policyRepo)) - { - (policyRepo as EfRepo.BaseEntityFrameworkRepository).ClearChangeTracking(); - } - - var minStatus = includeInvited ? OrganizationUserStatusType.Invited : OrganizationUserStatusType.Accepted; - - // Act - var result = await policyRepo.GetManyByTypeApplicableToUserIdAsync(savedUser.Id, queriedPolicyType, minStatus); - results.Add(result.FirstOrDefault()); + orgUser.UserId = savedUser.Id; } - // Assert - var distinctItems = results.Distinct(equalityComparer); + orgUser.OrganizationId = savedOrg.Id; + await orgUserRepos[i].CreateAsync(orgUser); - Assert.True(results.All(r => r == null) || - !distinctItems.Skip(1).Any()); + if (isProvider) + { + var savedProvider = await providerRepos[i].CreateAsync(provider); + + providerOrganization.OrganizationId = savedOrg.Id; + providerOrganization.ProviderId = savedProvider.Id; + await providerOrgRepos[i].CreateAsync(providerOrganization); + + providerUser.UserId = savedUser.Id; + providerUser.ProviderId = savedProvider.Id; + await providerUserRepos[i].CreateAsync(providerUser); + } + + policy.OrganizationId = savedOrg.Id; + await policyRepo.CreateAsync(policy); + if (suts.Contains(policyRepo)) + { + (policyRepo as EfRepo.BaseEntityFrameworkRepository).ClearChangeTracking(); + } + + var minStatus = includeInvited ? OrganizationUserStatusType.Invited : OrganizationUserStatusType.Accepted; + + // Act + var result = await policyRepo.GetManyByTypeApplicableToUserIdAsync(savedUser.Id, queriedPolicyType, minStatus); + results.Add(result.FirstOrDefault()); } + + // Assert + var distinctItems = results.Distinct(equalityComparer); + + Assert.True(results.All(r => r == null) || + !distinctItems.Skip(1).Any()); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/SendRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/SendRepositoryTests.cs index 6158be3ee..628b5562c 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/SendRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/SendRepositoryTests.cs @@ -6,60 +6,59 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories +namespace Bit.Infrastructure.EFIntegration.Test.Repositories; + +public class SendRepositoryTests { - public class SendRepositoryTests + [CiSkippedTheory, EfUserSendAutoData, EfOrganizationSendAutoData] + public async void CreateAsync_Works_DataMatches( + Send send, + User user, + Organization org, + SendCompare equalityComparer, + List suts, + List efUserRepos, + List efOrgRepos, + SqlRepo.SendRepository sqlSendRepo, + SqlRepo.UserRepository sqlUserRepo, + SqlRepo.OrganizationRepository sqlOrgRepo + ) { - [CiSkippedTheory, EfUserSendAutoData, EfOrganizationSendAutoData] - public async void CreateAsync_Works_DataMatches( - Send send, - User user, - Organization org, - SendCompare equalityComparer, - List suts, - List efUserRepos, - List efOrgRepos, - SqlRepo.SendRepository sqlSendRepo, - SqlRepo.UserRepository sqlUserRepo, - SqlRepo.OrganizationRepository sqlOrgRepo - ) + var savedSends = new List(); + foreach (var sut in suts) { - var savedSends = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); + var i = suts.IndexOf(sut); - if (send.OrganizationId.HasValue) - { - var efOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); - send.OrganizationId = efOrg.Id; - } - var efUser = await efUserRepos[i].CreateAsync(user); - sut.ClearChangeTracking(); - - send.UserId = efUser.Id; - var postEfSend = await sut.CreateAsync(send); - sut.ClearChangeTracking(); - - var savedSend = await sut.GetByIdAsync(postEfSend.Id); - savedSends.Add(savedSend); - } - - var sqlUser = await sqlUserRepo.CreateAsync(user); if (send.OrganizationId.HasValue) { - var sqlOrg = await sqlOrgRepo.CreateAsync(org); - send.OrganizationId = sqlOrg.Id; + var efOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); + send.OrganizationId = efOrg.Id; } + var efUser = await efUserRepos[i].CreateAsync(user); + sut.ClearChangeTracking(); - send.UserId = sqlUser.Id; - var sqlSend = await sqlSendRepo.CreateAsync(send); - var savedSqlSend = await sqlSendRepo.GetByIdAsync(sqlSend.Id); - savedSends.Add(savedSqlSend); + send.UserId = efUser.Id; + var postEfSend = await sut.CreateAsync(send); + sut.ClearChangeTracking(); - var distinctItems = savedSends.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var savedSend = await sut.GetByIdAsync(postEfSend.Id); + savedSends.Add(savedSend); } + + var sqlUser = await sqlUserRepo.CreateAsync(user); + if (send.OrganizationId.HasValue) + { + var sqlOrg = await sqlOrgRepo.CreateAsync(org); + send.OrganizationId = sqlOrg.Id; + } + + send.UserId = sqlUser.Id; + var sqlSend = await sqlSendRepo.CreateAsync(send); + var savedSqlSend = await sqlSendRepo.GetByIdAsync(sqlSend.Id); + savedSends.Add(savedSqlSend); + + var distinctItems = savedSends.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/SsoConfigRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/SsoConfigRepositoryTests.cs index c36c9efb4..7858bc1f0 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/SsoConfigRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/SsoConfigRepositoryTests.cs @@ -6,222 +6,221 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories +namespace Bit.Infrastructure.EFIntegration.Test.Repositories; + +public class SsoConfigRepositoryTests { - public class SsoConfigRepositoryTests + [CiSkippedTheory, EfSsoConfigAutoData] + public async void CreateAsync_Works_DataMatches(SsoConfig ssoConfig, Organization org, + SsoConfigCompare equalityComparer, List suts, + List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, + SqlRepo.OrganizationRepository sqlOrganizationRepo) { - [CiSkippedTheory, EfSsoConfigAutoData] - public async void CreateAsync_Works_DataMatches(SsoConfig ssoConfig, Organization org, - SsoConfigCompare equalityComparer, List suts, - List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, - SqlRepo.OrganizationRepository sqlOrganizationRepo) + var savedSsoConfigs = new List(); + + foreach (var sut in suts) { - var savedSsoConfigs = new List(); + var i = suts.IndexOf(sut); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); + var savedEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - var savedEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + ssoConfig.OrganizationId = savedEfOrg.Id; + var postEfSsoConfig = await sut.CreateAsync(ssoConfig); + sut.ClearChangeTracking(); - ssoConfig.OrganizationId = savedEfOrg.Id; - var postEfSsoConfig = await sut.CreateAsync(ssoConfig); - sut.ClearChangeTracking(); - - var savedEfSsoConfig = await sut.GetByIdAsync(ssoConfig.Id); - Assert.True(savedEfSsoConfig != null); - savedSsoConfigs.Add(savedEfSsoConfig); - } - - var sqlOrganization = await sqlOrganizationRepo.CreateAsync(org); - ssoConfig.OrganizationId = sqlOrganization.Id; - - var sqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(ssoConfig); - var savedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdAsync(sqlSsoConfig.Id); - Assert.True(savedSqlSsoConfig != null); - savedSsoConfigs.Add(savedSqlSsoConfig); - - var distinctItems = savedSsoConfigs.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var savedEfSsoConfig = await sut.GetByIdAsync(ssoConfig.Id); + Assert.True(savedEfSsoConfig != null); + savedSsoConfigs.Add(savedEfSsoConfig); } - [CiSkippedTheory, EfSsoConfigAutoData] - public async void ReplaceAsync_Works_DataMatches(SsoConfig postSsoConfig, SsoConfig replaceSsoConfig, - Organization org, SsoConfigCompare equalityComparer, List suts, - List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, - SqlRepo.OrganizationRepository sqlOrganizationRepo) + var sqlOrganization = await sqlOrganizationRepo.CreateAsync(org); + ssoConfig.OrganizationId = sqlOrganization.Id; + + var sqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(ssoConfig); + var savedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdAsync(sqlSsoConfig.Id); + Assert.True(savedSqlSsoConfig != null); + savedSsoConfigs.Add(savedSqlSsoConfig); + + var distinctItems = savedSsoConfigs.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfSsoConfigAutoData] + public async void ReplaceAsync_Works_DataMatches(SsoConfig postSsoConfig, SsoConfig replaceSsoConfig, + Organization org, SsoConfigCompare equalityComparer, List suts, + List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, + SqlRepo.OrganizationRepository sqlOrganizationRepo) + { + var savedSsoConfigs = new List(); + + foreach (var sut in suts) { - var savedSsoConfigs = new List(); + var i = suts.IndexOf(sut); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); + var savedEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - var savedEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + postSsoConfig.OrganizationId = replaceSsoConfig.OrganizationId = savedEfOrg.Id; + var postEfSsoConfig = await sut.CreateAsync(postSsoConfig); + sut.ClearChangeTracking(); - postSsoConfig.OrganizationId = replaceSsoConfig.OrganizationId = savedEfOrg.Id; - var postEfSsoConfig = await sut.CreateAsync(postSsoConfig); - sut.ClearChangeTracking(); + replaceSsoConfig.Id = postEfSsoConfig.Id; + savedSsoConfigs.Add(postEfSsoConfig); + await sut.ReplaceAsync(replaceSsoConfig); + sut.ClearChangeTracking(); - replaceSsoConfig.Id = postEfSsoConfig.Id; - savedSsoConfigs.Add(postEfSsoConfig); - await sut.ReplaceAsync(replaceSsoConfig); - sut.ClearChangeTracking(); - - var replacedSsoConfig = await sut.GetByIdAsync(replaceSsoConfig.Id); - Assert.True(replacedSsoConfig != null); - savedSsoConfigs.Add(replacedSsoConfig); - } - - var sqlOrganization = await sqlOrganizationRepo.CreateAsync(org); - postSsoConfig.OrganizationId = sqlOrganization.Id; - - var postSqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(postSsoConfig); - replaceSsoConfig.Id = postSqlSsoConfig.Id; - savedSsoConfigs.Add(postSqlSsoConfig); - - await sqlSsoConfigRepo.ReplaceAsync(replaceSsoConfig); - var replacedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdAsync(replaceSsoConfig.Id); - Assert.True(replacedSqlSsoConfig != null); - savedSsoConfigs.Add(replacedSqlSsoConfig); - - var distinctItems = savedSsoConfigs.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(2).Any()); + var replacedSsoConfig = await sut.GetByIdAsync(replaceSsoConfig.Id); + Assert.True(replacedSsoConfig != null); + savedSsoConfigs.Add(replacedSsoConfig); } - [CiSkippedTheory, EfSsoConfigAutoData] - public async void DeleteAsync_Works_DataMatches(SsoConfig ssoConfig, Organization org, List suts, - List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, - SqlRepo.OrganizationRepository sqlOrganizationRepo) + var sqlOrganization = await sqlOrganizationRepo.CreateAsync(org); + postSsoConfig.OrganizationId = sqlOrganization.Id; + + var postSqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(postSsoConfig); + replaceSsoConfig.Id = postSqlSsoConfig.Id; + savedSsoConfigs.Add(postSqlSsoConfig); + + await sqlSsoConfigRepo.ReplaceAsync(replaceSsoConfig); + var replacedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdAsync(replaceSsoConfig.Id); + Assert.True(replacedSqlSsoConfig != null); + savedSsoConfigs.Add(replacedSqlSsoConfig); + + var distinctItems = savedSsoConfigs.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(2).Any()); + } + + [CiSkippedTheory, EfSsoConfigAutoData] + public async void DeleteAsync_Works_DataMatches(SsoConfig ssoConfig, Organization org, List suts, + List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, + SqlRepo.OrganizationRepository sqlOrganizationRepo) + { + foreach (var sut in suts) { - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); + var i = suts.IndexOf(sut); - var savedEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + var savedEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - ssoConfig.OrganizationId = savedEfOrg.Id; - var postEfSsoConfig = await sut.CreateAsync(ssoConfig); - sut.ClearChangeTracking(); + ssoConfig.OrganizationId = savedEfOrg.Id; + var postEfSsoConfig = await sut.CreateAsync(ssoConfig); + sut.ClearChangeTracking(); - var savedEfSsoConfig = await sut.GetByIdAsync(postEfSsoConfig.Id); - Assert.True(savedEfSsoConfig != null); - sut.ClearChangeTracking(); + var savedEfSsoConfig = await sut.GetByIdAsync(postEfSsoConfig.Id); + Assert.True(savedEfSsoConfig != null); + sut.ClearChangeTracking(); - await sut.DeleteAsync(savedEfSsoConfig); - var deletedEfSsoConfig = await sut.GetByIdAsync(savedEfSsoConfig.Id); - Assert.True(deletedEfSsoConfig == null); - } - - var sqlOrganization = await sqlOrganizationRepo.CreateAsync(org); - ssoConfig.OrganizationId = sqlOrganization.Id; - - var postSqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(ssoConfig); - var savedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdAsync(postSqlSsoConfig.Id); - Assert.True(savedSqlSsoConfig != null); - - await sqlSsoConfigRepo.DeleteAsync(savedSqlSsoConfig); - savedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdAsync(postSqlSsoConfig.Id); - Assert.True(savedSqlSsoConfig == null); + await sut.DeleteAsync(savedEfSsoConfig); + var deletedEfSsoConfig = await sut.GetByIdAsync(savedEfSsoConfig.Id); + Assert.True(deletedEfSsoConfig == null); } - [CiSkippedTheory, EfSsoConfigAutoData] - public async void GetByOrganizationIdAsync_Works_DataMatches(SsoConfig ssoConfig, Organization org, - SsoConfigCompare equalityComparer, List suts, - List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, - SqlRepo.OrganizationRepository sqlOrgRepo) + var sqlOrganization = await sqlOrganizationRepo.CreateAsync(org); + ssoConfig.OrganizationId = sqlOrganization.Id; + + var postSqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(ssoConfig); + var savedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdAsync(postSqlSsoConfig.Id); + Assert.True(savedSqlSsoConfig != null); + + await sqlSsoConfigRepo.DeleteAsync(savedSqlSsoConfig); + savedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdAsync(postSqlSsoConfig.Id); + Assert.True(savedSqlSsoConfig == null); + } + + [CiSkippedTheory, EfSsoConfigAutoData] + public async void GetByOrganizationIdAsync_Works_DataMatches(SsoConfig ssoConfig, Organization org, + SsoConfigCompare equalityComparer, List suts, + List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, + SqlRepo.OrganizationRepository sqlOrgRepo) + { + var returnedList = new List(); + + foreach (var sut in suts) { - var returnedList = new List(); + var i = suts.IndexOf(sut); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); + var savedEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - var savedEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + ssoConfig.OrganizationId = savedEfOrg.Id; + await sut.CreateAsync(ssoConfig); + sut.ClearChangeTracking(); - ssoConfig.OrganizationId = savedEfOrg.Id; - await sut.CreateAsync(ssoConfig); - sut.ClearChangeTracking(); - - var savedEfUser = await sut.GetByOrganizationIdAsync(savedEfOrg.Id); - Assert.True(savedEfUser != null); - returnedList.Add(savedEfUser); - } - - var savedSqlOrg = await sqlOrgRepo.CreateAsync(org); - ssoConfig.OrganizationId = savedSqlOrg.Id; - - var postSqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(ssoConfig); - - var savedSqlSsoConfig = await sqlSsoConfigRepo.GetByOrganizationIdAsync(ssoConfig.OrganizationId); - Assert.True(savedSqlSsoConfig != null); - returnedList.Add(savedSqlSsoConfig); - - var distinctItems = returnedList.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var savedEfUser = await sut.GetByOrganizationIdAsync(savedEfOrg.Id); + Assert.True(savedEfUser != null); + returnedList.Add(savedEfUser); } - [CiSkippedTheory, EfSsoConfigAutoData] - public async void GetByIdentifierAsync_Works_DataMatches(SsoConfig ssoConfig, Organization org, - SsoConfigCompare equalityComparer, List suts, - List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, - SqlRepo.OrganizationRepository sqlOrgRepo) + var savedSqlOrg = await sqlOrgRepo.CreateAsync(org); + ssoConfig.OrganizationId = savedSqlOrg.Id; + + var postSqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(ssoConfig); + + var savedSqlSsoConfig = await sqlSsoConfigRepo.GetByOrganizationIdAsync(ssoConfig.OrganizationId); + Assert.True(savedSqlSsoConfig != null); + returnedList.Add(savedSqlSsoConfig); + + var distinctItems = returnedList.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfSsoConfigAutoData] + public async void GetByIdentifierAsync_Works_DataMatches(SsoConfig ssoConfig, Organization org, + SsoConfigCompare equalityComparer, List suts, + List efOrgRepos, SqlRepo.SsoConfigRepository sqlSsoConfigRepo, + SqlRepo.OrganizationRepository sqlOrgRepo) + { + var returnedList = new List(); + + foreach (var sut in suts) { - var returnedList = new List(); + var i = suts.IndexOf(sut); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); + var savedEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - var savedEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + ssoConfig.OrganizationId = savedEfOrg.Id; + await sut.CreateAsync(ssoConfig); + sut.ClearChangeTracking(); - ssoConfig.OrganizationId = savedEfOrg.Id; - await sut.CreateAsync(ssoConfig); - sut.ClearChangeTracking(); - - var savedEfSsoConfig = await sut.GetByIdentifierAsync(org.Identifier); - Assert.True(savedEfSsoConfig != null); - returnedList.Add(savedEfSsoConfig); - } - - var savedSqlOrg = await sqlOrgRepo.CreateAsync(org); - ssoConfig.OrganizationId = savedSqlOrg.Id; - - var postSqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(ssoConfig); - - var savedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdentifierAsync(org.Identifier); - Assert.True(savedSqlSsoConfig != null); - returnedList.Add(savedSqlSsoConfig); - - var distinctItems = returnedList.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var savedEfSsoConfig = await sut.GetByIdentifierAsync(org.Identifier); + Assert.True(savedEfSsoConfig != null); + returnedList.Add(savedEfSsoConfig); } - // Testing that data matches here would involve manipulating all SsoConfig records in the db - [CiSkippedTheory, EfSsoConfigAutoData] - public async void GetManyByRevisionNotBeforeDate_Works(SsoConfig ssoConfig, DateTime notBeforeDate, - Organization org, List suts, - List efOrgRepos) + var savedSqlOrg = await sqlOrgRepo.CreateAsync(org); + ssoConfig.OrganizationId = savedSqlOrg.Id; + + var postSqlSsoConfig = await sqlSsoConfigRepo.CreateAsync(ssoConfig); + + var savedSqlSsoConfig = await sqlSsoConfigRepo.GetByIdentifierAsync(org.Identifier); + Assert.True(savedSqlSsoConfig != null); + returnedList.Add(savedSqlSsoConfig); + + var distinctItems = returnedList.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } + + // Testing that data matches here would involve manipulating all SsoConfig records in the db + [CiSkippedTheory, EfSsoConfigAutoData] + public async void GetManyByRevisionNotBeforeDate_Works(SsoConfig ssoConfig, DateTime notBeforeDate, + Organization org, List suts, + List efOrgRepos) + { + foreach (var sut in suts) { - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); + var i = suts.IndexOf(sut); - var savedEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + var savedEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - ssoConfig.OrganizationId = savedEfOrg.Id; - await sut.CreateAsync(ssoConfig); - sut.ClearChangeTracking(); + ssoConfig.OrganizationId = savedEfOrg.Id; + await sut.CreateAsync(ssoConfig); + sut.ClearChangeTracking(); - var returnedEfSsoConfigs = await sut.GetManyByRevisionNotBeforeDate(notBeforeDate); - Assert.True(returnedEfSsoConfigs.All(sc => sc.RevisionDate >= notBeforeDate)); - } + var returnedEfSsoConfigs = await sut.GetManyByRevisionNotBeforeDate(notBeforeDate); + Assert.True(returnedEfSsoConfigs.All(sc => sc.RevisionDate >= notBeforeDate)); } } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/SsoUserRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/SsoUserRepositoryTests.cs index 9e9b66eea..bc43a0526 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/SsoUserRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/SsoUserRepositoryTests.cs @@ -6,182 +6,181 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories +namespace Bit.Infrastructure.EFIntegration.Test.Repositories; + +public class SsoUserRepositoryTests { - public class SsoUserRepositoryTests + [CiSkippedTheory, EfSsoUserAutoData] + public async void CreateAsync_Works_DataMatches(SsoUser ssoUser, User user, Organization org, + SsoUserCompare equalityComparer, List suts, + List efOrgRepos, List efUserRepos, + SqlRepo.SsoUserRepository sqlSsoUserRepo, SqlRepo.OrganizationRepository sqlOrgRepo, + SqlRepo.UserRepository sqlUserRepo) { - [CiSkippedTheory, EfSsoUserAutoData] - public async void CreateAsync_Works_DataMatches(SsoUser ssoUser, User user, Organization org, - SsoUserCompare equalityComparer, List suts, - List efOrgRepos, List efUserRepos, - SqlRepo.SsoUserRepository sqlSsoUserRepo, SqlRepo.OrganizationRepository sqlOrgRepo, - SqlRepo.UserRepository sqlUserRepo) + var createdSsoUsers = new List(); + foreach (var sut in suts) { - var createdSsoUsers = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); + var i = suts.IndexOf(sut); - var efUser = await efUserRepos[i].CreateAsync(user); - var efOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + var efUser = await efUserRepos[i].CreateAsync(user); + var efOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - ssoUser.UserId = efUser.Id; - ssoUser.OrganizationId = efOrg.Id; - var postEfSsoUser = await sut.CreateAsync(ssoUser); - sut.ClearChangeTracking(); + ssoUser.UserId = efUser.Id; + ssoUser.OrganizationId = efOrg.Id; + var postEfSsoUser = await sut.CreateAsync(ssoUser); + sut.ClearChangeTracking(); - var savedSsoUser = await sut.GetByIdAsync(ssoUser.Id); - createdSsoUsers.Add(savedSsoUser); - } - - var sqlUser = await sqlUserRepo.CreateAsync(user); - var sqlOrganization = await sqlOrgRepo.CreateAsync(org); - - ssoUser.UserId = sqlUser.Id; - ssoUser.OrganizationId = sqlOrganization.Id; - var sqlSsoUser = await sqlSsoUserRepo.CreateAsync(ssoUser); - - createdSsoUsers.Add(await sqlSsoUserRepo.GetByIdAsync(sqlSsoUser.Id)); - - var distinctSsoUsers = createdSsoUsers.Distinct(equalityComparer); - Assert.True(!distinctSsoUsers.Skip(1).Any()); + var savedSsoUser = await sut.GetByIdAsync(ssoUser.Id); + createdSsoUsers.Add(savedSsoUser); } - [CiSkippedTheory, EfSsoUserAutoData] - public async void ReplaceAsync_Works_DataMatches(SsoUser postSsoUser, SsoUser replaceSsoUser, - Organization org, User user, SsoUserCompare equalityComparer, - List suts, List efUserRepos, - List efOrgRepos, SqlRepo.SsoUserRepository sqlSsoUserRepo, - SqlRepo.OrganizationRepository sqlOrgRepo, SqlRepo.UserRepository sqlUserRepo) + var sqlUser = await sqlUserRepo.CreateAsync(user); + var sqlOrganization = await sqlOrgRepo.CreateAsync(org); + + ssoUser.UserId = sqlUser.Id; + ssoUser.OrganizationId = sqlOrganization.Id; + var sqlSsoUser = await sqlSsoUserRepo.CreateAsync(ssoUser); + + createdSsoUsers.Add(await sqlSsoUserRepo.GetByIdAsync(sqlSsoUser.Id)); + + var distinctSsoUsers = createdSsoUsers.Distinct(equalityComparer); + Assert.True(!distinctSsoUsers.Skip(1).Any()); + } + + [CiSkippedTheory, EfSsoUserAutoData] + public async void ReplaceAsync_Works_DataMatches(SsoUser postSsoUser, SsoUser replaceSsoUser, + Organization org, User user, SsoUserCompare equalityComparer, + List suts, List efUserRepos, + List efOrgRepos, SqlRepo.SsoUserRepository sqlSsoUserRepo, + SqlRepo.OrganizationRepository sqlOrgRepo, SqlRepo.UserRepository sqlUserRepo) + { + var savedSsoUsers = new List(); + foreach (var sut in suts) { - var savedSsoUsers = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); + var i = suts.IndexOf(sut); - var efUser = await efUserRepos[i].CreateAsync(user); - var efOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + var efUser = await efUserRepos[i].CreateAsync(user); + var efOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - postSsoUser.UserId = efUser.Id; - postSsoUser.OrganizationId = efOrg.Id; - var postEfSsoUser = await sut.CreateAsync(postSsoUser); - sut.ClearChangeTracking(); + postSsoUser.UserId = efUser.Id; + postSsoUser.OrganizationId = efOrg.Id; + var postEfSsoUser = await sut.CreateAsync(postSsoUser); + sut.ClearChangeTracking(); - replaceSsoUser.Id = postEfSsoUser.Id; - replaceSsoUser.UserId = postEfSsoUser.UserId; - replaceSsoUser.OrganizationId = postEfSsoUser.OrganizationId; - await sut.ReplaceAsync(replaceSsoUser); - sut.ClearChangeTracking(); + replaceSsoUser.Id = postEfSsoUser.Id; + replaceSsoUser.UserId = postEfSsoUser.UserId; + replaceSsoUser.OrganizationId = postEfSsoUser.OrganizationId; + await sut.ReplaceAsync(replaceSsoUser); + sut.ClearChangeTracking(); - var replacedSsoUser = await sut.GetByIdAsync(replaceSsoUser.Id); - savedSsoUsers.Add(replacedSsoUser); - } - - var sqlUser = await sqlUserRepo.CreateAsync(user); - var sqlOrganization = await sqlOrgRepo.CreateAsync(org); - - postSsoUser.UserId = sqlUser.Id; - postSsoUser.OrganizationId = sqlOrganization.Id; - var postSqlSsoUser = await sqlSsoUserRepo.CreateAsync(postSsoUser); - - replaceSsoUser.Id = postSqlSsoUser.Id; - replaceSsoUser.UserId = postSqlSsoUser.UserId; - replaceSsoUser.OrganizationId = postSqlSsoUser.OrganizationId; - await sqlSsoUserRepo.ReplaceAsync(replaceSsoUser); - - savedSsoUsers.Add(await sqlSsoUserRepo.GetByIdAsync(replaceSsoUser.Id)); - - var distinctItems = savedSsoUsers.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var replacedSsoUser = await sut.GetByIdAsync(replaceSsoUser.Id); + savedSsoUsers.Add(replacedSsoUser); } - [CiSkippedTheory, EfSsoUserAutoData] - public async void DeleteAsync_Works_DataMatches(SsoUser ssoUser, Organization org, User user, List suts, - List efUserRepos, List efOrgRepos, - SqlRepo.SsoUserRepository sqlSsoUserRepo, SqlRepo.UserRepository sqlUserRepo, - SqlRepo.OrganizationRepository sqlOrganizationRepo) + var sqlUser = await sqlUserRepo.CreateAsync(user); + var sqlOrganization = await sqlOrgRepo.CreateAsync(org); + + postSsoUser.UserId = sqlUser.Id; + postSsoUser.OrganizationId = sqlOrganization.Id; + var postSqlSsoUser = await sqlSsoUserRepo.CreateAsync(postSsoUser); + + replaceSsoUser.Id = postSqlSsoUser.Id; + replaceSsoUser.UserId = postSqlSsoUser.UserId; + replaceSsoUser.OrganizationId = postSqlSsoUser.OrganizationId; + await sqlSsoUserRepo.ReplaceAsync(replaceSsoUser); + + savedSsoUsers.Add(await sqlSsoUserRepo.GetByIdAsync(replaceSsoUser.Id)); + + var distinctItems = savedSsoUsers.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfSsoUserAutoData] + public async void DeleteAsync_Works_DataMatches(SsoUser ssoUser, Organization org, User user, List suts, + List efUserRepos, List efOrgRepos, + SqlRepo.SsoUserRepository sqlSsoUserRepo, SqlRepo.UserRepository sqlUserRepo, + SqlRepo.OrganizationRepository sqlOrganizationRepo) + { + foreach (var sut in suts) { - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); + var i = suts.IndexOf(sut); - var savedEfUser = await efUserRepos[i].CreateAsync(user); - var savedEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + var savedEfUser = await efUserRepos[i].CreateAsync(user); + var savedEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - ssoUser.UserId = savedEfUser.Id; - ssoUser.OrganizationId = savedEfOrg.Id; - var postEfSsoUser = await sut.CreateAsync(ssoUser); - sut.ClearChangeTracking(); + ssoUser.UserId = savedEfUser.Id; + ssoUser.OrganizationId = savedEfOrg.Id; + var postEfSsoUser = await sut.CreateAsync(ssoUser); + sut.ClearChangeTracking(); - var savedEfSsoUser = await sut.GetByIdAsync(postEfSsoUser.Id); - Assert.True(savedEfSsoUser != null); - sut.ClearChangeTracking(); + var savedEfSsoUser = await sut.GetByIdAsync(postEfSsoUser.Id); + Assert.True(savedEfSsoUser != null); + sut.ClearChangeTracking(); - await sut.DeleteAsync(savedEfSsoUser); - savedEfSsoUser = await sut.GetByIdAsync(savedEfSsoUser.Id); - Assert.True(savedEfSsoUser == null); - } - - var sqlUser = await sqlUserRepo.CreateAsync(user); - var sqlOrganization = await sqlOrganizationRepo.CreateAsync(org); - ssoUser.UserId = sqlUser.Id; - ssoUser.OrganizationId = sqlOrganization.Id; - - var postSqlSsoUser = await sqlSsoUserRepo.CreateAsync(ssoUser); - var savedSqlSsoUser = await sqlSsoUserRepo.GetByIdAsync(postSqlSsoUser.Id); - Assert.True(savedSqlSsoUser != null); - - await sqlSsoUserRepo.DeleteAsync(savedSqlSsoUser); - savedSqlSsoUser = await sqlSsoUserRepo.GetByIdAsync(postSqlSsoUser.Id); - Assert.True(savedSqlSsoUser == null); + await sut.DeleteAsync(savedEfSsoUser); + savedEfSsoUser = await sut.GetByIdAsync(savedEfSsoUser.Id); + Assert.True(savedEfSsoUser == null); } - [CiSkippedTheory, EfSsoUserAutoData] - public async void DeleteAsync_UserIdOrganizationId_Works_DataMatches(SsoUser ssoUser, - User user, Organization org, List suts, - List efUserRepos, List efOrgRepos, - SqlRepo.SsoUserRepository sqlSsoUserRepo, SqlRepo.UserRepository sqlUserRepo, SqlRepo.OrganizationRepository sqlOrgRepo - ) + var sqlUser = await sqlUserRepo.CreateAsync(user); + var sqlOrganization = await sqlOrganizationRepo.CreateAsync(org); + ssoUser.UserId = sqlUser.Id; + ssoUser.OrganizationId = sqlOrganization.Id; + + var postSqlSsoUser = await sqlSsoUserRepo.CreateAsync(ssoUser); + var savedSqlSsoUser = await sqlSsoUserRepo.GetByIdAsync(postSqlSsoUser.Id); + Assert.True(savedSqlSsoUser != null); + + await sqlSsoUserRepo.DeleteAsync(savedSqlSsoUser); + savedSqlSsoUser = await sqlSsoUserRepo.GetByIdAsync(postSqlSsoUser.Id); + Assert.True(savedSqlSsoUser == null); + } + + [CiSkippedTheory, EfSsoUserAutoData] + public async void DeleteAsync_UserIdOrganizationId_Works_DataMatches(SsoUser ssoUser, + User user, Organization org, List suts, + List efUserRepos, List efOrgRepos, + SqlRepo.SsoUserRepository sqlSsoUserRepo, SqlRepo.UserRepository sqlUserRepo, SqlRepo.OrganizationRepository sqlOrgRepo + ) + { + foreach (var sut in suts) { - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); + var i = suts.IndexOf(sut); - var savedEfUser = await efUserRepos[i].CreateAsync(user); - var savedEfOrg = await efOrgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); + var savedEfUser = await efUserRepos[i].CreateAsync(user); + var savedEfOrg = await efOrgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); - ssoUser.UserId = savedEfUser.Id; - ssoUser.OrganizationId = savedEfOrg.Id; - var postEfSsoUser = await sut.CreateAsync(ssoUser); - sut.ClearChangeTracking(); + ssoUser.UserId = savedEfUser.Id; + ssoUser.OrganizationId = savedEfOrg.Id; + var postEfSsoUser = await sut.CreateAsync(ssoUser); + sut.ClearChangeTracking(); - var savedEfSsoUser = await sut.GetByIdAsync(postEfSsoUser.Id); - Assert.True(savedEfSsoUser != null); - sut.ClearChangeTracking(); + var savedEfSsoUser = await sut.GetByIdAsync(postEfSsoUser.Id); + Assert.True(savedEfSsoUser != null); + sut.ClearChangeTracking(); - await sut.DeleteAsync(savedEfSsoUser.UserId, savedEfSsoUser.OrganizationId); - sut.ClearChangeTracking(); + await sut.DeleteAsync(savedEfSsoUser.UserId, savedEfSsoUser.OrganizationId); + sut.ClearChangeTracking(); - savedEfSsoUser = await sut.GetByIdAsync(savedEfSsoUser.Id); - Assert.True(savedEfSsoUser == null); - } - - var sqlUser = await sqlUserRepo.CreateAsync(user); - var sqlOrganization = await sqlOrgRepo.CreateAsync(org); - ssoUser.UserId = sqlUser.Id; - ssoUser.OrganizationId = sqlOrganization.Id; - - var postSqlSsoUser = await sqlSsoUserRepo.CreateAsync(ssoUser); - var savedSqlSsoUser = await sqlSsoUserRepo.GetByIdAsync(postSqlSsoUser.Id); - Assert.True(savedSqlSsoUser != null); - - await sqlSsoUserRepo.DeleteAsync(savedSqlSsoUser.UserId, savedSqlSsoUser.OrganizationId); - savedSqlSsoUser = await sqlSsoUserRepo.GetByIdAsync(postSqlSsoUser.Id); - Assert.True(savedSqlSsoUser == null); + savedEfSsoUser = await sut.GetByIdAsync(savedEfSsoUser.Id); + Assert.True(savedEfSsoUser == null); } + + var sqlUser = await sqlUserRepo.CreateAsync(user); + var sqlOrganization = await sqlOrgRepo.CreateAsync(org); + ssoUser.UserId = sqlUser.Id; + ssoUser.OrganizationId = sqlOrganization.Id; + + var postSqlSsoUser = await sqlSsoUserRepo.CreateAsync(ssoUser); + var savedSqlSsoUser = await sqlSsoUserRepo.GetByIdAsync(postSqlSsoUser.Id); + Assert.True(savedSqlSsoUser != null); + + await sqlSsoUserRepo.DeleteAsync(savedSqlSsoUser.UserId, savedSqlSsoUser.OrganizationId); + savedSqlSsoUser = await sqlSsoUserRepo.GetByIdAsync(postSqlSsoUser.Id); + Assert.True(savedSqlSsoUser == null); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/TaxRateRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/TaxRateRepositoryTests.cs index d5616f78e..8892f6c70 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/TaxRateRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/TaxRateRepositoryTests.cs @@ -6,35 +6,34 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories +namespace Bit.Infrastructure.EFIntegration.Test.Repositories; + +public class TaxRateRepositoryTests { - public class TaxRateRepositoryTests + [CiSkippedTheory, EfTaxRateAutoData] + public async void CreateAsync_Works_DataMatches( + TaxRate taxRate, + TaxRateCompare equalityComparer, + List suts, + SqlRepo.TaxRateRepository sqlTaxRateRepo + ) { - [CiSkippedTheory, EfTaxRateAutoData] - public async void CreateAsync_Works_DataMatches( - TaxRate taxRate, - TaxRateCompare equalityComparer, - List suts, - SqlRepo.TaxRateRepository sqlTaxRateRepo - ) + var savedTaxRates = new List(); + foreach (var sut in suts) { - var savedTaxRates = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); - var postEfTaxRate = await sut.CreateAsync(taxRate); - sut.ClearChangeTracking(); + var i = suts.IndexOf(sut); + var postEfTaxRate = await sut.CreateAsync(taxRate); + sut.ClearChangeTracking(); - var savedTaxRate = await sut.GetByIdAsync(postEfTaxRate.Id); - savedTaxRates.Add(savedTaxRate); - } - - var sqlTaxRate = await sqlTaxRateRepo.CreateAsync(taxRate); - var savedSqlTaxRate = await sqlTaxRateRepo.GetByIdAsync(sqlTaxRate.Id); - savedTaxRates.Add(savedSqlTaxRate); - - var distinctItems = savedTaxRates.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var savedTaxRate = await sut.GetByIdAsync(postEfTaxRate.Id); + savedTaxRates.Add(savedTaxRate); } + + var sqlTaxRate = await sqlTaxRateRepo.CreateAsync(taxRate); + var savedSqlTaxRate = await sqlTaxRateRepo.GetByIdAsync(sqlTaxRate.Id); + savedTaxRates.Add(savedSqlTaxRate); + + var distinctItems = savedTaxRates.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/TransactionRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/TransactionRepositoryTests.cs index 563a0377e..2f0d2cd8a 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/TransactionRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/TransactionRepositoryTests.cs @@ -6,59 +6,58 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories +namespace Bit.Infrastructure.EFIntegration.Test.Repositories; + +public class TransactionRepositoryTests { - public class TransactionRepositoryTests + + [CiSkippedTheory, EfUserTransactionAutoData, EfOrganizationTransactionAutoData] + public async void CreateAsync_Works_DataMatches( + Transaction transaction, + User user, + Organization org, + TransactionCompare equalityComparer, + List suts, + List efUserRepos, + List efOrgRepos, + SqlRepo.TransactionRepository sqlTransactionRepo, + SqlRepo.UserRepository sqlUserRepo, + SqlRepo.OrganizationRepository sqlOrgRepo + ) { - - [CiSkippedTheory, EfUserTransactionAutoData, EfOrganizationTransactionAutoData] - public async void CreateAsync_Works_DataMatches( - Transaction transaction, - User user, - Organization org, - TransactionCompare equalityComparer, - List suts, - List efUserRepos, - List efOrgRepos, - SqlRepo.TransactionRepository sqlTransactionRepo, - SqlRepo.UserRepository sqlUserRepo, - SqlRepo.OrganizationRepository sqlOrgRepo - ) + var savedTransactions = new List(); + foreach (var sut in suts) { - var savedTransactions = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); - var efUser = await efUserRepos[i].CreateAsync(user); - if (transaction.OrganizationId.HasValue) - { - var efOrg = await efOrgRepos[i].CreateAsync(org); - transaction.OrganizationId = efOrg.Id; - } - sut.ClearChangeTracking(); - - transaction.UserId = efUser.Id; - var postEfTransaction = await sut.CreateAsync(transaction); - sut.ClearChangeTracking(); - - var savedTransaction = await sut.GetByIdAsync(postEfTransaction.Id); - savedTransactions.Add(savedTransaction); - } - - var sqlUser = await sqlUserRepo.CreateAsync(user); + var i = suts.IndexOf(sut); + var efUser = await efUserRepos[i].CreateAsync(user); if (transaction.OrganizationId.HasValue) { - var sqlOrg = await sqlOrgRepo.CreateAsync(org); - transaction.OrganizationId = sqlOrg.Id; + var efOrg = await efOrgRepos[i].CreateAsync(org); + transaction.OrganizationId = efOrg.Id; } + sut.ClearChangeTracking(); - transaction.UserId = sqlUser.Id; - var sqlTransaction = await sqlTransactionRepo.CreateAsync(transaction); - var savedSqlTransaction = await sqlTransactionRepo.GetByIdAsync(sqlTransaction.Id); - savedTransactions.Add(savedSqlTransaction); + transaction.UserId = efUser.Id; + var postEfTransaction = await sut.CreateAsync(transaction); + sut.ClearChangeTracking(); - var distinctItems = savedTransactions.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var savedTransaction = await sut.GetByIdAsync(postEfTransaction.Id); + savedTransactions.Add(savedTransaction); } + + var sqlUser = await sqlUserRepo.CreateAsync(user); + if (transaction.OrganizationId.HasValue) + { + var sqlOrg = await sqlOrgRepo.CreateAsync(org); + transaction.OrganizationId = sqlOrg.Id; + } + + transaction.UserId = sqlUser.Id; + var sqlTransaction = await sqlTransactionRepo.CreateAsync(transaction); + var savedSqlTransaction = await sqlTransactionRepo.GetByIdAsync(sqlTransaction.Id); + savedTransactions.Add(savedSqlTransaction); + + var distinctItems = savedTransactions.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } } diff --git a/test/Infrastructure.EFIntegration.Test/Repositories/UserRepositoryTests.cs b/test/Infrastructure.EFIntegration.Test/Repositories/UserRepositoryTests.cs index d362cb954..ce04ffdfb 100644 --- a/test/Infrastructure.EFIntegration.Test/Repositories/UserRepositoryTests.cs +++ b/test/Infrastructure.EFIntegration.Test/Repositories/UserRepositoryTests.cs @@ -7,284 +7,283 @@ using Xunit; using EfRepo = Bit.Infrastructure.EntityFramework.Repositories; using SqlRepo = Bit.Infrastructure.Dapper.Repositories; -namespace Bit.Infrastructure.EFIntegration.Test.Repositories +namespace Bit.Infrastructure.EFIntegration.Test.Repositories; + +public class UserRepositoryTests { - public class UserRepositoryTests + [CiSkippedTheory, EfUserAutoData] + public async void CreateAsync_Works_DataMatches( + User user, UserCompare equalityComparer, + List suts, + SqlRepo.UserRepository sqlUserRepo + ) { - [CiSkippedTheory, EfUserAutoData] - public async void CreateAsync_Works_DataMatches( - User user, UserCompare equalityComparer, - List suts, - SqlRepo.UserRepository sqlUserRepo - ) + var savedUsers = new List(); + + foreach (var sut in suts) { - var savedUsers = new List(); + var postEfUser = await sut.CreateAsync(user); - foreach (var sut in suts) - { - var postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); - sut.ClearChangeTracking(); - - var savedUser = await sut.GetByIdAsync(postEfUser.Id); - savedUsers.Add(savedUser); - } - - var sqlUser = await sqlUserRepo.CreateAsync(user); - savedUsers.Add(await sqlUserRepo.GetByIdAsync(sqlUser.Id)); - - var distinctItems = savedUsers.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var savedUser = await sut.GetByIdAsync(postEfUser.Id); + savedUsers.Add(savedUser); } - [CiSkippedTheory, EfUserAutoData] - public async void ReplaceAsync_Works_DataMatches(User postUser, User replaceUser, - UserCompare equalityComparer, List suts, - SqlRepo.UserRepository sqlUserRepo) + var sqlUser = await sqlUserRepo.CreateAsync(user); + savedUsers.Add(await sqlUserRepo.GetByIdAsync(sqlUser.Id)); + + var distinctItems = savedUsers.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfUserAutoData] + public async void ReplaceAsync_Works_DataMatches(User postUser, User replaceUser, + UserCompare equalityComparer, List suts, + SqlRepo.UserRepository sqlUserRepo) + { + var savedUsers = new List(); + foreach (var sut in suts) { - var savedUsers = new List(); - foreach (var sut in suts) - { - var postEfUser = await sut.CreateAsync(postUser); - replaceUser.Id = postEfUser.Id; - await sut.ReplaceAsync(replaceUser); - var replacedUser = await sut.GetByIdAsync(replaceUser.Id); - savedUsers.Add(replacedUser); - } - - var postSqlUser = await sqlUserRepo.CreateAsync(postUser); - replaceUser.Id = postSqlUser.Id; - await sqlUserRepo.ReplaceAsync(replaceUser); - savedUsers.Add(await sqlUserRepo.GetByIdAsync(replaceUser.Id)); - - var distinctItems = savedUsers.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var postEfUser = await sut.CreateAsync(postUser); + replaceUser.Id = postEfUser.Id; + await sut.ReplaceAsync(replaceUser); + var replacedUser = await sut.GetByIdAsync(replaceUser.Id); + savedUsers.Add(replacedUser); } - [CiSkippedTheory, EfUserAutoData] - public async void DeleteAsync_Works_DataMatches(User user, List suts, SqlRepo.UserRepository sqlUserRepo) + var postSqlUser = await sqlUserRepo.CreateAsync(postUser); + replaceUser.Id = postSqlUser.Id; + await sqlUserRepo.ReplaceAsync(replaceUser); + savedUsers.Add(await sqlUserRepo.GetByIdAsync(replaceUser.Id)); + + var distinctItems = savedUsers.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfUserAutoData] + public async void DeleteAsync_Works_DataMatches(User user, List suts, SqlRepo.UserRepository sqlUserRepo) + { + foreach (var sut in suts) { - foreach (var sut in suts) - { - var postEfUser = await sut.CreateAsync(user); - sut.ClearChangeTracking(); + var postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); - var savedEfUser = await sut.GetByIdAsync(postEfUser.Id); - Assert.True(savedEfUser != null); - sut.ClearChangeTracking(); + var savedEfUser = await sut.GetByIdAsync(postEfUser.Id); + Assert.True(savedEfUser != null); + sut.ClearChangeTracking(); - await sut.DeleteAsync(savedEfUser); - sut.ClearChangeTracking(); + await sut.DeleteAsync(savedEfUser); + sut.ClearChangeTracking(); - savedEfUser = await sut.GetByIdAsync(savedEfUser.Id); - Assert.True(savedEfUser == null); - } - - var postSqlUser = await sqlUserRepo.CreateAsync(user); - var savedSqlUser = await sqlUserRepo.GetByIdAsync(postSqlUser.Id); - Assert.True(savedSqlUser != null); - - await sqlUserRepo.DeleteAsync(postSqlUser); - savedSqlUser = await sqlUserRepo.GetByIdAsync(postSqlUser.Id); - Assert.True(savedSqlUser == null); + savedEfUser = await sut.GetByIdAsync(savedEfUser.Id); + Assert.True(savedEfUser == null); } - [CiSkippedTheory, EfUserAutoData] - public async void GetByEmailAsync_Works_DataMatches(User user, UserCompare equalityComparer, - List suts, SqlRepo.UserRepository sqlUserRepo) - { - var savedUsers = new List(); - foreach (var sut in suts) - { - var postEfUser = await sut.CreateAsync(user); - sut.ClearChangeTracking(); - var savedUser = await sut.GetByEmailAsync(postEfUser.Email.ToUpperInvariant()); - savedUsers.Add(savedUser); - } + var postSqlUser = await sqlUserRepo.CreateAsync(user); + var savedSqlUser = await sqlUserRepo.GetByIdAsync(postSqlUser.Id); + Assert.True(savedSqlUser != null); - var postSqlUser = await sqlUserRepo.CreateAsync(user); - savedUsers.Add(await sqlUserRepo.GetByEmailAsync(postSqlUser.Email.ToUpperInvariant())); + await sqlUserRepo.DeleteAsync(postSqlUser); + savedSqlUser = await sqlUserRepo.GetByIdAsync(postSqlUser.Id); + Assert.True(savedSqlUser == null); + } - var distinctItems = savedUsers.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfUserAutoData] - public async void GetKdfInformationByEmailAsync_Works_DataMatches(User user, - UserKdfInformationCompare equalityComparer, List suts, - SqlRepo.UserRepository sqlUserRepo) - { - var savedKdfInformation = new List(); - foreach (var sut in suts) - { - var postEfUser = await sut.CreateAsync(user); - sut.ClearChangeTracking(); - var kdfInformation = await sut.GetKdfInformationByEmailAsync(postEfUser.Email.ToUpperInvariant()); - savedKdfInformation.Add(kdfInformation); - } - - var postSqlUser = await sqlUserRepo.CreateAsync(user); - var sqlKdfInformation = await sqlUserRepo.GetKdfInformationByEmailAsync(postSqlUser.Email); - savedKdfInformation.Add(sqlKdfInformation); - - var distinctItems = savedKdfInformation.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfUserAutoData] - public async void SearchAsync_Works_DataMatches(User user, int skip, int take, - UserCompare equalityCompare, List suts, - SqlRepo.UserRepository sqlUserRepo) - { - var searchedEfUsers = new List(); - foreach (var sut in suts) - { - var postEfUser = await sut.CreateAsync(user); - sut.ClearChangeTracking(); - - var searchedEfUsersCollection = await sut.SearchAsync(postEfUser.Email.ToUpperInvariant(), skip, take); - searchedEfUsers.Concat(searchedEfUsersCollection.ToList()); - } - - var postSqlUser = await sqlUserRepo.CreateAsync(user); - var searchedSqlUsers = await sqlUserRepo.SearchAsync(postSqlUser.Email.ToUpperInvariant(), skip, take); - - var distinctItems = searchedEfUsers.Concat(searchedSqlUsers).Distinct(equalityCompare); - Assert.True(!distinctItems.Skip(1).Any()); - } - - [CiSkippedTheory, EfUserAutoData] - public async void GetManyByPremiumAsync_Works_DataMatches(User user, + [CiSkippedTheory, EfUserAutoData] + public async void GetByEmailAsync_Works_DataMatches(User user, UserCompare equalityComparer, List suts, SqlRepo.UserRepository sqlUserRepo) + { + var savedUsers = new List(); + foreach (var sut in suts) { - var returnedUsers = new List(); - foreach (var sut in suts) - { - var postEfUser = await sut.CreateAsync(user); - sut.ClearChangeTracking(); - - var searchedEfUsers = await sut.GetManyByPremiumAsync(user.Premium); - returnedUsers.Concat(searchedEfUsers.ToList()); - } - - var postSqlUser = await sqlUserRepo.CreateAsync(user); - var searchedSqlUsers = await sqlUserRepo.GetManyByPremiumAsync(user.Premium); - returnedUsers.Concat(searchedSqlUsers.ToList()); - - Assert.True(returnedUsers.All(x => x.Premium == user.Premium)); + var postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); + var savedUser = await sut.GetByEmailAsync(postEfUser.Email.ToUpperInvariant()); + savedUsers.Add(savedUser); } - [CiSkippedTheory, EfUserAutoData] - public async void GetPublicKeyAsync_Works_DataMatches(User user, List suts, - SqlRepo.UserRepository sqlUserRepo) + var postSqlUser = await sqlUserRepo.CreateAsync(user); + savedUsers.Add(await sqlUserRepo.GetByEmailAsync(postSqlUser.Email.ToUpperInvariant())); + + var distinctItems = savedUsers.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfUserAutoData] + public async void GetKdfInformationByEmailAsync_Works_DataMatches(User user, + UserKdfInformationCompare equalityComparer, List suts, + SqlRepo.UserRepository sqlUserRepo) + { + var savedKdfInformation = new List(); + foreach (var sut in suts) { - var returnedKeys = new List(); - foreach (var sut in suts) - { - var postEfUser = await sut.CreateAsync(user); - sut.ClearChangeTracking(); - - var efKey = await sut.GetPublicKeyAsync(postEfUser.Id); - returnedKeys.Add(efKey); - } - - var postSqlUser = await sqlUserRepo.CreateAsync(user); - var sqlKey = await sqlUserRepo.GetPublicKeyAsync(postSqlUser.Id); - returnedKeys.Add(sqlKey); - - Assert.True(!returnedKeys.Distinct().Skip(1).Any()); + var postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); + var kdfInformation = await sut.GetKdfInformationByEmailAsync(postEfUser.Email.ToUpperInvariant()); + savedKdfInformation.Add(kdfInformation); } - [CiSkippedTheory, EfUserAutoData] - public async void GetAccountRevisionDateAsync(User user, List suts, - SqlRepo.UserRepository sqlUserRepo) + var postSqlUser = await sqlUserRepo.CreateAsync(user); + var sqlKdfInformation = await sqlUserRepo.GetKdfInformationByEmailAsync(postSqlUser.Email); + savedKdfInformation.Add(sqlKdfInformation); + + var distinctItems = savedKdfInformation.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfUserAutoData] + public async void SearchAsync_Works_DataMatches(User user, int skip, int take, + UserCompare equalityCompare, List suts, + SqlRepo.UserRepository sqlUserRepo) + { + var searchedEfUsers = new List(); + foreach (var sut in suts) { - var returnedKeys = new List(); - foreach (var sut in suts) - { - var postEfUser = await sut.CreateAsync(user); - sut.ClearChangeTracking(); + var postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); - var efKey = await sut.GetPublicKeyAsync(postEfUser.Id); - returnedKeys.Add(efKey); - } - - var postSqlUser = await sqlUserRepo.CreateAsync(user); - var sqlKey = await sqlUserRepo.GetPublicKeyAsync(postSqlUser.Id); - returnedKeys.Add(sqlKey); - - Assert.True(!returnedKeys.Distinct().Skip(1).Any()); + var searchedEfUsersCollection = await sut.SearchAsync(postEfUser.Email.ToUpperInvariant(), skip, take); + searchedEfUsers.Concat(searchedEfUsersCollection.ToList()); } - [CiSkippedTheory, EfUserAutoData] - public async void UpdateRenewalReminderDateAsync_Works_DataMatches(User user, - DateTime updatedReminderDate, List suts, - SqlRepo.UserRepository sqlUserRepo) + var postSqlUser = await sqlUserRepo.CreateAsync(user); + var searchedSqlUsers = await sqlUserRepo.SearchAsync(postSqlUser.Email.ToUpperInvariant(), skip, take); + + var distinctItems = searchedEfUsers.Concat(searchedSqlUsers).Distinct(equalityCompare); + Assert.True(!distinctItems.Skip(1).Any()); + } + + [CiSkippedTheory, EfUserAutoData] + public async void GetManyByPremiumAsync_Works_DataMatches(User user, + List suts, SqlRepo.UserRepository sqlUserRepo) + { + var returnedUsers = new List(); + foreach (var sut in suts) { - var savedDates = new List(); - foreach (var sut in suts) - { - var postEfUser = user; - postEfUser = await sut.CreateAsync(user); - sut.ClearChangeTracking(); + var postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); - await sut.UpdateRenewalReminderDateAsync(postEfUser.Id, updatedReminderDate); - sut.ClearChangeTracking(); - - var replacedUser = await sut.GetByIdAsync(postEfUser.Id); - savedDates.Add(replacedUser.RenewalReminderDate); - } - - var postSqlUser = await sqlUserRepo.CreateAsync(user); - await sqlUserRepo.UpdateRenewalReminderDateAsync(postSqlUser.Id, updatedReminderDate); - var replacedSqlUser = await sqlUserRepo.GetByIdAsync(postSqlUser.Id); - savedDates.Add(replacedSqlUser.RenewalReminderDate); - - var distinctItems = savedDates.GroupBy(e => e.ToString()); - Assert.True(!distinctItems.Skip(1).Any() && - savedDates.All(e => e.ToString() == updatedReminderDate.ToString())); + var searchedEfUsers = await sut.GetManyByPremiumAsync(user.Premium); + returnedUsers.Concat(searchedEfUsers.ToList()); } - [CiSkippedTheory, EfUserAutoData] - public async void GetBySsoUserAsync_Works_DataMatches(User user, Organization org, - SsoUser ssoUser, UserCompare equalityComparer, List suts, - List ssoUserRepos, List orgRepos, - SqlRepo.UserRepository sqlUserRepo, SqlRepo.SsoUserRepository sqlSsoUserRepo, - SqlRepo.OrganizationRepository sqlOrgRepo) + var postSqlUser = await sqlUserRepo.CreateAsync(user); + var searchedSqlUsers = await sqlUserRepo.GetManyByPremiumAsync(user.Premium); + returnedUsers.Concat(searchedSqlUsers.ToList()); + + Assert.True(returnedUsers.All(x => x.Premium == user.Premium)); + } + + [CiSkippedTheory, EfUserAutoData] + public async void GetPublicKeyAsync_Works_DataMatches(User user, List suts, + SqlRepo.UserRepository sqlUserRepo) + { + var returnedKeys = new List(); + foreach (var sut in suts) { - var returnedList = new List(); - foreach (var sut in suts) - { - var i = suts.IndexOf(sut); + var postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); - var postEfUser = await sut.CreateAsync(user); - sut.ClearChangeTracking(); - - var efOrg = await orgRepos[i].CreateAsync(org); - sut.ClearChangeTracking(); - - ssoUser.UserId = postEfUser.Id; - ssoUser.OrganizationId = efOrg.Id; - var postEfSsoUser = await ssoUserRepos[i].CreateAsync(ssoUser); - sut.ClearChangeTracking(); - - var returnedUser = await sut.GetBySsoUserAsync(postEfSsoUser.ExternalId.ToUpperInvariant(), efOrg.Id); - returnedList.Add(returnedUser); - } - - var sqlUser = await sqlUserRepo.CreateAsync(user); - var sqlOrganization = await sqlOrgRepo.CreateAsync(org); - - ssoUser.UserId = sqlUser.Id; - ssoUser.OrganizationId = sqlOrganization.Id; - var postSqlSsoUser = await sqlSsoUserRepo.CreateAsync(ssoUser); - - var returnedSqlUser = await sqlUserRepo - .GetBySsoUserAsync(postSqlSsoUser.ExternalId, sqlOrganization.Id); - returnedList.Add(returnedSqlUser); - - var distinctItems = returnedList.Distinct(equalityComparer); - Assert.True(!distinctItems.Skip(1).Any()); + var efKey = await sut.GetPublicKeyAsync(postEfUser.Id); + returnedKeys.Add(efKey); } + + var postSqlUser = await sqlUserRepo.CreateAsync(user); + var sqlKey = await sqlUserRepo.GetPublicKeyAsync(postSqlUser.Id); + returnedKeys.Add(sqlKey); + + Assert.True(!returnedKeys.Distinct().Skip(1).Any()); + } + + [CiSkippedTheory, EfUserAutoData] + public async void GetAccountRevisionDateAsync(User user, List suts, + SqlRepo.UserRepository sqlUserRepo) + { + var returnedKeys = new List(); + foreach (var sut in suts) + { + var postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); + + var efKey = await sut.GetPublicKeyAsync(postEfUser.Id); + returnedKeys.Add(efKey); + } + + var postSqlUser = await sqlUserRepo.CreateAsync(user); + var sqlKey = await sqlUserRepo.GetPublicKeyAsync(postSqlUser.Id); + returnedKeys.Add(sqlKey); + + Assert.True(!returnedKeys.Distinct().Skip(1).Any()); + } + + [CiSkippedTheory, EfUserAutoData] + public async void UpdateRenewalReminderDateAsync_Works_DataMatches(User user, + DateTime updatedReminderDate, List suts, + SqlRepo.UserRepository sqlUserRepo) + { + var savedDates = new List(); + foreach (var sut in suts) + { + var postEfUser = user; + postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); + + await sut.UpdateRenewalReminderDateAsync(postEfUser.Id, updatedReminderDate); + sut.ClearChangeTracking(); + + var replacedUser = await sut.GetByIdAsync(postEfUser.Id); + savedDates.Add(replacedUser.RenewalReminderDate); + } + + var postSqlUser = await sqlUserRepo.CreateAsync(user); + await sqlUserRepo.UpdateRenewalReminderDateAsync(postSqlUser.Id, updatedReminderDate); + var replacedSqlUser = await sqlUserRepo.GetByIdAsync(postSqlUser.Id); + savedDates.Add(replacedSqlUser.RenewalReminderDate); + + var distinctItems = savedDates.GroupBy(e => e.ToString()); + Assert.True(!distinctItems.Skip(1).Any() && + savedDates.All(e => e.ToString() == updatedReminderDate.ToString())); + } + + [CiSkippedTheory, EfUserAutoData] + public async void GetBySsoUserAsync_Works_DataMatches(User user, Organization org, + SsoUser ssoUser, UserCompare equalityComparer, List suts, + List ssoUserRepos, List orgRepos, + SqlRepo.UserRepository sqlUserRepo, SqlRepo.SsoUserRepository sqlSsoUserRepo, + SqlRepo.OrganizationRepository sqlOrgRepo) + { + var returnedList = new List(); + foreach (var sut in suts) + { + var i = suts.IndexOf(sut); + + var postEfUser = await sut.CreateAsync(user); + sut.ClearChangeTracking(); + + var efOrg = await orgRepos[i].CreateAsync(org); + sut.ClearChangeTracking(); + + ssoUser.UserId = postEfUser.Id; + ssoUser.OrganizationId = efOrg.Id; + var postEfSsoUser = await ssoUserRepos[i].CreateAsync(ssoUser); + sut.ClearChangeTracking(); + + var returnedUser = await sut.GetBySsoUserAsync(postEfSsoUser.ExternalId.ToUpperInvariant(), efOrg.Id); + returnedList.Add(returnedUser); + } + + var sqlUser = await sqlUserRepo.CreateAsync(user); + var sqlOrganization = await sqlOrgRepo.CreateAsync(org); + + ssoUser.UserId = sqlUser.Id; + ssoUser.OrganizationId = sqlOrganization.Id; + var postSqlSsoUser = await sqlSsoUserRepo.CreateAsync(ssoUser); + + var returnedSqlUser = await sqlUserRepo + .GetBySsoUserAsync(postSqlSsoUser.ExternalId, sqlOrganization.Id); + returnedList.Add(returnedSqlUser); + + var distinctItems = returnedList.Distinct(equalityComparer); + Assert.True(!distinctItems.Skip(1).Any()); } } diff --git a/test/IntegrationTestCommon/Factories/IdentityApplicationFactory.cs b/test/IntegrationTestCommon/Factories/IdentityApplicationFactory.cs index 0a8741b55..501ded613 100644 --- a/test/IntegrationTestCommon/Factories/IdentityApplicationFactory.cs +++ b/test/IntegrationTestCommon/Factories/IdentityApplicationFactory.cs @@ -7,40 +7,39 @@ using Bit.Identity; using Bit.Test.Common.Helpers; using Microsoft.AspNetCore.Http; -namespace Bit.IntegrationTestCommon.Factories +namespace Bit.IntegrationTestCommon.Factories; + +public class IdentityApplicationFactory : WebApplicationFactoryBase { - public class IdentityApplicationFactory : WebApplicationFactoryBase + public const string DefaultDeviceIdentifier = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; + + public async Task RegisterAsync(RegisterRequestModel model) { - public const string DefaultDeviceIdentifier = "92b9d953-b9b6-4eaf-9d3e-11d57144dfeb"; + return await Server.PostAsync("/accounts/register", JsonContent.Create(model)); + } - public async Task RegisterAsync(RegisterRequestModel model) + public async Task<(string Token, string RefreshToken)> TokenFromPasswordAsync(string username, + string password, + string deviceIdentifier = DefaultDeviceIdentifier, + string clientId = "web", + DeviceType deviceType = DeviceType.FirefoxBrowser, + string deviceName = "firefox") + { + var context = await Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary { - return await Server.PostAsync("/accounts/register", JsonContent.Create(model)); - } + { "scope", "api offline_access" }, + { "client_id", clientId }, + { "deviceType", ((int)deviceType).ToString() }, + { "deviceIdentifier", deviceIdentifier }, + { "deviceName", deviceName }, + { "grant_type", "password" }, + { "username", username }, + { "password", password }, + }), context => context.Request.Headers.Add("Auth-Email", CoreHelpers.Base64UrlEncodeString(username))); - public async Task<(string Token, string RefreshToken)> TokenFromPasswordAsync(string username, - string password, - string deviceIdentifier = DefaultDeviceIdentifier, - string clientId = "web", - DeviceType deviceType = DeviceType.FirefoxBrowser, - string deviceName = "firefox") - { - var context = await Server.PostAsync("/connect/token", new FormUrlEncodedContent(new Dictionary - { - { "scope", "api offline_access" }, - { "client_id", clientId }, - { "deviceType", ((int)deviceType).ToString() }, - { "deviceIdentifier", deviceIdentifier }, - { "deviceName", deviceName }, - { "grant_type", "password" }, - { "username", username }, - { "password", password }, - }), context => context.Request.Headers.Add("Auth-Email", CoreHelpers.Base64UrlEncodeString(username))); + using var body = await AssertHelper.AssertResponseTypeIs(context); + var root = body.RootElement; - using var body = await AssertHelper.AssertResponseTypeIs(context); - var root = body.RootElement; - - return (root.GetProperty("access_token").GetString(), root.GetProperty("refresh_token").GetString()); - } + return (root.GetProperty("access_token").GetString(), root.GetProperty("refresh_token").GetString()); } } diff --git a/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs b/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs index 04b4c0de4..45a1454ae 100644 --- a/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs +++ b/test/IntegrationTestCommon/Factories/WebApplicationFactoryBase.cs @@ -9,103 +9,102 @@ using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; -namespace Bit.IntegrationTestCommon.Factories +namespace Bit.IntegrationTestCommon.Factories; + +public static class FactoryConstants { - public static class FactoryConstants + public const string DefaultDatabaseName = "test_database"; + public const string WhitelistedIp = "1.1.1.1"; +} + +public abstract class WebApplicationFactoryBase : WebApplicationFactory + where T : class +{ + /// + /// The database name to use for this instance of the factory. By default it will use a shared database name so all instances will connect to the same database during it's lifetime. + /// + /// + /// This will need to be set BEFORE using the Server property + /// + public string DatabaseName { get; set; } = FactoryConstants.DefaultDatabaseName; + + /// + /// Configure the web host to use an EF in memory database + /// + protected override void ConfigureWebHost(IWebHostBuilder builder) { - public const string DefaultDatabaseName = "test_database"; - public const string WhitelistedIp = "1.1.1.1"; + builder.ConfigureAppConfiguration(c => + { + c.SetBasePath(AppContext.BaseDirectory) + .AddJsonFile("appsettings.json") + .AddJsonFile("appsettings.Development.json"); + + c.AddUserSecrets(typeof(Identity.Startup).Assembly, optional: true); + c.AddInMemoryCollection(new Dictionary + { + // Manually insert a EF provider so that ConfigureServices will add EF repositories but we will override + // DbContextOptions to use an in memory database + { "globalSettings:databaseProvider", "postgres" }, + { "globalSettings:postgreSql:connectionString", "Host=localhost;Username=test;Password=test;Database=test" }, + + // Clear the redis connection string for distributed caching, forcing an in-memory implementation + { "globalSettings:redis:connectionString", ""} + }); + }); + + builder.ConfigureTestServices(services => + { + var dbContextOptions = services.First(sd => sd.ServiceType == typeof(DbContextOptions)); + services.Remove(dbContextOptions); + services.AddScoped(_ => + { + return new DbContextOptionsBuilder() + .UseInMemoryDatabase(DatabaseName) + .Options; + }); + + // QUESTION: The normal licensing service should run fine on developer machines but not in CI + // should we have a fork here to leave the normal service for developers? + // TODO: Eventually add the license file to CI + var licensingService = services.First(sd => sd.ServiceType == typeof(ILicensingService)); + services.Remove(licensingService); + services.AddSingleton(); + + // FUTURE CONSIDERATION: Add way to run this self hosted/cloud, for now it is cloud only + var pushRegistrationService = services.First(sd => sd.ServiceType == typeof(IPushRegistrationService)); + services.Remove(pushRegistrationService); + services.AddSingleton(); + + // Even though we are cloud we currently set this up as cloud, we can use the EF/selfhosted service + // instead of using Noop for this service + // TODO: Install and use azurite in CI pipeline + var eventWriteService = services.First(sd => sd.ServiceType == typeof(IEventWriteService)); + services.Remove(eventWriteService); + services.AddSingleton(); + + var eventRepositoryService = services.First(sd => sd.ServiceType == typeof(IEventRepository)); + services.Remove(eventRepositoryService); + services.AddSingleton(); + + // Our Rate limiter works so well that it begins to fail tests unless we carve out + // one whitelisted ip. We should still test the rate limiter though and they should change the Ip + // to something that is NOT whitelisted + services.Configure(options => + { + options.IpWhitelist = new List + { + FactoryConstants.WhitelistedIp, + }; + }); + + // Fix IP Rate Limiting + services.AddSingleton(); + }); } - public abstract class WebApplicationFactoryBase : WebApplicationFactory - where T : class + public DatabaseContext GetDatabaseContext() { - /// - /// The database name to use for this instance of the factory. By default it will use a shared database name so all instances will connect to the same database during it's lifetime. - /// - /// - /// This will need to be set BEFORE using the Server property - /// - public string DatabaseName { get; set; } = FactoryConstants.DefaultDatabaseName; - - /// - /// Configure the web host to use an EF in memory database - /// - protected override void ConfigureWebHost(IWebHostBuilder builder) - { - builder.ConfigureAppConfiguration(c => - { - c.SetBasePath(AppContext.BaseDirectory) - .AddJsonFile("appsettings.json") - .AddJsonFile("appsettings.Development.json"); - - c.AddUserSecrets(typeof(Identity.Startup).Assembly, optional: true); - c.AddInMemoryCollection(new Dictionary - { - // Manually insert a EF provider so that ConfigureServices will add EF repositories but we will override - // DbContextOptions to use an in memory database - { "globalSettings:databaseProvider", "postgres" }, - { "globalSettings:postgreSql:connectionString", "Host=localhost;Username=test;Password=test;Database=test" }, - - // Clear the redis connection string for distributed caching, forcing an in-memory implementation - { "globalSettings:redis:connectionString", ""} - }); - }); - - builder.ConfigureTestServices(services => - { - var dbContextOptions = services.First(sd => sd.ServiceType == typeof(DbContextOptions)); - services.Remove(dbContextOptions); - services.AddScoped(_ => - { - return new DbContextOptionsBuilder() - .UseInMemoryDatabase(DatabaseName) - .Options; - }); - - // QUESTION: The normal licensing service should run fine on developer machines but not in CI - // should we have a fork here to leave the normal service for developers? - // TODO: Eventually add the license file to CI - var licensingService = services.First(sd => sd.ServiceType == typeof(ILicensingService)); - services.Remove(licensingService); - services.AddSingleton(); - - // FUTURE CONSIDERATION: Add way to run this self hosted/cloud, for now it is cloud only - var pushRegistrationService = services.First(sd => sd.ServiceType == typeof(IPushRegistrationService)); - services.Remove(pushRegistrationService); - services.AddSingleton(); - - // Even though we are cloud we currently set this up as cloud, we can use the EF/selfhosted service - // instead of using Noop for this service - // TODO: Install and use azurite in CI pipeline - var eventWriteService = services.First(sd => sd.ServiceType == typeof(IEventWriteService)); - services.Remove(eventWriteService); - services.AddSingleton(); - - var eventRepositoryService = services.First(sd => sd.ServiceType == typeof(IEventRepository)); - services.Remove(eventRepositoryService); - services.AddSingleton(); - - // Our Rate limiter works so well that it begins to fail tests unless we carve out - // one whitelisted ip. We should still test the rate limiter though and they should change the Ip - // to something that is NOT whitelisted - services.Configure(options => - { - options.IpWhitelist = new List - { - FactoryConstants.WhitelistedIp, - }; - }); - - // Fix IP Rate Limiting - services.AddSingleton(); - }); - } - - public DatabaseContext GetDatabaseContext() - { - var scope = Services.CreateScope(); - return scope.ServiceProvider.GetRequiredService(); - } + var scope = Services.CreateScope(); + return scope.ServiceProvider.GetRequiredService(); } } diff --git a/test/IntegrationTestCommon/Factories/WebApplicationFactoryExtensions.cs b/test/IntegrationTestCommon/Factories/WebApplicationFactoryExtensions.cs index ed428a772..88fc21006 100644 --- a/test/IntegrationTestCommon/Factories/WebApplicationFactoryExtensions.cs +++ b/test/IntegrationTestCommon/Factories/WebApplicationFactoryExtensions.cs @@ -4,65 +4,64 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.Primitives; -namespace Bit.IntegrationTestCommon.Factories +namespace Bit.IntegrationTestCommon.Factories; + +public static class WebApplicationFactoryExtensions { - public static class WebApplicationFactoryExtensions + private static async Task SendAsync(this TestServer server, + HttpMethod method, + string requestUri, + HttpContent content = null, + Action extraConfiguration = null) { - private static async Task SendAsync(this TestServer server, - HttpMethod method, - string requestUri, - HttpContent content = null, - Action extraConfiguration = null) + return await server.SendAsync(httpContext => { - return await server.SendAsync(httpContext => + // Automatically set the whitelisted IP so normal tests do not run into rate limit issues + // to test rate limiter, use the extraConfiguration parameter to set Connection.RemoteIpAddress + // it runs after this so it will take precedence. + httpContext.Connection.RemoteIpAddress = IPAddress.Parse(FactoryConstants.WhitelistedIp); + + httpContext.Request.Path = new PathString(requestUri); + httpContext.Request.Method = method.Method; + + if (content != null) { - // Automatically set the whitelisted IP so normal tests do not run into rate limit issues - // to test rate limiter, use the extraConfiguration parameter to set Connection.RemoteIpAddress - // it runs after this so it will take precedence. - httpContext.Connection.RemoteIpAddress = IPAddress.Parse(FactoryConstants.WhitelistedIp); - - httpContext.Request.Path = new PathString(requestUri); - httpContext.Request.Method = method.Method; - - if (content != null) + foreach (var header in content.Headers) { - foreach (var header in content.Headers) - { - httpContext.Request.Headers.Add(header.Key, new StringValues(header.Value.ToArray())); - } - - httpContext.Request.Body = content.ReadAsStream(); + httpContext.Request.Headers.Add(header.Key, new StringValues(header.Value.ToArray())); } - extraConfiguration?.Invoke(httpContext); - }); - } - public static Task PostAsync(this TestServer server, - string requestUri, - HttpContent content, - Action extraConfiguration = null) - => SendAsync(server, HttpMethod.Post, requestUri, content, extraConfiguration); - public static Task GetAsync(this TestServer server, - string requestUri, - Action extraConfiguration = null) - => SendAsync(server, HttpMethod.Get, requestUri, content: null, extraConfiguration); + httpContext.Request.Body = content.ReadAsStream(); + } - public static HttpContext SetAuthEmail(this HttpContext context, string username) - { - context.Request.Headers.Add("Auth-Email", CoreHelpers.Base64UrlEncodeString(username)); - return context; - } + extraConfiguration?.Invoke(httpContext); + }); + } + public static Task PostAsync(this TestServer server, + string requestUri, + HttpContent content, + Action extraConfiguration = null) + => SendAsync(server, HttpMethod.Post, requestUri, content, extraConfiguration); + public static Task GetAsync(this TestServer server, + string requestUri, + Action extraConfiguration = null) + => SendAsync(server, HttpMethod.Get, requestUri, content: null, extraConfiguration); - public static HttpContext SetIp(this HttpContext context, string ip) - { - context.Connection.RemoteIpAddress = IPAddress.Parse(ip); - return context; - } + public static HttpContext SetAuthEmail(this HttpContext context, string username) + { + context.Request.Headers.Add("Auth-Email", CoreHelpers.Base64UrlEncodeString(username)); + return context; + } - public static async Task ReadBodyAsStringAsync(this HttpContext context) - { - using var sr = new StreamReader(context.Response.Body); - return await sr.ReadToEndAsync(); - } + public static HttpContext SetIp(this HttpContext context, string ip) + { + context.Connection.RemoteIpAddress = IPAddress.Parse(ip); + return context; + } + + public static async Task ReadBodyAsStringAsync(this HttpContext context) + { + using var sr = new StreamReader(context.Response.Body); + return await sr.ReadToEndAsync(); } } diff --git a/util/EfShared/MigrationBuilderExtensions.cs b/util/EfShared/MigrationBuilderExtensions.cs index aa8fc04a8..cb9fad33c 100644 --- a/util/EfShared/MigrationBuilderExtensions.cs +++ b/util/EfShared/MigrationBuilderExtensions.cs @@ -4,29 +4,28 @@ using System.Runtime.CompilerServices; using Bit.Core.Utilities; using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit +namespace Bit; + +// This file is a manual addition to a project that it helps, a project that chooses to compile it +// should have a projet reference to Core.csproj and a package reference to Microsoft.EntityFrameworkCore.Design +// The reason for this is that if it belonged to it's own library you would have to add manual references to the above +// and manage the version for the EntityFrameworkCore package. This way it also doesn't create another dll +// To include this you can view examples in the MySqlMigrations and PostgresMigrations .csproj files. +// + +public static class MigrationBuilderExtensions { - // This file is a manual addition to a project that it helps, a project that chooses to compile it - // should have a projet reference to Core.csproj and a package reference to Microsoft.EntityFrameworkCore.Design - // The reason for this is that if it belonged to it's own library you would have to add manual references to the above - // and manage the version for the EntityFrameworkCore package. This way it also doesn't create another dll - // To include this you can view examples in the MySqlMigrations and PostgresMigrations .csproj files. - // - - public static class MigrationBuilderExtensions + /// + /// Reads an embedded resource for it's SQL contents and formats it with the specified direction for easier custom migration steps + /// + /// The MigrationBuilder instance the sql should be applied to + /// The file name portion of the resource name, it is assumed to be in a Scripts folder + /// The direction of the migration taking place + public static void SqlResource(this MigrationBuilder migrationBuilder, string resourceName, [CallerMemberName] string dir = null) { - /// - /// Reads an embedded resource for it's SQL contents and formats it with the specified direction for easier custom migration steps - /// - /// The MigrationBuilder instance the sql should be applied to - /// The file name portion of the resource name, it is assumed to be in a Scripts folder - /// The direction of the migration taking place - public static void SqlResource(this MigrationBuilder migrationBuilder, string resourceName, [CallerMemberName] string dir = null) - { - var formattedResourceName = string.IsNullOrEmpty(dir) ? resourceName : string.Format(resourceName, dir); + var formattedResourceName = string.IsNullOrEmpty(dir) ? resourceName : string.Format(resourceName, dir); - migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync( - $"Scripts.{formattedResourceName}")); - } + migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync( + $"Scripts.{formattedResourceName}")); } } diff --git a/util/Migrator/DbMigrator.cs b/util/Migrator/DbMigrator.cs index d0463a00e..ad62691fc 100644 --- a/util/Migrator/DbMigrator.cs +++ b/util/Migrator/DbMigrator.cs @@ -5,104 +5,103 @@ using Bit.Core; using DbUp; using Microsoft.Extensions.Logging; -namespace Bit.Migrator +namespace Bit.Migrator; + +public class DbMigrator { - public class DbMigrator + private readonly string _connectionString; + private readonly ILogger _logger; + private readonly string _masterConnectionString; + + public DbMigrator(string connectionString, ILogger logger) { - private readonly string _connectionString; - private readonly ILogger _logger; - private readonly string _masterConnectionString; - - public DbMigrator(string connectionString, ILogger logger) + _connectionString = connectionString; + _logger = logger; + _masterConnectionString = new SqlConnectionStringBuilder(connectionString) { - _connectionString = connectionString; - _logger = logger; - _masterConnectionString = new SqlConnectionStringBuilder(connectionString) - { - InitialCatalog = "master" - }.ConnectionString; + InitialCatalog = "master" + }.ConnectionString; + } + + public bool MigrateMsSqlDatabase(bool enableLogging = true, + CancellationToken cancellationToken = default(CancellationToken)) + { + if (enableLogging && _logger != null) + { + _logger.LogInformation(Constants.BypassFiltersEventId, "Migrating database."); } - public bool MigrateMsSqlDatabase(bool enableLogging = true, - CancellationToken cancellationToken = default(CancellationToken)) + using (var connection = new SqlConnection(_masterConnectionString)) { - if (enableLogging && _logger != null) + var databaseName = new SqlConnectionStringBuilder(_connectionString).InitialCatalog; + if (string.IsNullOrWhiteSpace(databaseName)) { - _logger.LogInformation(Constants.BypassFiltersEventId, "Migrating database."); + databaseName = "vault"; } - using (var connection = new SqlConnection(_masterConnectionString)) - { - var databaseName = new SqlConnectionStringBuilder(_connectionString).InitialCatalog; - if (string.IsNullOrWhiteSpace(databaseName)) - { - databaseName = "vault"; - } + var databaseNameQuoted = new SqlCommandBuilder().QuoteIdentifier(databaseName); + var command = new SqlCommand( + "IF ((SELECT COUNT(1) FROM sys.databases WHERE [name] = @DatabaseName) = 0) " + + "CREATE DATABASE " + databaseNameQuoted + ";", connection); + command.Parameters.Add("@DatabaseName", SqlDbType.VarChar).Value = databaseName; + command.Connection.Open(); + command.ExecuteNonQuery(); - var databaseNameQuoted = new SqlCommandBuilder().QuoteIdentifier(databaseName); - var command = new SqlCommand( - "IF ((SELECT COUNT(1) FROM sys.databases WHERE [name] = @DatabaseName) = 0) " + - "CREATE DATABASE " + databaseNameQuoted + ";", connection); - command.Parameters.Add("@DatabaseName", SqlDbType.VarChar).Value = databaseName; - command.Connection.Open(); - command.ExecuteNonQuery(); - - command.CommandText = "IF ((SELECT DATABASEPROPERTYEX([name], 'IsAutoClose') " + - "FROM sys.databases WHERE [name] = @DatabaseName) = 1) " + - "ALTER DATABASE " + databaseNameQuoted + " SET AUTO_CLOSE OFF;"; - command.ExecuteNonQuery(); - } - - cancellationToken.ThrowIfCancellationRequested(); - using (var connection = new SqlConnection(_connectionString)) - { - // Rename old migration scripts to new namespace. - var command = new SqlCommand( - "IF OBJECT_ID('Migration','U') IS NOT NULL " + - "UPDATE [dbo].[Migration] SET " + - "[ScriptName] = REPLACE([ScriptName], 'Bit.Setup.', 'Bit.Migrator.');", connection); - command.Connection.Open(); - command.ExecuteNonQuery(); - } - - cancellationToken.ThrowIfCancellationRequested(); - var builder = DeployChanges.To - .SqlDatabase(_connectionString) - .JournalToSqlTable("dbo", "Migration") - .WithScriptsAndCodeEmbeddedInAssembly(Assembly.GetExecutingAssembly(), - s => s.Contains($".DbScripts.") && !s.Contains(".Archive.")) - .WithTransaction() - .WithExecutionTimeout(new TimeSpan(0, 5, 0)); - - if (enableLogging) - { - if (_logger != null) - { - builder.LogTo(new DbUpLogger(_logger)); - } - else - { - builder.LogToConsole(); - } - } - - var upgrader = builder.Build(); - var result = upgrader.PerformUpgrade(); - - if (enableLogging && _logger != null) - { - if (result.Successful) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Migration successful."); - } - else - { - _logger.LogError(Constants.BypassFiltersEventId, result.Error, "Migration failed."); - } - } - - cancellationToken.ThrowIfCancellationRequested(); - return result.Successful; + command.CommandText = "IF ((SELECT DATABASEPROPERTYEX([name], 'IsAutoClose') " + + "FROM sys.databases WHERE [name] = @DatabaseName) = 1) " + + "ALTER DATABASE " + databaseNameQuoted + " SET AUTO_CLOSE OFF;"; + command.ExecuteNonQuery(); } + + cancellationToken.ThrowIfCancellationRequested(); + using (var connection = new SqlConnection(_connectionString)) + { + // Rename old migration scripts to new namespace. + var command = new SqlCommand( + "IF OBJECT_ID('Migration','U') IS NOT NULL " + + "UPDATE [dbo].[Migration] SET " + + "[ScriptName] = REPLACE([ScriptName], 'Bit.Setup.', 'Bit.Migrator.');", connection); + command.Connection.Open(); + command.ExecuteNonQuery(); + } + + cancellationToken.ThrowIfCancellationRequested(); + var builder = DeployChanges.To + .SqlDatabase(_connectionString) + .JournalToSqlTable("dbo", "Migration") + .WithScriptsAndCodeEmbeddedInAssembly(Assembly.GetExecutingAssembly(), + s => s.Contains($".DbScripts.") && !s.Contains(".Archive.")) + .WithTransaction() + .WithExecutionTimeout(new TimeSpan(0, 5, 0)); + + if (enableLogging) + { + if (_logger != null) + { + builder.LogTo(new DbUpLogger(_logger)); + } + else + { + builder.LogToConsole(); + } + } + + var upgrader = builder.Build(); + var result = upgrader.PerformUpgrade(); + + if (enableLogging && _logger != null) + { + if (result.Successful) + { + _logger.LogInformation(Constants.BypassFiltersEventId, "Migration successful."); + } + else + { + _logger.LogError(Constants.BypassFiltersEventId, result.Error, "Migration failed."); + } + } + + cancellationToken.ThrowIfCancellationRequested(); + return result.Successful; } } diff --git a/util/Migrator/DbUpLogger.cs b/util/Migrator/DbUpLogger.cs index 1c1707d21..a65b3ec0e 100644 --- a/util/Migrator/DbUpLogger.cs +++ b/util/Migrator/DbUpLogger.cs @@ -2,30 +2,29 @@ using DbUp.Engine.Output; using Microsoft.Extensions.Logging; -namespace Bit.Migrator +namespace Bit.Migrator; + +public class DbUpLogger : IUpgradeLog { - public class DbUpLogger : IUpgradeLog + private readonly ILogger _logger; + + public DbUpLogger(ILogger logger) { - private readonly ILogger _logger; + _logger = logger; + } - public DbUpLogger(ILogger logger) - { - _logger = logger; - } + public void WriteError(string format, params object[] args) + { + _logger.LogError(Constants.BypassFiltersEventId, format, args); + } - public void WriteError(string format, params object[] args) - { - _logger.LogError(Constants.BypassFiltersEventId, format, args); - } + public void WriteInformation(string format, params object[] args) + { + _logger.LogInformation(Constants.BypassFiltersEventId, format, args); + } - public void WriteInformation(string format, params object[] args) - { - _logger.LogInformation(Constants.BypassFiltersEventId, format, args); - } - - public void WriteWarning(string format, params object[] args) - { - _logger.LogWarning(Constants.BypassFiltersEventId, format, args); - } + public void WriteWarning(string format, params object[] args) + { + _logger.LogWarning(Constants.BypassFiltersEventId, format, args); } } diff --git a/util/MySqlMigrations/Factories.cs b/util/MySqlMigrations/Factories.cs index 734b88dd8..538c39612 100644 --- a/util/MySqlMigrations/Factories.cs +++ b/util/MySqlMigrations/Factories.cs @@ -4,35 +4,34 @@ using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Design; using Microsoft.Extensions.Configuration; -namespace MySqlMigrations -{ - public static class GlobalSettingsFactory - { - public static GlobalSettings GlobalSettings { get; } = new GlobalSettings(); - static GlobalSettingsFactory() - { - var configBuilder = new ConfigurationBuilder().AddUserSecrets(); - var Configuration = configBuilder.Build(); - ConfigurationBinder.Bind(Configuration.GetSection("GlobalSettings"), GlobalSettings); - } - } +namespace MySqlMigrations; - public class DatabaseContextFactory : IDesignTimeDbContextFactory +public static class GlobalSettingsFactory +{ + public static GlobalSettings GlobalSettings { get; } = new GlobalSettings(); + static GlobalSettingsFactory() { - public DatabaseContext CreateDbContext(string[] args) - { - var globalSettings = GlobalSettingsFactory.GlobalSettings; - var optionsBuilder = new DbContextOptionsBuilder(); - var connectionString = globalSettings.MySql?.ConnectionString; - if (string.IsNullOrWhiteSpace(connectionString)) - { - throw new Exception("No MySql connection string found."); - } - optionsBuilder.UseMySql( - connectionString, - ServerVersion.AutoDetect(connectionString), - b => b.MigrationsAssembly("MySqlMigrations")); - return new DatabaseContext(optionsBuilder.Options); - } + var configBuilder = new ConfigurationBuilder().AddUserSecrets(); + var Configuration = configBuilder.Build(); + ConfigurationBinder.Bind(Configuration.GetSection("GlobalSettings"), GlobalSettings); + } +} + +public class DatabaseContextFactory : IDesignTimeDbContextFactory +{ + public DatabaseContext CreateDbContext(string[] args) + { + var globalSettings = GlobalSettingsFactory.GlobalSettings; + var optionsBuilder = new DbContextOptionsBuilder(); + var connectionString = globalSettings.MySql?.ConnectionString; + if (string.IsNullOrWhiteSpace(connectionString)) + { + throw new Exception("No MySql connection string found."); + } + optionsBuilder.UseMySql( + connectionString, + ServerVersion.AutoDetect(connectionString), + b => b.MigrationsAssembly("MySqlMigrations")); + return new DatabaseContext(optionsBuilder.Options); } } diff --git a/util/MySqlMigrations/Migrations/20210617183900_Init.cs b/util/MySqlMigrations/Migrations/20210617183900_Init.cs index 859091b72..d85ad6a1e 100644 --- a/util/MySqlMigrations/Migrations/20210617183900_Init.cs +++ b/util/MySqlMigrations/Migrations/20210617183900_Init.cs @@ -1,1129 +1,1128 @@ using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations +namespace Bit.MySqlMigrations.Migrations; + +public partial class Init : Migration { - public partial class Init : Migration + protected override void Up(MigrationBuilder migrationBuilder) { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AlterDatabase() - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Event", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Date = table.Column(type: "datetime(6)", nullable: false), - Type = table.Column(type: "int", nullable: false), - UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - CipherId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - CollectionId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - PolicyId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - GroupId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - OrganizationUserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - DeviceType = table.Column(type: "tinyint unsigned", nullable: true), - IpAddress = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ActingUserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci") - }, - constraints: table => - { - table.PrimaryKey("PK_Event", x => x.Id); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Grant", - columns: table => new - { - Key = table.Column(type: "varchar(200)", maxLength: 200, nullable: false) - .Annotation("MySql:CharSet", "utf8mb4"), - Type = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - SubjectId = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - SessionId = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ClientId = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Description = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - ExpirationDate = table.Column(type: "datetime(6)", nullable: true), - ConsumedDate = table.Column(type: "datetime(6)", nullable: true), - Data = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4") - }, - constraints: table => - { - table.PrimaryKey("PK_Grant", x => x.Key); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Installation", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Email = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Key = table.Column(type: "varchar(150)", maxLength: 150, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Enabled = table.Column(type: "tinyint(1)", nullable: false), - CreationDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Installation", x => x.Id); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Organization", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Identifier = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Name = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessName = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessAddress1 = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessAddress2 = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessAddress3 = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessCountry = table.Column(type: "varchar(2)", maxLength: 2, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessTaxNumber = table.Column(type: "varchar(30)", maxLength: 30, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BillingEmail = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Plan = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - PlanType = table.Column(type: "tinyint unsigned", nullable: false), - Seats = table.Column(type: "int", nullable: true), - MaxCollections = table.Column(type: "smallint", nullable: true), - UsePolicies = table.Column(type: "tinyint(1)", nullable: false), - UseSso = table.Column(type: "tinyint(1)", nullable: false), - UseGroups = table.Column(type: "tinyint(1)", nullable: false), - UseDirectory = table.Column(type: "tinyint(1)", nullable: false), - UseEvents = table.Column(type: "tinyint(1)", nullable: false), - UseTotp = table.Column(type: "tinyint(1)", nullable: false), - Use2fa = table.Column(type: "tinyint(1)", nullable: false), - UseApi = table.Column(type: "tinyint(1)", nullable: false), - UseResetPassword = table.Column(type: "tinyint(1)", nullable: false), - SelfHost = table.Column(type: "tinyint(1)", nullable: false), - UsersGetPremium = table.Column(type: "tinyint(1)", nullable: false), - Storage = table.Column(type: "bigint", nullable: true), - MaxStorageGb = table.Column(type: "smallint", nullable: true), - Gateway = table.Column(type: "tinyint unsigned", nullable: true), - GatewayCustomerId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - GatewaySubscriptionId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ReferenceData = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Enabled = table.Column(type: "tinyint(1)", nullable: false), - LicenseKey = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ApiKey = table.Column(type: "varchar(30)", maxLength: 30, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - PublicKey = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - PrivateKey = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - TwoFactorProviders = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ExpirationDate = table.Column(type: "datetime(6)", nullable: true), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Organization", x => x.Id); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Provider", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Name = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessName = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessAddress1 = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessAddress2 = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessAddress3 = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessCountry = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BusinessTaxNumber = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - BillingEmail = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Status = table.Column(type: "tinyint unsigned", nullable: false), - Enabled = table.Column(type: "tinyint(1)", nullable: false), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Provider", x => x.Id); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "TaxRate", - columns: table => new - { - Id = table.Column(type: "varchar(40)", maxLength: 40, nullable: false) - .Annotation("MySql:CharSet", "utf8mb4"), - Country = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - State = table.Column(type: "varchar(2)", maxLength: 2, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - PostalCode = table.Column(type: "varchar(10)", maxLength: 10, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Rate = table.Column(type: "decimal(65,30)", nullable: false), - Active = table.Column(type: "tinyint(1)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_TaxRate", x => x.Id); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "User", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Name = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Email = table.Column(type: "varchar(256)", maxLength: 256, nullable: false) - .Annotation("MySql:CharSet", "utf8mb4"), - EmailVerified = table.Column(type: "tinyint(1)", nullable: false), - MasterPassword = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - MasterPasswordHint = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Culture = table.Column(type: "varchar(10)", maxLength: 10, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - SecurityStamp = table.Column(type: "varchar(50)", maxLength: 50, nullable: false) - .Annotation("MySql:CharSet", "utf8mb4"), - TwoFactorProviders = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - TwoFactorRecoveryCode = table.Column(type: "varchar(32)", maxLength: 32, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - EquivalentDomains = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ExcludedGlobalEquivalentDomains = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - AccountRevisionDate = table.Column(type: "datetime(6)", nullable: false), - Key = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - PublicKey = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - PrivateKey = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Premium = table.Column(type: "tinyint(1)", nullable: false), - PremiumExpirationDate = table.Column(type: "datetime(6)", nullable: true), - RenewalReminderDate = table.Column(type: "datetime(6)", nullable: true), - Storage = table.Column(type: "bigint", nullable: true), - MaxStorageGb = table.Column(type: "smallint", nullable: true), - Gateway = table.Column(type: "tinyint unsigned", nullable: true), - GatewayCustomerId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - GatewaySubscriptionId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ReferenceData = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - LicenseKey = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ApiKey = table.Column(type: "varchar(30)", maxLength: 30, nullable: false) - .Annotation("MySql:CharSet", "utf8mb4"), - Kdf = table.Column(type: "tinyint unsigned", nullable: false), - KdfIterations = table.Column(type: "int", nullable: false), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_User", x => x.Id); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Collection", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Name = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ExternalId = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Collection", x => x.Id); - table.ForeignKey( - name: "FK_Collection_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Group", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Name = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - AccessAll = table.Column(type: "tinyint(1)", nullable: false), - ExternalId = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Group", x => x.Id); - table.ForeignKey( - name: "FK_Group_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Policy", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Type = table.Column(type: "tinyint unsigned", nullable: false), - Data = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Enabled = table.Column(type: "tinyint(1)", nullable: false), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Policy", x => x.Id); - table.ForeignKey( - name: "FK_Policy_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "SsoConfig", - columns: table => new - { - Id = table.Column(type: "bigint", nullable: false) - .Annotation("MySql:ValueGenerationStrategy", MySqlValueGenerationStrategy.IdentityColumn), - Enabled = table.Column(type: "tinyint(1)", nullable: false), - OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Data = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_SsoConfig", x => x.Id); - table.ForeignKey( - name: "FK_SsoConfig_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "ProviderOrganization", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - ProviderId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Key = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Settings = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_ProviderOrganization", x => x.Id); - table.ForeignKey( - name: "FK_ProviderOrganization_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_ProviderOrganization_Provider_ProviderId", - column: x => x.ProviderId, - principalTable: "Provider", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Cipher", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - Type = table.Column(type: "tinyint unsigned", nullable: false), - Data = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Favorites = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Folders = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Attachments = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false), - DeletedDate = table.Column(type: "datetime(6)", nullable: true), - Reprompt = table.Column(type: "tinyint unsigned", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_Cipher", x => x.Id); - table.ForeignKey( - name: "FK_Cipher_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_Cipher_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Device", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Name = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Type = table.Column(type: "tinyint unsigned", nullable: false), - Identifier = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - PushToken = table.Column(type: "varchar(255)", maxLength: 255, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Device", x => x.Id); - table.ForeignKey( - name: "FK_Device_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "EmergencyAccess", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - GrantorId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - GranteeId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - Email = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - KeyEncrypted = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Type = table.Column(type: "tinyint unsigned", nullable: false), - Status = table.Column(type: "tinyint unsigned", nullable: false), - WaitTimeDays = table.Column(type: "int", nullable: false), - RecoveryInitiatedDate = table.Column(type: "datetime(6)", nullable: true), - LastNotificationDate = table.Column(type: "datetime(6)", nullable: true), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_EmergencyAccess", x => x.Id); - table.ForeignKey( - name: "FK_EmergencyAccess_User_GranteeId", - column: x => x.GranteeId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_EmergencyAccess_User_GrantorId", - column: x => x.GrantorId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Folder", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Name = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Folder", x => x.Id); - table.ForeignKey( - name: "FK_Folder_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "OrganizationUser", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - Email = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Key = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ResetPasswordKey = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Status = table.Column(type: "tinyint unsigned", nullable: false), - Type = table.Column(type: "tinyint unsigned", nullable: false), - AccessAll = table.Column(type: "tinyint(1)", nullable: false), - ExternalId = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false), - Permissions = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4") - }, - constraints: table => - { - table.PrimaryKey("PK_OrganizationUser", x => x.Id); - table.ForeignKey( - name: "FK_OrganizationUser_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_OrganizationUser_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "ProviderUser", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - ProviderId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - Email = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Key = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Status = table.Column(type: "tinyint unsigned", nullable: false), - Type = table.Column(type: "tinyint unsigned", nullable: false), - Permissions = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_ProviderUser", x => x.Id); - table.ForeignKey( - name: "FK_ProviderUser_Provider_ProviderId", - column: x => x.ProviderId, - principalTable: "Provider", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_ProviderUser_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Send", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - Type = table.Column(type: "tinyint unsigned", nullable: false), - Data = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Key = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Password = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - MaxAccessCount = table.Column(type: "int", nullable: true), - AccessCount = table.Column(type: "int", nullable: false), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false), - ExpirationDate = table.Column(type: "datetime(6)", nullable: true), - DeletionDate = table.Column(type: "datetime(6)", nullable: false), - Disabled = table.Column(type: "tinyint(1)", nullable: false), - HideEmail = table.Column(type: "tinyint(1)", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_Send", x => x.Id); - table.ForeignKey( - name: "FK_Send_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_Send_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "SsoUser", - columns: table => new - { - Id = table.Column(type: "bigint", nullable: false) - .Annotation("MySql:ValueGenerationStrategy", MySqlValueGenerationStrategy.IdentityColumn), - UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - ExternalId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_SsoUser", x => x.Id); - table.ForeignKey( - name: "FK_SsoUser_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_SsoUser_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "Transaction", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - Type = table.Column(type: "tinyint unsigned", nullable: false), - Amount = table.Column(type: "decimal(65,30)", nullable: false), - Refunded = table.Column(type: "tinyint(1)", nullable: true), - RefundedAmount = table.Column(type: "decimal(65,30)", nullable: true), - Details = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - PaymentMethodType = table.Column(type: "tinyint unsigned", nullable: true), - Gateway = table.Column(type: "tinyint unsigned", nullable: true), - GatewayId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Transaction", x => x.Id); - table.ForeignKey( - name: "FK_Transaction_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_Transaction_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "U2f", - columns: table => new - { - Id = table.Column(type: "int", nullable: false) - .Annotation("MySql:ValueGenerationStrategy", MySqlValueGenerationStrategy.IdentityColumn), - UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - KeyHandle = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Challenge = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - AppId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Version = table.Column(type: "varchar(20)", maxLength: 20, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_U2f", x => x.Id); - table.ForeignKey( - name: "FK_U2f_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "CollectionGroups", - columns: table => new - { - CollectionId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - GroupId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - ReadOnly = table.Column(type: "tinyint(1)", nullable: false), - HidePasswords = table.Column(type: "tinyint(1)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_CollectionGroups", x => new { x.CollectionId, x.GroupId }); - table.ForeignKey( - name: "FK_CollectionGroups_Collection_CollectionId", - column: x => x.CollectionId, - principalTable: "Collection", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_CollectionGroups_Group_GroupId", - column: x => x.GroupId, - principalTable: "Group", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "CollectionCipher", - columns: table => new - { - CollectionId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - CipherId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci") - }, - constraints: table => - { - table.PrimaryKey("PK_CollectionCipher", x => new { x.CollectionId, x.CipherId }); - table.ForeignKey( - name: "FK_CollectionCipher_Cipher_CipherId", - column: x => x.CipherId, - principalTable: "Cipher", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_CollectionCipher_Collection_CollectionId", - column: x => x.CollectionId, - principalTable: "Collection", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "CollectionUsers", - columns: table => new - { - CollectionId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - OrganizationUserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - ReadOnly = table.Column(type: "tinyint(1)", nullable: false), - HidePasswords = table.Column(type: "tinyint(1)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_CollectionUsers", x => new { x.CollectionId, x.OrganizationUserId }); - table.ForeignKey( - name: "FK_CollectionUsers_Collection_CollectionId", - column: x => x.CollectionId, - principalTable: "Collection", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_CollectionUsers_OrganizationUser_OrganizationUserId", - column: x => x.OrganizationUserId, - principalTable: "OrganizationUser", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_CollectionUsers_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "GroupUser", - columns: table => new - { - GroupId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - OrganizationUserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci") - }, - constraints: table => - { - table.PrimaryKey("PK_GroupUser", x => new { x.GroupId, x.OrganizationUserId }); - table.ForeignKey( - name: "FK_GroupUser_Group_GroupId", - column: x => x.GroupId, - principalTable: "Group", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_GroupUser_OrganizationUser_OrganizationUserId", - column: x => x.OrganizationUserId, - principalTable: "OrganizationUser", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_GroupUser_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateTable( - name: "ProviderOrganizationProviderUser", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - ProviderOrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - ProviderUserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Type = table.Column(type: "tinyint unsigned", nullable: false), - Permissions = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_ProviderOrganizationProviderUser", x => x.Id); - table.ForeignKey( - name: "FK_ProviderOrganizationProviderUser_ProviderOrganization_Provid~", - column: x => x.ProviderOrganizationId, - principalTable: "ProviderOrganization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_ProviderOrganizationProviderUser_ProviderUser_ProviderUserId", - column: x => x.ProviderUserId, - principalTable: "ProviderUser", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); - - migrationBuilder.CreateIndex( - name: "IX_Cipher_OrganizationId", - table: "Cipher", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_Cipher_UserId", - table: "Cipher", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Collection_OrganizationId", - table: "Collection", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_CollectionCipher_CipherId", - table: "CollectionCipher", - column: "CipherId"); - - migrationBuilder.CreateIndex( - name: "IX_CollectionGroups_GroupId", - table: "CollectionGroups", - column: "GroupId"); - - migrationBuilder.CreateIndex( - name: "IX_CollectionUsers_OrganizationUserId", - table: "CollectionUsers", - column: "OrganizationUserId"); - - migrationBuilder.CreateIndex( - name: "IX_CollectionUsers_UserId", - table: "CollectionUsers", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Device_UserId", - table: "Device", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_EmergencyAccess_GranteeId", - table: "EmergencyAccess", - column: "GranteeId"); - - migrationBuilder.CreateIndex( - name: "IX_EmergencyAccess_GrantorId", - table: "EmergencyAccess", - column: "GrantorId"); - - migrationBuilder.CreateIndex( - name: "IX_Folder_UserId", - table: "Folder", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Group_OrganizationId", - table: "Group", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_GroupUser_OrganizationUserId", - table: "GroupUser", - column: "OrganizationUserId"); - - migrationBuilder.CreateIndex( - name: "IX_GroupUser_UserId", - table: "GroupUser", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_OrganizationUser_OrganizationId", - table: "OrganizationUser", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_OrganizationUser_UserId", - table: "OrganizationUser", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Policy_OrganizationId", - table: "Policy", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganization_OrganizationId", - table: "ProviderOrganization", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganization_ProviderId", - table: "ProviderOrganization", - column: "ProviderId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganizationProviderUser_ProviderOrganizationId", - table: "ProviderOrganizationProviderUser", - column: "ProviderOrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganizationProviderUser_ProviderUserId", - table: "ProviderOrganizationProviderUser", - column: "ProviderUserId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderUser_ProviderId", - table: "ProviderUser", - column: "ProviderId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderUser_UserId", - table: "ProviderUser", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Send_OrganizationId", - table: "Send", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_Send_UserId", - table: "Send", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_SsoConfig_OrganizationId", - table: "SsoConfig", - column: "OrganizationId"); + migrationBuilder.AlterDatabase() + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Event", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Date = table.Column(type: "datetime(6)", nullable: false), + Type = table.Column(type: "int", nullable: false), + UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + CipherId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + CollectionId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + PolicyId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + GroupId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + OrganizationUserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + DeviceType = table.Column(type: "tinyint unsigned", nullable: true), + IpAddress = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ActingUserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci") + }, + constraints: table => + { + table.PrimaryKey("PK_Event", x => x.Id); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Grant", + columns: table => new + { + Key = table.Column(type: "varchar(200)", maxLength: 200, nullable: false) + .Annotation("MySql:CharSet", "utf8mb4"), + Type = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + SubjectId = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + SessionId = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ClientId = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Description = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + ExpirationDate = table.Column(type: "datetime(6)", nullable: true), + ConsumedDate = table.Column(type: "datetime(6)", nullable: true), + Data = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4") + }, + constraints: table => + { + table.PrimaryKey("PK_Grant", x => x.Key); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Installation", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Email = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Key = table.Column(type: "varchar(150)", maxLength: 150, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Enabled = table.Column(type: "tinyint(1)", nullable: false), + CreationDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Installation", x => x.Id); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Organization", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Identifier = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Name = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessName = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessAddress1 = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessAddress2 = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessAddress3 = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessCountry = table.Column(type: "varchar(2)", maxLength: 2, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessTaxNumber = table.Column(type: "varchar(30)", maxLength: 30, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BillingEmail = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Plan = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + PlanType = table.Column(type: "tinyint unsigned", nullable: false), + Seats = table.Column(type: "int", nullable: true), + MaxCollections = table.Column(type: "smallint", nullable: true), + UsePolicies = table.Column(type: "tinyint(1)", nullable: false), + UseSso = table.Column(type: "tinyint(1)", nullable: false), + UseGroups = table.Column(type: "tinyint(1)", nullable: false), + UseDirectory = table.Column(type: "tinyint(1)", nullable: false), + UseEvents = table.Column(type: "tinyint(1)", nullable: false), + UseTotp = table.Column(type: "tinyint(1)", nullable: false), + Use2fa = table.Column(type: "tinyint(1)", nullable: false), + UseApi = table.Column(type: "tinyint(1)", nullable: false), + UseResetPassword = table.Column(type: "tinyint(1)", nullable: false), + SelfHost = table.Column(type: "tinyint(1)", nullable: false), + UsersGetPremium = table.Column(type: "tinyint(1)", nullable: false), + Storage = table.Column(type: "bigint", nullable: true), + MaxStorageGb = table.Column(type: "smallint", nullable: true), + Gateway = table.Column(type: "tinyint unsigned", nullable: true), + GatewayCustomerId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + GatewaySubscriptionId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ReferenceData = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Enabled = table.Column(type: "tinyint(1)", nullable: false), + LicenseKey = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ApiKey = table.Column(type: "varchar(30)", maxLength: 30, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + PublicKey = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + PrivateKey = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + TwoFactorProviders = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ExpirationDate = table.Column(type: "datetime(6)", nullable: true), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Organization", x => x.Id); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Provider", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Name = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessName = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessAddress1 = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessAddress2 = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessAddress3 = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessCountry = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BusinessTaxNumber = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + BillingEmail = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Status = table.Column(type: "tinyint unsigned", nullable: false), + Enabled = table.Column(type: "tinyint(1)", nullable: false), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Provider", x => x.Id); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "TaxRate", + columns: table => new + { + Id = table.Column(type: "varchar(40)", maxLength: 40, nullable: false) + .Annotation("MySql:CharSet", "utf8mb4"), + Country = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + State = table.Column(type: "varchar(2)", maxLength: 2, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + PostalCode = table.Column(type: "varchar(10)", maxLength: 10, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Rate = table.Column(type: "decimal(65,30)", nullable: false), + Active = table.Column(type: "tinyint(1)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_TaxRate", x => x.Id); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "User", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Name = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Email = table.Column(type: "varchar(256)", maxLength: 256, nullable: false) + .Annotation("MySql:CharSet", "utf8mb4"), + EmailVerified = table.Column(type: "tinyint(1)", nullable: false), + MasterPassword = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + MasterPasswordHint = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Culture = table.Column(type: "varchar(10)", maxLength: 10, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + SecurityStamp = table.Column(type: "varchar(50)", maxLength: 50, nullable: false) + .Annotation("MySql:CharSet", "utf8mb4"), + TwoFactorProviders = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + TwoFactorRecoveryCode = table.Column(type: "varchar(32)", maxLength: 32, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + EquivalentDomains = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ExcludedGlobalEquivalentDomains = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + AccountRevisionDate = table.Column(type: "datetime(6)", nullable: false), + Key = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + PublicKey = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + PrivateKey = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Premium = table.Column(type: "tinyint(1)", nullable: false), + PremiumExpirationDate = table.Column(type: "datetime(6)", nullable: true), + RenewalReminderDate = table.Column(type: "datetime(6)", nullable: true), + Storage = table.Column(type: "bigint", nullable: true), + MaxStorageGb = table.Column(type: "smallint", nullable: true), + Gateway = table.Column(type: "tinyint unsigned", nullable: true), + GatewayCustomerId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + GatewaySubscriptionId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ReferenceData = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + LicenseKey = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ApiKey = table.Column(type: "varchar(30)", maxLength: 30, nullable: false) + .Annotation("MySql:CharSet", "utf8mb4"), + Kdf = table.Column(type: "tinyint unsigned", nullable: false), + KdfIterations = table.Column(type: "int", nullable: false), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_User", x => x.Id); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Collection", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Name = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ExternalId = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Collection", x => x.Id); + table.ForeignKey( + name: "FK_Collection_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Group", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Name = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + AccessAll = table.Column(type: "tinyint(1)", nullable: false), + ExternalId = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Group", x => x.Id); + table.ForeignKey( + name: "FK_Group_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Policy", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Type = table.Column(type: "tinyint unsigned", nullable: false), + Data = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Enabled = table.Column(type: "tinyint(1)", nullable: false), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Policy", x => x.Id); + table.ForeignKey( + name: "FK_Policy_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "SsoConfig", + columns: table => new + { + Id = table.Column(type: "bigint", nullable: false) + .Annotation("MySql:ValueGenerationStrategy", MySqlValueGenerationStrategy.IdentityColumn), + Enabled = table.Column(type: "tinyint(1)", nullable: false), + OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Data = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_SsoConfig", x => x.Id); + table.ForeignKey( + name: "FK_SsoConfig_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "ProviderOrganization", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + ProviderId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Key = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Settings = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_ProviderOrganization", x => x.Id); + table.ForeignKey( + name: "FK_ProviderOrganization_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_ProviderOrganization_Provider_ProviderId", + column: x => x.ProviderId, + principalTable: "Provider", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Cipher", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + Type = table.Column(type: "tinyint unsigned", nullable: false), + Data = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Favorites = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Folders = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Attachments = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false), + DeletedDate = table.Column(type: "datetime(6)", nullable: true), + Reprompt = table.Column(type: "tinyint unsigned", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_Cipher", x => x.Id); + table.ForeignKey( + name: "FK_Cipher_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_Cipher_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Device", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Name = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Type = table.Column(type: "tinyint unsigned", nullable: false), + Identifier = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + PushToken = table.Column(type: "varchar(255)", maxLength: 255, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Device", x => x.Id); + table.ForeignKey( + name: "FK_Device_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "EmergencyAccess", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + GrantorId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + GranteeId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + Email = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + KeyEncrypted = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Type = table.Column(type: "tinyint unsigned", nullable: false), + Status = table.Column(type: "tinyint unsigned", nullable: false), + WaitTimeDays = table.Column(type: "int", nullable: false), + RecoveryInitiatedDate = table.Column(type: "datetime(6)", nullable: true), + LastNotificationDate = table.Column(type: "datetime(6)", nullable: true), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_EmergencyAccess", x => x.Id); + table.ForeignKey( + name: "FK_EmergencyAccess_User_GranteeId", + column: x => x.GranteeId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_EmergencyAccess_User_GrantorId", + column: x => x.GrantorId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Folder", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Name = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Folder", x => x.Id); + table.ForeignKey( + name: "FK_Folder_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "OrganizationUser", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + Email = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Key = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ResetPasswordKey = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Status = table.Column(type: "tinyint unsigned", nullable: false), + Type = table.Column(type: "tinyint unsigned", nullable: false), + AccessAll = table.Column(type: "tinyint(1)", nullable: false), + ExternalId = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false), + Permissions = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4") + }, + constraints: table => + { + table.PrimaryKey("PK_OrganizationUser", x => x.Id); + table.ForeignKey( + name: "FK_OrganizationUser_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_OrganizationUser_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "ProviderUser", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + ProviderId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + Email = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Key = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Status = table.Column(type: "tinyint unsigned", nullable: false), + Type = table.Column(type: "tinyint unsigned", nullable: false), + Permissions = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_ProviderUser", x => x.Id); + table.ForeignKey( + name: "FK_ProviderUser_Provider_ProviderId", + column: x => x.ProviderId, + principalTable: "Provider", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_ProviderUser_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Send", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + Type = table.Column(type: "tinyint unsigned", nullable: false), + Data = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Key = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Password = table.Column(type: "varchar(300)", maxLength: 300, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + MaxAccessCount = table.Column(type: "int", nullable: true), + AccessCount = table.Column(type: "int", nullable: false), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false), + ExpirationDate = table.Column(type: "datetime(6)", nullable: true), + DeletionDate = table.Column(type: "datetime(6)", nullable: false), + Disabled = table.Column(type: "tinyint(1)", nullable: false), + HideEmail = table.Column(type: "tinyint(1)", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_Send", x => x.Id); + table.ForeignKey( + name: "FK_Send_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_Send_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "SsoUser", + columns: table => new + { + Id = table.Column(type: "bigint", nullable: false) + .Annotation("MySql:ValueGenerationStrategy", MySqlValueGenerationStrategy.IdentityColumn), + UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + ExternalId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_SsoUser", x => x.Id); + table.ForeignKey( + name: "FK_SsoUser_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_SsoUser_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "Transaction", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + Type = table.Column(type: "tinyint unsigned", nullable: false), + Amount = table.Column(type: "decimal(65,30)", nullable: false), + Refunded = table.Column(type: "tinyint(1)", nullable: true), + RefundedAmount = table.Column(type: "decimal(65,30)", nullable: true), + Details = table.Column(type: "varchar(100)", maxLength: 100, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + PaymentMethodType = table.Column(type: "tinyint unsigned", nullable: true), + Gateway = table.Column(type: "tinyint unsigned", nullable: true), + GatewayId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Transaction", x => x.Id); + table.ForeignKey( + name: "FK_Transaction_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_Transaction_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "U2f", + columns: table => new + { + Id = table.Column(type: "int", nullable: false) + .Annotation("MySql:ValueGenerationStrategy", MySqlValueGenerationStrategy.IdentityColumn), + UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + KeyHandle = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Challenge = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + AppId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Version = table.Column(type: "varchar(20)", maxLength: 20, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_U2f", x => x.Id); + table.ForeignKey( + name: "FK_U2f_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "CollectionGroups", + columns: table => new + { + CollectionId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + GroupId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + ReadOnly = table.Column(type: "tinyint(1)", nullable: false), + HidePasswords = table.Column(type: "tinyint(1)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_CollectionGroups", x => new { x.CollectionId, x.GroupId }); + table.ForeignKey( + name: "FK_CollectionGroups_Collection_CollectionId", + column: x => x.CollectionId, + principalTable: "Collection", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_CollectionGroups_Group_GroupId", + column: x => x.GroupId, + principalTable: "Group", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "CollectionCipher", + columns: table => new + { + CollectionId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + CipherId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci") + }, + constraints: table => + { + table.PrimaryKey("PK_CollectionCipher", x => new { x.CollectionId, x.CipherId }); + table.ForeignKey( + name: "FK_CollectionCipher_Cipher_CipherId", + column: x => x.CipherId, + principalTable: "Cipher", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_CollectionCipher_Collection_CollectionId", + column: x => x.CollectionId, + principalTable: "Collection", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "CollectionUsers", + columns: table => new + { + CollectionId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + OrganizationUserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + ReadOnly = table.Column(type: "tinyint(1)", nullable: false), + HidePasswords = table.Column(type: "tinyint(1)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_CollectionUsers", x => new { x.CollectionId, x.OrganizationUserId }); + table.ForeignKey( + name: "FK_CollectionUsers_Collection_CollectionId", + column: x => x.CollectionId, + principalTable: "Collection", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_CollectionUsers_OrganizationUser_OrganizationUserId", + column: x => x.OrganizationUserId, + principalTable: "OrganizationUser", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_CollectionUsers_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "GroupUser", + columns: table => new + { + GroupId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + OrganizationUserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + UserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci") + }, + constraints: table => + { + table.PrimaryKey("PK_GroupUser", x => new { x.GroupId, x.OrganizationUserId }); + table.ForeignKey( + name: "FK_GroupUser_Group_GroupId", + column: x => x.GroupId, + principalTable: "Group", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_GroupUser_OrganizationUser_OrganizationUserId", + column: x => x.OrganizationUserId, + principalTable: "OrganizationUser", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_GroupUser_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateTable( + name: "ProviderOrganizationProviderUser", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + ProviderOrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + ProviderUserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Type = table.Column(type: "tinyint unsigned", nullable: false), + Permissions = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_ProviderOrganizationProviderUser", x => x.Id); + table.ForeignKey( + name: "FK_ProviderOrganizationProviderUser_ProviderOrganization_Provid~", + column: x => x.ProviderOrganizationId, + principalTable: "ProviderOrganization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_ProviderOrganizationProviderUser_ProviderUser_ProviderUserId", + column: x => x.ProviderUserId, + principalTable: "ProviderUser", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); + + migrationBuilder.CreateIndex( + name: "IX_Cipher_OrganizationId", + table: "Cipher", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_Cipher_UserId", + table: "Cipher", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Collection_OrganizationId", + table: "Collection", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_CollectionCipher_CipherId", + table: "CollectionCipher", + column: "CipherId"); + + migrationBuilder.CreateIndex( + name: "IX_CollectionGroups_GroupId", + table: "CollectionGroups", + column: "GroupId"); + + migrationBuilder.CreateIndex( + name: "IX_CollectionUsers_OrganizationUserId", + table: "CollectionUsers", + column: "OrganizationUserId"); + + migrationBuilder.CreateIndex( + name: "IX_CollectionUsers_UserId", + table: "CollectionUsers", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Device_UserId", + table: "Device", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_EmergencyAccess_GranteeId", + table: "EmergencyAccess", + column: "GranteeId"); + + migrationBuilder.CreateIndex( + name: "IX_EmergencyAccess_GrantorId", + table: "EmergencyAccess", + column: "GrantorId"); + + migrationBuilder.CreateIndex( + name: "IX_Folder_UserId", + table: "Folder", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Group_OrganizationId", + table: "Group", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_GroupUser_OrganizationUserId", + table: "GroupUser", + column: "OrganizationUserId"); + + migrationBuilder.CreateIndex( + name: "IX_GroupUser_UserId", + table: "GroupUser", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_OrganizationUser_OrganizationId", + table: "OrganizationUser", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_OrganizationUser_UserId", + table: "OrganizationUser", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Policy_OrganizationId", + table: "Policy", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganization_OrganizationId", + table: "ProviderOrganization", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganization_ProviderId", + table: "ProviderOrganization", + column: "ProviderId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganizationProviderUser_ProviderOrganizationId", + table: "ProviderOrganizationProviderUser", + column: "ProviderOrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganizationProviderUser_ProviderUserId", + table: "ProviderOrganizationProviderUser", + column: "ProviderUserId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderUser_ProviderId", + table: "ProviderUser", + column: "ProviderId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderUser_UserId", + table: "ProviderUser", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Send_OrganizationId", + table: "Send", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_Send_UserId", + table: "Send", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_SsoConfig_OrganizationId", + table: "SsoConfig", + column: "OrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_SsoUser_OrganizationId", - table: "SsoUser", - column: "OrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_SsoUser_OrganizationId", + table: "SsoUser", + column: "OrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_SsoUser_UserId", - table: "SsoUser", - column: "UserId"); + migrationBuilder.CreateIndex( + name: "IX_SsoUser_UserId", + table: "SsoUser", + column: "UserId"); - migrationBuilder.CreateIndex( - name: "IX_Transaction_OrganizationId", - table: "Transaction", - column: "OrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_Transaction_OrganizationId", + table: "Transaction", + column: "OrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_Transaction_UserId", - table: "Transaction", - column: "UserId"); + migrationBuilder.CreateIndex( + name: "IX_Transaction_UserId", + table: "Transaction", + column: "UserId"); - migrationBuilder.CreateIndex( - name: "IX_U2f_UserId", - table: "U2f", - column: "UserId"); - } + migrationBuilder.CreateIndex( + name: "IX_U2f_UserId", + table: "U2f", + column: "UserId"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropTable( - name: "CollectionCipher"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropTable( + name: "CollectionCipher"); - migrationBuilder.DropTable( - name: "CollectionGroups"); + migrationBuilder.DropTable( + name: "CollectionGroups"); - migrationBuilder.DropTable( - name: "CollectionUsers"); + migrationBuilder.DropTable( + name: "CollectionUsers"); - migrationBuilder.DropTable( - name: "Device"); + migrationBuilder.DropTable( + name: "Device"); - migrationBuilder.DropTable( - name: "EmergencyAccess"); + migrationBuilder.DropTable( + name: "EmergencyAccess"); - migrationBuilder.DropTable( - name: "Event"); + migrationBuilder.DropTable( + name: "Event"); - migrationBuilder.DropTable( - name: "Folder"); + migrationBuilder.DropTable( + name: "Folder"); - migrationBuilder.DropTable( - name: "Grant"); + migrationBuilder.DropTable( + name: "Grant"); - migrationBuilder.DropTable( - name: "GroupUser"); + migrationBuilder.DropTable( + name: "GroupUser"); - migrationBuilder.DropTable( - name: "Installation"); + migrationBuilder.DropTable( + name: "Installation"); - migrationBuilder.DropTable( - name: "Policy"); + migrationBuilder.DropTable( + name: "Policy"); - migrationBuilder.DropTable( - name: "ProviderOrganizationProviderUser"); + migrationBuilder.DropTable( + name: "ProviderOrganizationProviderUser"); - migrationBuilder.DropTable( - name: "Send"); + migrationBuilder.DropTable( + name: "Send"); - migrationBuilder.DropTable( - name: "SsoConfig"); + migrationBuilder.DropTable( + name: "SsoConfig"); - migrationBuilder.DropTable( - name: "SsoUser"); + migrationBuilder.DropTable( + name: "SsoUser"); - migrationBuilder.DropTable( - name: "TaxRate"); + migrationBuilder.DropTable( + name: "TaxRate"); - migrationBuilder.DropTable( - name: "Transaction"); + migrationBuilder.DropTable( + name: "Transaction"); - migrationBuilder.DropTable( - name: "U2f"); + migrationBuilder.DropTable( + name: "U2f"); - migrationBuilder.DropTable( - name: "Cipher"); + migrationBuilder.DropTable( + name: "Cipher"); - migrationBuilder.DropTable( - name: "Collection"); + migrationBuilder.DropTable( + name: "Collection"); - migrationBuilder.DropTable( - name: "Group"); + migrationBuilder.DropTable( + name: "Group"); - migrationBuilder.DropTable( - name: "OrganizationUser"); + migrationBuilder.DropTable( + name: "OrganizationUser"); - migrationBuilder.DropTable( - name: "ProviderOrganization"); + migrationBuilder.DropTable( + name: "ProviderOrganization"); - migrationBuilder.DropTable( - name: "ProviderUser"); + migrationBuilder.DropTable( + name: "ProviderUser"); - migrationBuilder.DropTable( - name: "Organization"); + migrationBuilder.DropTable( + name: "Organization"); - migrationBuilder.DropTable( - name: "Provider"); + migrationBuilder.DropTable( + name: "Provider"); - migrationBuilder.DropTable( - name: "User"); - } + migrationBuilder.DropTable( + name: "User"); } } diff --git a/util/MySqlMigrations/Migrations/20210709095522_RemoveProviderOrganizationProviderUser.cs b/util/MySqlMigrations/Migrations/20210709095522_RemoveProviderOrganizationProviderUser.cs index cf28c6fce..cc53719fa 100644 --- a/util/MySqlMigrations/Migrations/20210709095522_RemoveProviderOrganizationProviderUser.cs +++ b/util/MySqlMigrations/Migrations/20210709095522_RemoveProviderOrganizationProviderUser.cs @@ -1,90 +1,89 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations +namespace Bit.MySqlMigrations.Migrations; + +public partial class RemoveProviderOrganizationProviderUser : Migration { - public partial class RemoveProviderOrganizationProviderUser : Migration + protected override void Up(MigrationBuilder migrationBuilder) { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropTable( - name: "ProviderOrganizationProviderUser"); + migrationBuilder.DropTable( + name: "ProviderOrganizationProviderUser"); - migrationBuilder.AddColumn( - name: "UseEvents", - table: "Provider", - type: "tinyint(1)", - nullable: false, - defaultValue: false); + migrationBuilder.AddColumn( + name: "UseEvents", + table: "Provider", + type: "tinyint(1)", + nullable: false, + defaultValue: false); - migrationBuilder.AddColumn( - name: "ProviderId", - table: "Event", - type: "char(36)", - nullable: true, - collation: "ascii_general_ci"); + migrationBuilder.AddColumn( + name: "ProviderId", + table: "Event", + type: "char(36)", + nullable: true, + collation: "ascii_general_ci"); - migrationBuilder.AddColumn( - name: "ProviderUserId", - table: "Event", - type: "char(36)", - nullable: true, - collation: "ascii_general_ci"); - } + migrationBuilder.AddColumn( + name: "ProviderUserId", + table: "Event", + type: "char(36)", + nullable: true, + collation: "ascii_general_ci"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "UseEvents", - table: "Provider"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "UseEvents", + table: "Provider"); - migrationBuilder.DropColumn( - name: "ProviderId", - table: "Event"); + migrationBuilder.DropColumn( + name: "ProviderId", + table: "Event"); - migrationBuilder.DropColumn( - name: "ProviderUserId", - table: "Event"); + migrationBuilder.DropColumn( + name: "ProviderUserId", + table: "Event"); - migrationBuilder.CreateTable( - name: "ProviderOrganizationProviderUser", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - Permissions = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - ProviderOrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - ProviderUserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - RevisionDate = table.Column(type: "datetime(6)", nullable: false), - Type = table.Column(type: "tinyint unsigned", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_ProviderOrganizationProviderUser", x => x.Id); - table.ForeignKey( - name: "FK_ProviderOrganizationProviderUser_ProviderOrganization_Provid~", - column: x => x.ProviderOrganizationId, - principalTable: "ProviderOrganization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_ProviderOrganizationProviderUser_ProviderUser_ProviderUserId", - column: x => x.ProviderUserId, - principalTable: "ProviderUser", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); + migrationBuilder.CreateTable( + name: "ProviderOrganizationProviderUser", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + Permissions = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + ProviderOrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + ProviderUserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + RevisionDate = table.Column(type: "datetime(6)", nullable: false), + Type = table.Column(type: "tinyint unsigned", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_ProviderOrganizationProviderUser", x => x.Id); + table.ForeignKey( + name: "FK_ProviderOrganizationProviderUser_ProviderOrganization_Provid~", + column: x => x.ProviderOrganizationId, + principalTable: "ProviderOrganization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_ProviderOrganizationProviderUser_ProviderUser_ProviderUserId", + column: x => x.ProviderUserId, + principalTable: "ProviderUser", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganizationProviderUser_ProviderOrganizationId", - table: "ProviderOrganizationProviderUser", - column: "ProviderOrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganizationProviderUser_ProviderOrganizationId", + table: "ProviderOrganizationProviderUser", + column: "ProviderOrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganizationProviderUser_ProviderUserId", - table: "ProviderOrganizationProviderUser", - column: "ProviderUserId"); - } + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganizationProviderUser_ProviderUserId", + table: "ProviderOrganizationProviderUser", + column: "ProviderUserId"); } } diff --git a/util/MySqlMigrations/Migrations/20210716142145_UserForcePasswordReset.cs b/util/MySqlMigrations/Migrations/20210716142145_UserForcePasswordReset.cs index 762aa0546..1c385968c 100644 --- a/util/MySqlMigrations/Migrations/20210716142145_UserForcePasswordReset.cs +++ b/util/MySqlMigrations/Migrations/20210716142145_UserForcePasswordReset.cs @@ -1,24 +1,23 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations -{ - public partial class UserForcePasswordReset : Migration - { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "ForcePasswordReset", - table: "User", - type: "tinyint(1)", - nullable: false, - defaultValue: false); - } +namespace Bit.MySqlMigrations.Migrations; - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "ForcePasswordReset", - table: "User"); - } +public partial class UserForcePasswordReset : Migration +{ + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "ForcePasswordReset", + table: "User", + type: "tinyint(1)", + nullable: false, + defaultValue: false); + } + + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "ForcePasswordReset", + table: "User"); } } diff --git a/util/MySqlMigrations/Migrations/20210921132418_AddMaxAutoscaleSeatsToOrganization.cs b/util/MySqlMigrations/Migrations/20210921132418_AddMaxAutoscaleSeatsToOrganization.cs index 2ae51dc4a..168667ecb 100644 --- a/util/MySqlMigrations/Migrations/20210921132418_AddMaxAutoscaleSeatsToOrganization.cs +++ b/util/MySqlMigrations/Migrations/20210921132418_AddMaxAutoscaleSeatsToOrganization.cs @@ -1,44 +1,43 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations +namespace Bit.MySqlMigrations.Migrations; + +public partial class AddMaxAutoscaleSeatsToOrganization : Migration { - public partial class AddMaxAutoscaleSeatsToOrganization : Migration + protected override void Up(MigrationBuilder migrationBuilder) { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "MaxAutoscaleSeats", - table: "Organization", - type: "int", - nullable: true); + migrationBuilder.AddColumn( + name: "MaxAutoscaleSeats", + table: "Organization", + type: "int", + nullable: true); - migrationBuilder.AddColumn( - name: "OwnersNotifiedOfAutoscaling", - table: "Organization", - type: "datetime(6)", - nullable: true); + migrationBuilder.AddColumn( + name: "OwnersNotifiedOfAutoscaling", + table: "Organization", + type: "datetime(6)", + nullable: true); - migrationBuilder.AddColumn( - name: "ProviderOrganizationId", - table: "Event", - type: "char(36)", - nullable: true, - collation: "ascii_general_ci"); - } + migrationBuilder.AddColumn( + name: "ProviderOrganizationId", + table: "Event", + type: "char(36)", + nullable: true, + collation: "ascii_general_ci"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "MaxAutoscaleSeats", - table: "Organization"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "MaxAutoscaleSeats", + table: "Organization"); - migrationBuilder.DropColumn( - name: "OwnersNotifiedOfAutoscaling", - table: "Organization"); + migrationBuilder.DropColumn( + name: "OwnersNotifiedOfAutoscaling", + table: "Organization"); - migrationBuilder.DropColumn( - name: "ProviderOrganizationId", - table: "Event"); - } + migrationBuilder.DropColumn( + name: "ProviderOrganizationId", + table: "Event"); } } diff --git a/util/MySqlMigrations/Migrations/20211011144835_SplitManageCollectionsPermissions2.cs b/util/MySqlMigrations/Migrations/20211011144835_SplitManageCollectionsPermissions2.cs index 8884d3340..19817d128 100644 --- a/util/MySqlMigrations/Migrations/20211011144835_SplitManageCollectionsPermissions2.cs +++ b/util/MySqlMigrations/Migrations/20211011144835_SplitManageCollectionsPermissions2.cs @@ -1,21 +1,20 @@ using Bit.Core.Utilities; using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations +namespace Bit.MySqlMigrations.Migrations; + +public partial class SplitManageCollectionsPermissions2 : Migration { - public partial class SplitManageCollectionsPermissions2 : Migration + private const string _scriptLocation = + "MySqlMigrations.Scripts.2021-09-21_01_SplitManageCollectionsPermission.sql"; + + protected override void Up(MigrationBuilder migrationBuilder) { - private const string _scriptLocation = - "MySqlMigrations.Scripts.2021-09-21_01_SplitManageCollectionsPermission.sql"; + migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync(_scriptLocation)); + } - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync(_scriptLocation)); - } - - protected override void Down(MigrationBuilder migrationBuilder) - { - throw new Exception("Irreversible migration"); - } + protected override void Down(MigrationBuilder migrationBuilder) + { + throw new Exception("Irreversible migration"); } } diff --git a/util/MySqlMigrations/Migrations/20211021201150_SetMaxAutoscaleSeatsToCurrentSeatCount.cs b/util/MySqlMigrations/Migrations/20211021201150_SetMaxAutoscaleSeatsToCurrentSeatCount.cs index af746bccc..00574ab65 100644 --- a/util/MySqlMigrations/Migrations/20211021201150_SetMaxAutoscaleSeatsToCurrentSeatCount.cs +++ b/util/MySqlMigrations/Migrations/20211021201150_SetMaxAutoscaleSeatsToCurrentSeatCount.cs @@ -1,21 +1,20 @@ using Bit.Core.Utilities; using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations +namespace Bit.MySqlMigrations.Migrations; + +public partial class SetMaxAutoscaleSeatsToCurrentSeatCount : Migration { - public partial class SetMaxAutoscaleSeatsToCurrentSeatCount : Migration + private const string _scriptLocation = + "MySqlMigrations.Scripts.2021-10-21_00_SetMaxAutoscaleSeatCount.sql"; + + protected override void Up(MigrationBuilder migrationBuilder) { - private const string _scriptLocation = - "MySqlMigrations.Scripts.2021-10-21_00_SetMaxAutoscaleSeatCount.sql"; + migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync(_scriptLocation)); + } - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync(_scriptLocation)); - } - - protected override void Down(MigrationBuilder migrationBuilder) - { - throw new Exception("Irreversible migration"); - } + protected override void Down(MigrationBuilder migrationBuilder) + { + throw new Exception("Irreversible migration"); } } diff --git a/util/MySqlMigrations/Migrations/20211108041911_KeyConnector.cs b/util/MySqlMigrations/Migrations/20211108041911_KeyConnector.cs index 3e48440fa..59ed36ff4 100644 --- a/util/MySqlMigrations/Migrations/20211108041911_KeyConnector.cs +++ b/util/MySqlMigrations/Migrations/20211108041911_KeyConnector.cs @@ -1,24 +1,23 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations -{ - public partial class KeyConnector : Migration - { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "UsesKeyConnector", - table: "User", - type: "tinyint(1)", - nullable: false, - defaultValue: false); - } +namespace Bit.MySqlMigrations.Migrations; - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "UsesKeyConnector", - table: "User"); - } +public partial class KeyConnector : Migration +{ + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "UsesKeyConnector", + table: "User", + type: "tinyint(1)", + nullable: false, + defaultValue: false); + } + + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "UsesKeyConnector", + table: "User"); } } diff --git a/util/MySqlMigrations/Migrations/20211108225243_OrganizationSponsorship.cs b/util/MySqlMigrations/Migrations/20211108225243_OrganizationSponsorship.cs index 68dc13557..155ecfc00 100644 --- a/util/MySqlMigrations/Migrations/20211108225243_OrganizationSponsorship.cs +++ b/util/MySqlMigrations/Migrations/20211108225243_OrganizationSponsorship.cs @@ -1,85 +1,84 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations +namespace Bit.MySqlMigrations.Migrations; + +public partial class OrganizationSponsorship : Migration { - public partial class OrganizationSponsorship : Migration + protected override void Up(MigrationBuilder migrationBuilder) { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "UsesCryptoAgent", - table: "User", - type: "tinyint(1)", - nullable: false, - defaultValue: false); + migrationBuilder.AddColumn( + name: "UsesCryptoAgent", + table: "User", + type: "tinyint(1)", + nullable: false, + defaultValue: false); - migrationBuilder.CreateTable( - name: "OrganizationSponsorship", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - InstallationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - SponsoringOrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - SponsoringOrganizationUserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - SponsoredOrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), - FriendlyName = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - OfferedToEmail = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - PlanSponsorshipType = table.Column(type: "tinyint unsigned", nullable: true), - CloudSponsor = table.Column(type: "tinyint(1)", nullable: false), - LastSyncDate = table.Column(type: "datetime(6)", nullable: true), - TimesRenewedWithoutValidation = table.Column(type: "tinyint unsigned", nullable: false), - SponsorshipLapsedDate = table.Column(type: "datetime(6)", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_OrganizationSponsorship", x => x.Id); - table.ForeignKey( - name: "FK_OrganizationSponsorship_Installation_InstallationId", - column: x => x.InstallationId, - principalTable: "Installation", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoredOrganizationId", - column: x => x.SponsoredOrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", - column: x => x.SponsoringOrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }) - .Annotation("MySql:CharSet", "utf8mb4"); + migrationBuilder.CreateTable( + name: "OrganizationSponsorship", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + InstallationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + SponsoringOrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + SponsoringOrganizationUserId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + SponsoredOrganizationId = table.Column(type: "char(36)", nullable: true, collation: "ascii_general_ci"), + FriendlyName = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + OfferedToEmail = table.Column(type: "varchar(256)", maxLength: 256, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + PlanSponsorshipType = table.Column(type: "tinyint unsigned", nullable: true), + CloudSponsor = table.Column(type: "tinyint(1)", nullable: false), + LastSyncDate = table.Column(type: "datetime(6)", nullable: true), + TimesRenewedWithoutValidation = table.Column(type: "tinyint unsigned", nullable: false), + SponsorshipLapsedDate = table.Column(type: "datetime(6)", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_OrganizationSponsorship", x => x.Id); + table.ForeignKey( + name: "FK_OrganizationSponsorship_Installation_InstallationId", + column: x => x.InstallationId, + principalTable: "Installation", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoredOrganizationId", + column: x => x.SponsoredOrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", + column: x => x.SponsoringOrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }) + .Annotation("MySql:CharSet", "utf8mb4"); - migrationBuilder.CreateIndex( - name: "IX_OrganizationSponsorship_InstallationId", - table: "OrganizationSponsorship", - column: "InstallationId"); + migrationBuilder.CreateIndex( + name: "IX_OrganizationSponsorship_InstallationId", + table: "OrganizationSponsorship", + column: "InstallationId"); - migrationBuilder.CreateIndex( - name: "IX_OrganizationSponsorship_SponsoredOrganizationId", - table: "OrganizationSponsorship", - column: "SponsoredOrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_OrganizationSponsorship_SponsoredOrganizationId", + table: "OrganizationSponsorship", + column: "SponsoredOrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_OrganizationSponsorship_SponsoringOrganizationId", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId"); - } + migrationBuilder.CreateIndex( + name: "IX_OrganizationSponsorship_SponsoringOrganizationId", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropTable( - name: "OrganizationSponsorship"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropTable( + name: "OrganizationSponsorship"); - migrationBuilder.DropColumn( - name: "UsesCryptoAgent", - table: "User"); - } + migrationBuilder.DropColumn( + name: "UsesCryptoAgent", + table: "User"); } } diff --git a/util/MySqlMigrations/Migrations/20211115145402_KeyConnectorFlag.cs b/util/MySqlMigrations/Migrations/20211115145402_KeyConnectorFlag.cs index d68eb65a1..62d924f5c 100644 --- a/util/MySqlMigrations/Migrations/20211115145402_KeyConnectorFlag.cs +++ b/util/MySqlMigrations/Migrations/20211115145402_KeyConnectorFlag.cs @@ -1,24 +1,23 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations -{ - public partial class KeyConnectorFlag : Migration - { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "UseKeyConnector", - table: "Organization", - type: "tinyint(1)", - nullable: false, - defaultValue: false); - } +namespace Bit.MySqlMigrations.Migrations; - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "UseKeyConnector", - table: "Organization"); - } +public partial class KeyConnectorFlag : Migration +{ + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "UseKeyConnector", + table: "Organization", + type: "tinyint(1)", + nullable: false, + defaultValue: false); + } + + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "UseKeyConnector", + table: "Organization"); } } diff --git a/util/MySqlMigrations/Migrations/20220121092546_RemoveU2F.cs b/util/MySqlMigrations/Migrations/20220121092546_RemoveU2F.cs index 8d9250fee..8950ac914 100644 --- a/util/MySqlMigrations/Migrations/20220121092546_RemoveU2F.cs +++ b/util/MySqlMigrations/Migrations/20220121092546_RemoveU2F.cs @@ -1,51 +1,50 @@ using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations +namespace Bit.MySqlMigrations.Migrations; + +public partial class RemoveU2F : Migration { - public partial class RemoveU2F : Migration + protected override void Up(MigrationBuilder migrationBuilder) { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropTable( - name: "U2f"); - } + migrationBuilder.DropTable( + name: "U2f"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.CreateTable( - name: "U2f", - columns: table => new - { - Id = table.Column(type: "int", nullable: false) - .Annotation("MySql:ValueGenerationStrategy", MySqlValueGenerationStrategy.IdentityColumn), - AppId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - Challenge = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - CreationDate = table.Column(type: "datetime(6)", nullable: false), - KeyHandle = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Version = table.Column(type: "varchar(20)", maxLength: 20, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4") - }, - constraints: table => - { - table.PrimaryKey("PK_U2f", x => x.Id); - table.ForeignKey( - name: "FK_U2f_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.CreateTable( + name: "U2f", + columns: table => new + { + Id = table.Column(type: "int", nullable: false) + .Annotation("MySql:ValueGenerationStrategy", MySqlValueGenerationStrategy.IdentityColumn), + AppId = table.Column(type: "varchar(50)", maxLength: 50, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + Challenge = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + CreationDate = table.Column(type: "datetime(6)", nullable: false), + KeyHandle = table.Column(type: "varchar(200)", maxLength: 200, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + UserId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Version = table.Column(type: "varchar(20)", maxLength: 20, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4") + }, + constraints: table => + { + table.PrimaryKey("PK_U2f", x => x.Id); + table.ForeignKey( + name: "FK_U2f_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); - migrationBuilder.CreateIndex( - name: "IX_U2f_UserId", - table: "U2f", - column: "UserId"); - } + migrationBuilder.CreateIndex( + name: "IX_U2f_UserId", + table: "U2f", + column: "UserId"); } } diff --git a/util/MySqlMigrations/Migrations/20220301215315_FailedLoginCaptcha.cs b/util/MySqlMigrations/Migrations/20220301215315_FailedLoginCaptcha.cs index f7c8bcc57..91245fc46 100644 --- a/util/MySqlMigrations/Migrations/20220301215315_FailedLoginCaptcha.cs +++ b/util/MySqlMigrations/Migrations/20220301215315_FailedLoginCaptcha.cs @@ -1,34 +1,33 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations +namespace Bit.MySqlMigrations.Migrations; + +public partial class FailedLoginCaptcha : Migration { - public partial class FailedLoginCaptcha : Migration + protected override void Up(MigrationBuilder migrationBuilder) { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "FailedLoginCount", - table: "User", - type: "int", - nullable: false, - defaultValue: 0); + migrationBuilder.AddColumn( + name: "FailedLoginCount", + table: "User", + type: "int", + nullable: false, + defaultValue: 0); - migrationBuilder.AddColumn( - name: "LastFailedLoginDate", - table: "User", - type: "datetime(6)", - nullable: true); - } + migrationBuilder.AddColumn( + name: "LastFailedLoginDate", + table: "User", + type: "datetime(6)", + nullable: true); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "FailedLoginCount", - table: "User"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "FailedLoginCount", + table: "User"); - migrationBuilder.DropColumn( - name: "LastFailedLoginDate", - table: "User"); - } + migrationBuilder.DropColumn( + name: "LastFailedLoginDate", + table: "User"); } } diff --git a/util/MySqlMigrations/Migrations/20220322191314_SelfHostF4E.cs b/util/MySqlMigrations/Migrations/20220322191314_SelfHostF4E.cs index c6f9c8934..993399e50 100644 --- a/util/MySqlMigrations/Migrations/20220322191314_SelfHostF4E.cs +++ b/util/MySqlMigrations/Migrations/20220322191314_SelfHostF4E.cs @@ -1,158 +1,157 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations +namespace Bit.MySqlMigrations.Migrations; + +public partial class SelfHostF4E : Migration { - public partial class SelfHostF4E : Migration + private const string _scriptLocationTemplate = "2022-03-01_00_{0}_MigrateOrganizationApiKeys.sql"; + + protected override void Up(MigrationBuilder migrationBuilder) { - private const string _scriptLocationTemplate = "2022-03-01_00_{0}_MigrateOrganizationApiKeys.sql"; + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Installation_InstallationId", + table: "OrganizationSponsorship"); - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Installation_InstallationId", - table: "OrganizationSponsorship"); + migrationBuilder.DropIndex( + name: "IX_OrganizationSponsorship_InstallationId", + table: "OrganizationSponsorship"); - migrationBuilder.DropIndex( - name: "IX_OrganizationSponsorship_InstallationId", - table: "OrganizationSponsorship"); + migrationBuilder.DropColumn( + name: "InstallationId", + table: "OrganizationSponsorship"); - migrationBuilder.DropColumn( - name: "InstallationId", - table: "OrganizationSponsorship"); + migrationBuilder.DropColumn( + name: "TimesRenewedWithoutValidation", + table: "OrganizationSponsorship"); - migrationBuilder.DropColumn( - name: "TimesRenewedWithoutValidation", - table: "OrganizationSponsorship"); + migrationBuilder.CreateTable( + name: "OrganizationApiKey", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Type = table.Column(type: "tinyint unsigned", nullable: false), + ApiKey = table.Column(type: "varchar(30)", maxLength: 30, nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"), + RevisionDate = table.Column(type: "datetime(6)", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_OrganizationApiKey", x => x.Id); + table.ForeignKey( + name: "FK_OrganizationApiKey_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); - migrationBuilder.CreateTable( - name: "OrganizationApiKey", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Type = table.Column(type: "tinyint unsigned", nullable: false), - ApiKey = table.Column(type: "varchar(30)", maxLength: 30, nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"), - RevisionDate = table.Column(type: "datetime(6)", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_OrganizationApiKey", x => x.Id); - table.ForeignKey( - name: "FK_OrganizationApiKey_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); + migrationBuilder.SqlResource(_scriptLocationTemplate); - migrationBuilder.SqlResource(_scriptLocationTemplate); + migrationBuilder.DropColumn( + name: "ApiKey", + table: "Organization"); - migrationBuilder.DropColumn( - name: "ApiKey", - table: "Organization"); + migrationBuilder.RenameColumn( + name: "SponsorshipLapsedDate", + table: "OrganizationSponsorship", + newName: "ValidUntil"); - migrationBuilder.RenameColumn( - name: "SponsorshipLapsedDate", - table: "OrganizationSponsorship", - newName: "ValidUntil"); - - migrationBuilder.RenameColumn( - name: "CloudSponsor", - table: "OrganizationSponsorship", - newName: "ToDelete"); + migrationBuilder.RenameColumn( + name: "CloudSponsor", + table: "OrganizationSponsorship", + newName: "ToDelete"); - migrationBuilder.CreateTable( - name: "OrganizationConnection", - columns: table => new - { - Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Type = table.Column(type: "tinyint unsigned", nullable: false), - OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), - Enabled = table.Column(type: "tinyint(1)", nullable: false), - Config = table.Column(type: "longtext", nullable: true) - .Annotation("MySql:CharSet", "utf8mb4") - }, - constraints: table => - { - table.PrimaryKey("PK_OrganizationConnection", x => x.Id); - table.ForeignKey( - name: "FK_OrganizationConnection_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }) - .Annotation("MySql:CharSet", "utf8mb4"); + migrationBuilder.CreateTable( + name: "OrganizationConnection", + columns: table => new + { + Id = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Type = table.Column(type: "tinyint unsigned", nullable: false), + OrganizationId = table.Column(type: "char(36)", nullable: false, collation: "ascii_general_ci"), + Enabled = table.Column(type: "tinyint(1)", nullable: false), + Config = table.Column(type: "longtext", nullable: true) + .Annotation("MySql:CharSet", "utf8mb4") + }, + constraints: table => + { + table.PrimaryKey("PK_OrganizationConnection", x => x.Id); + table.ForeignKey( + name: "FK_OrganizationConnection_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }) + .Annotation("MySql:CharSet", "utf8mb4"); - migrationBuilder.CreateIndex( - name: "IX_OrganizationApiKey_OrganizationId", - table: "OrganizationApiKey", - column: "OrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_OrganizationApiKey_OrganizationId", + table: "OrganizationApiKey", + column: "OrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_OrganizationConnection_OrganizationId", - table: "OrganizationConnection", - column: "OrganizationId"); - } + migrationBuilder.CreateIndex( + name: "IX_OrganizationConnection_OrganizationId", + table: "OrganizationConnection", + column: "OrganizationId"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "ApiKey", - table: "Organization", - type: "varchar(30)", - maxLength: 30, - nullable: true) - .Annotation("MySql:CharSet", "utf8mb4"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "ApiKey", + table: "Organization", + type: "varchar(30)", + maxLength: 30, + nullable: true) + .Annotation("MySql:CharSet", "utf8mb4"); - migrationBuilder.SqlResource(_scriptLocationTemplate); + migrationBuilder.SqlResource(_scriptLocationTemplate); - migrationBuilder.DropTable( - name: "OrganizationApiKey"); + migrationBuilder.DropTable( + name: "OrganizationApiKey"); - migrationBuilder.DropTable( - name: "OrganizationConnection"); + migrationBuilder.DropTable( + name: "OrganizationConnection"); - migrationBuilder.RenameColumn( - name: "ValidUntil", - table: "OrganizationSponsorship", - newName: "SponsorshipLapsedDate"); + migrationBuilder.RenameColumn( + name: "ValidUntil", + table: "OrganizationSponsorship", + newName: "SponsorshipLapsedDate"); - migrationBuilder.RenameColumn( - name: "ToDelete", - table: "OrganizationSponsorship", - newName: "CloudSponsor"); + migrationBuilder.RenameColumn( + name: "ToDelete", + table: "OrganizationSponsorship", + newName: "CloudSponsor"); - migrationBuilder.AddColumn( - name: "InstallationId", - table: "OrganizationSponsorship", - type: "char(36)", - nullable: true, - collation: "ascii_general_ci"); + migrationBuilder.AddColumn( + name: "InstallationId", + table: "OrganizationSponsorship", + type: "char(36)", + nullable: true, + collation: "ascii_general_ci"); - migrationBuilder.AddColumn( - name: "TimesRenewedWithoutValidation", - table: "OrganizationSponsorship", - type: "tinyint unsigned", - nullable: false, - defaultValue: (byte)0); + migrationBuilder.AddColumn( + name: "TimesRenewedWithoutValidation", + table: "OrganizationSponsorship", + type: "tinyint unsigned", + nullable: false, + defaultValue: (byte)0); - migrationBuilder.CreateIndex( - name: "IX_OrganizationSponsorship_InstallationId", - table: "OrganizationSponsorship", - column: "InstallationId"); + migrationBuilder.CreateIndex( + name: "IX_OrganizationSponsorship_InstallationId", + table: "OrganizationSponsorship", + column: "InstallationId"); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Installation_InstallationId", - table: "OrganizationSponsorship", - column: "InstallationId", - principalTable: "Installation", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - } + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Installation_InstallationId", + table: "OrganizationSponsorship", + column: "InstallationId", + principalTable: "Installation", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); } } diff --git a/util/MySqlMigrations/Migrations/20220411191518_SponsorshipBulkActions.cs b/util/MySqlMigrations/Migrations/20220411191518_SponsorshipBulkActions.cs index 30e31e015..9b66e00cd 100644 --- a/util/MySqlMigrations/Migrations/20220411191518_SponsorshipBulkActions.cs +++ b/util/MySqlMigrations/Migrations/20220411191518_SponsorshipBulkActions.cs @@ -1,81 +1,80 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations +namespace Bit.MySqlMigrations.Migrations; + +public partial class SponsorshipBulkActions : Migration { - public partial class SponsorshipBulkActions : Migration + protected override void Up(MigrationBuilder migrationBuilder) { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", - table: "OrganizationSponsorship"); + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", + table: "OrganizationSponsorship"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationUserId", - table: "OrganizationSponsorship", - type: "char(36)", - nullable: false, - defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), - collation: "ascii_general_ci", - oldClrType: typeof(Guid), - oldType: "char(36)", - oldNullable: true) - .OldAnnotation("Relational:Collation", "ascii_general_ci"); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationUserId", + table: "OrganizationSponsorship", + type: "char(36)", + nullable: false, + defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), + collation: "ascii_general_ci", + oldClrType: typeof(Guid), + oldType: "char(36)", + oldNullable: true) + .OldAnnotation("Relational:Collation", "ascii_general_ci"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationId", - table: "OrganizationSponsorship", - type: "char(36)", - nullable: false, - defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), - collation: "ascii_general_ci", - oldClrType: typeof(Guid), - oldType: "char(36)", - oldNullable: true) - .OldAnnotation("Relational:Collation", "ascii_general_ci"); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationId", + table: "OrganizationSponsorship", + type: "char(36)", + nullable: false, + defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), + collation: "ascii_general_ci", + oldClrType: typeof(Guid), + oldType: "char(36)", + oldNullable: true) + .OldAnnotation("Relational:Collation", "ascii_general_ci"); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId", - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - } + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId", + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", - table: "OrganizationSponsorship"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", + table: "OrganizationSponsorship"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationUserId", - table: "OrganizationSponsorship", - type: "char(36)", - nullable: true, - collation: "ascii_general_ci", - oldClrType: typeof(Guid), - oldType: "char(36)") - .OldAnnotation("Relational:Collation", "ascii_general_ci"); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationUserId", + table: "OrganizationSponsorship", + type: "char(36)", + nullable: true, + collation: "ascii_general_ci", + oldClrType: typeof(Guid), + oldType: "char(36)") + .OldAnnotation("Relational:Collation", "ascii_general_ci"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationId", - table: "OrganizationSponsorship", - type: "char(36)", - nullable: true, - collation: "ascii_general_ci", - oldClrType: typeof(Guid), - oldType: "char(36)") - .OldAnnotation("Relational:Collation", "ascii_general_ci"); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationId", + table: "OrganizationSponsorship", + type: "char(36)", + nullable: true, + collation: "ascii_general_ci", + oldClrType: typeof(Guid), + oldType: "char(36)") + .OldAnnotation("Relational:Collation", "ascii_general_ci"); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId", - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - } + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId", + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); } } diff --git a/util/MySqlMigrations/Migrations/20220420170738_AddInstallationIdToEvents.cs b/util/MySqlMigrations/Migrations/20220420170738_AddInstallationIdToEvents.cs index 77b3b5a69..d07b0e41d 100644 --- a/util/MySqlMigrations/Migrations/20220420170738_AddInstallationIdToEvents.cs +++ b/util/MySqlMigrations/Migrations/20220420170738_AddInstallationIdToEvents.cs @@ -1,70 +1,69 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations +namespace Bit.MySqlMigrations.Migrations; + +public partial class AddInstallationIdToEvents : Migration { - public partial class AddInstallationIdToEvents : Migration + protected override void Up(MigrationBuilder migrationBuilder) { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", - table: "OrganizationSponsorship"); + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", + table: "OrganizationSponsorship"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationId", - table: "OrganizationSponsorship", - type: "char(36)", - nullable: true, - collation: "ascii_general_ci", - oldClrType: typeof(Guid), - oldType: "char(36)") - .OldAnnotation("Relational:Collation", "ascii_general_ci"); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationId", + table: "OrganizationSponsorship", + type: "char(36)", + nullable: true, + collation: "ascii_general_ci", + oldClrType: typeof(Guid), + oldType: "char(36)") + .OldAnnotation("Relational:Collation", "ascii_general_ci"); - migrationBuilder.AddColumn( - name: "InstallationId", - table: "Event", - type: "char(36)", - nullable: true, - collation: "ascii_general_ci"); + migrationBuilder.AddColumn( + name: "InstallationId", + table: "Event", + type: "char(36)", + nullable: true, + collation: "ascii_general_ci"); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId", - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - } + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId", + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", - table: "OrganizationSponsorship"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", + table: "OrganizationSponsorship"); - migrationBuilder.DropColumn( - name: "InstallationId", - table: "Event"); + migrationBuilder.DropColumn( + name: "InstallationId", + table: "Event"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationId", - table: "OrganizationSponsorship", - type: "char(36)", - nullable: false, - defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), - collation: "ascii_general_ci", - oldClrType: typeof(Guid), - oldType: "char(36)", - oldNullable: true) - .OldAnnotation("Relational:Collation", "ascii_general_ci"); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationId", + table: "OrganizationSponsorship", + type: "char(36)", + nullable: false, + defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), + collation: "ascii_general_ci", + oldClrType: typeof(Guid), + oldType: "char(36)", + oldNullable: true) + .OldAnnotation("Relational:Collation", "ascii_general_ci"); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId", - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - } + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganizationId", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId", + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); } } diff --git a/util/MySqlMigrations/Migrations/20220524171600_DeviceUnknownVerification.cs b/util/MySqlMigrations/Migrations/20220524171600_DeviceUnknownVerification.cs index 23017d4da..2ce7dadf2 100644 --- a/util/MySqlMigrations/Migrations/20220524171600_DeviceUnknownVerification.cs +++ b/util/MySqlMigrations/Migrations/20220524171600_DeviceUnknownVerification.cs @@ -1,24 +1,23 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations -{ - public partial class DeviceUnknownVerification : Migration - { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "UnknownDeviceVerificationEnabled", - table: "User", - type: "tinyint(1)", - nullable: false, - defaultValue: true); - } +namespace Bit.MySqlMigrations.Migrations; - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "UnknownDeviceVerificationEnabled", - table: "User"); - } +public partial class DeviceUnknownVerification : Migration +{ + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "UnknownDeviceVerificationEnabled", + table: "User", + type: "tinyint(1)", + nullable: false, + defaultValue: true); + } + + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "UnknownDeviceVerificationEnabled", + table: "User"); } } diff --git a/util/MySqlMigrations/Migrations/20220608191914_DeactivatedUserStatus.cs b/util/MySqlMigrations/Migrations/20220608191914_DeactivatedUserStatus.cs index 52b10dcb3..d0c5caf37 100644 --- a/util/MySqlMigrations/Migrations/20220608191914_DeactivatedUserStatus.cs +++ b/util/MySqlMigrations/Migrations/20220608191914_DeactivatedUserStatus.cs @@ -1,29 +1,28 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.MySqlMigrations.Migrations -{ - public partial class DeactivatedUserStatus : Migration - { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AlterColumn( - name: "Status", - table: "OrganizationUser", - type: "smallint", - nullable: false, - oldClrType: typeof(byte), - oldType: "tinyint unsigned"); - } +namespace Bit.MySqlMigrations.Migrations; - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.AlterColumn( - name: "Status", - table: "OrganizationUser", - type: "tinyint unsigned", - nullable: false, - oldClrType: typeof(short), - oldType: "smallint"); - } +public partial class DeactivatedUserStatus : Migration +{ + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AlterColumn( + name: "Status", + table: "OrganizationUser", + type: "smallint", + nullable: false, + oldClrType: typeof(byte), + oldType: "tinyint unsigned"); + } + + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.AlterColumn( + name: "Status", + table: "OrganizationUser", + type: "tinyint unsigned", + nullable: false, + oldClrType: typeof(short), + oldType: "smallint"); } } diff --git a/util/MySqlMigrations/Migrations/20220707163017_UseScimFlag.cs b/util/MySqlMigrations/Migrations/20220707163017_UseScimFlag.cs index 4c2998612..c0033e60d 100644 --- a/util/MySqlMigrations/Migrations/20220707163017_UseScimFlag.cs +++ b/util/MySqlMigrations/Migrations/20220707163017_UseScimFlag.cs @@ -2,25 +2,24 @@ #nullable disable -namespace Bit.MySqlMigrations.Migrations -{ - public partial class UseScimFlag : Migration - { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "UseScim", - table: "Organization", - type: "tinyint(1)", - nullable: false, - defaultValue: false); - } +namespace Bit.MySqlMigrations.Migrations; - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "UseScim", - table: "Organization"); - } +public partial class UseScimFlag : Migration +{ + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "UseScim", + table: "Organization", + type: "tinyint(1)", + nullable: false, + defaultValue: false); + } + + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "UseScim", + table: "Organization"); } } diff --git a/util/PostgresMigrations/Factories.cs b/util/PostgresMigrations/Factories.cs index 532dddf73..5504fe58b 100644 --- a/util/PostgresMigrations/Factories.cs +++ b/util/PostgresMigrations/Factories.cs @@ -4,34 +4,33 @@ using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Design; using Microsoft.Extensions.Configuration; -namespace MySqlMigrations -{ - public static class GlobalSettingsFactory - { - public static GlobalSettings GlobalSettings { get; } = new GlobalSettings(); - static GlobalSettingsFactory() - { - var configBuilder = new ConfigurationBuilder().AddUserSecrets(); - var Configuration = configBuilder.Build(); - ConfigurationBinder.Bind(Configuration.GetSection("GlobalSettings"), GlobalSettings); - } - } +namespace MySqlMigrations; - public class DatabaseContextFactory : IDesignTimeDbContextFactory +public static class GlobalSettingsFactory +{ + public static GlobalSettings GlobalSettings { get; } = new GlobalSettings(); + static GlobalSettingsFactory() { - public DatabaseContext CreateDbContext(string[] args) - { - var globalSettings = GlobalSettingsFactory.GlobalSettings; - var optionsBuilder = new DbContextOptionsBuilder(); - var connectionString = globalSettings.PostgreSql?.ConnectionString; - if (string.IsNullOrWhiteSpace(connectionString)) - { - throw new Exception("No Postgres connection string found."); - } - optionsBuilder.UseNpgsql( - connectionString, - b => b.MigrationsAssembly("PostgresMigrations")); - return new DatabaseContext(optionsBuilder.Options); - } + var configBuilder = new ConfigurationBuilder().AddUserSecrets(); + var Configuration = configBuilder.Build(); + ConfigurationBinder.Bind(Configuration.GetSection("GlobalSettings"), GlobalSettings); + } +} + +public class DatabaseContextFactory : IDesignTimeDbContextFactory +{ + public DatabaseContext CreateDbContext(string[] args) + { + var globalSettings = GlobalSettingsFactory.GlobalSettings; + var optionsBuilder = new DbContextOptionsBuilder(); + var connectionString = globalSettings.PostgreSql?.ConnectionString; + if (string.IsNullOrWhiteSpace(connectionString)) + { + throw new Exception("No Postgres connection string found."); + } + optionsBuilder.UseNpgsql( + connectionString, + b => b.MigrationsAssembly("PostgresMigrations")); + return new DatabaseContext(optionsBuilder.Options); } } diff --git a/util/PostgresMigrations/Migrations/20210708191531_Init.cs b/util/PostgresMigrations/Migrations/20210708191531_Init.cs index 3de407ff2..068e292ce 100644 --- a/util/PostgresMigrations/Migrations/20210708191531_Init.cs +++ b/util/PostgresMigrations/Migrations/20210708191531_Init.cs @@ -1,1008 +1,1007 @@ using Microsoft.EntityFrameworkCore.Migrations; using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata; -namespace Bit.PostgresMigrations.Migrations +namespace Bit.PostgresMigrations.Migrations; + +public partial class Init : Migration { - public partial class Init : Migration + protected override void Up(MigrationBuilder migrationBuilder) { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AlterDatabase() - .Annotation("Npgsql:CollationDefinition:postgresIndetermanisticCollation", "en-u-ks-primary,en-u-ks-primary,icu,False"); - - migrationBuilder.CreateTable( - name: "Event", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - Date = table.Column(type: "timestamp without time zone", nullable: false), - Type = table.Column(type: "integer", nullable: false), - UserId = table.Column(type: "uuid", nullable: true), - OrganizationId = table.Column(type: "uuid", nullable: true), - CipherId = table.Column(type: "uuid", nullable: true), - CollectionId = table.Column(type: "uuid", nullable: true), - PolicyId = table.Column(type: "uuid", nullable: true), - GroupId = table.Column(type: "uuid", nullable: true), - OrganizationUserId = table.Column(type: "uuid", nullable: true), - DeviceType = table.Column(type: "smallint", nullable: true), - IpAddress = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - ActingUserId = table.Column(type: "uuid", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_Event", x => x.Id); - }); - - migrationBuilder.CreateTable( - name: "Grant", - columns: table => new - { - Key = table.Column(type: "character varying(200)", maxLength: 200, nullable: false), - Type = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - SubjectId = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), - SessionId = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), - ClientId = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), - Description = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - ExpirationDate = table.Column(type: "timestamp without time zone", nullable: true), - ConsumedDate = table.Column(type: "timestamp without time zone", nullable: true), - Data = table.Column(type: "text", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_Grant", x => x.Key); - }); - - migrationBuilder.CreateTable( - name: "Installation", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - Email = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), - Key = table.Column(type: "character varying(150)", maxLength: 150, nullable: true), - Enabled = table.Column(type: "boolean", nullable: false), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Installation", x => x.Id); - }); - - migrationBuilder.CreateTable( - name: "Organization", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - Identifier = table.Column(type: "character varying(50)", maxLength: 50, nullable: true, collation: "postgresIndetermanisticCollation"), - Name = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - BusinessName = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - BusinessAddress1 = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - BusinessAddress2 = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - BusinessAddress3 = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - BusinessCountry = table.Column(type: "character varying(2)", maxLength: 2, nullable: true), - BusinessTaxNumber = table.Column(type: "character varying(30)", maxLength: 30, nullable: true), - BillingEmail = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), - Plan = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - PlanType = table.Column(type: "smallint", nullable: false), - Seats = table.Column(type: "integer", nullable: true), - MaxCollections = table.Column(type: "smallint", nullable: true), - UsePolicies = table.Column(type: "boolean", nullable: false), - UseSso = table.Column(type: "boolean", nullable: false), - UseGroups = table.Column(type: "boolean", nullable: false), - UseDirectory = table.Column(type: "boolean", nullable: false), - UseEvents = table.Column(type: "boolean", nullable: false), - UseTotp = table.Column(type: "boolean", nullable: false), - Use2fa = table.Column(type: "boolean", nullable: false), - UseApi = table.Column(type: "boolean", nullable: false), - UseResetPassword = table.Column(type: "boolean", nullable: false), - SelfHost = table.Column(type: "boolean", nullable: false), - UsersGetPremium = table.Column(type: "boolean", nullable: false), - Storage = table.Column(type: "bigint", nullable: true), - MaxStorageGb = table.Column(type: "smallint", nullable: true), - Gateway = table.Column(type: "smallint", nullable: true), - GatewayCustomerId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - GatewaySubscriptionId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - ReferenceData = table.Column(type: "text", nullable: true), - Enabled = table.Column(type: "boolean", nullable: false), - LicenseKey = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), - ApiKey = table.Column(type: "character varying(30)", maxLength: 30, nullable: true), - PublicKey = table.Column(type: "text", nullable: true), - PrivateKey = table.Column(type: "text", nullable: true), - TwoFactorProviders = table.Column(type: "text", nullable: true), - ExpirationDate = table.Column(type: "timestamp without time zone", nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Organization", x => x.Id); - }); - - migrationBuilder.CreateTable( - name: "Provider", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - Name = table.Column(type: "text", nullable: true), - BusinessName = table.Column(type: "text", nullable: true), - BusinessAddress1 = table.Column(type: "text", nullable: true), - BusinessAddress2 = table.Column(type: "text", nullable: true), - BusinessAddress3 = table.Column(type: "text", nullable: true), - BusinessCountry = table.Column(type: "text", nullable: true), - BusinessTaxNumber = table.Column(type: "text", nullable: true), - BillingEmail = table.Column(type: "text", nullable: true), - Status = table.Column(type: "smallint", nullable: false), - UseEvents = table.Column(type: "boolean", nullable: false), - Enabled = table.Column(type: "boolean", nullable: false), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Provider", x => x.Id); - }); - - migrationBuilder.CreateTable( - name: "TaxRate", - columns: table => new - { - Id = table.Column(type: "character varying(40)", maxLength: 40, nullable: false), - Country = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - State = table.Column(type: "character varying(2)", maxLength: 2, nullable: true), - PostalCode = table.Column(type: "character varying(10)", maxLength: 10, nullable: true), - Rate = table.Column(type: "numeric", nullable: false), - Active = table.Column(type: "boolean", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_TaxRate", x => x.Id); - }); - - migrationBuilder.CreateTable( - name: "User", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - Name = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - Email = table.Column(type: "character varying(256)", maxLength: 256, nullable: false, collation: "postgresIndetermanisticCollation"), - EmailVerified = table.Column(type: "boolean", nullable: false), - MasterPassword = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), - MasterPasswordHint = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - Culture = table.Column(type: "character varying(10)", maxLength: 10, nullable: true), - SecurityStamp = table.Column(type: "character varying(50)", maxLength: 50, nullable: false), - TwoFactorProviders = table.Column(type: "text", nullable: true), - TwoFactorRecoveryCode = table.Column(type: "character varying(32)", maxLength: 32, nullable: true), - EquivalentDomains = table.Column(type: "text", nullable: true), - ExcludedGlobalEquivalentDomains = table.Column(type: "text", nullable: true), - AccountRevisionDate = table.Column(type: "timestamp without time zone", nullable: false), - Key = table.Column(type: "text", nullable: true), - PublicKey = table.Column(type: "text", nullable: true), - PrivateKey = table.Column(type: "text", nullable: true), - Premium = table.Column(type: "boolean", nullable: false), - PremiumExpirationDate = table.Column(type: "timestamp without time zone", nullable: true), - RenewalReminderDate = table.Column(type: "timestamp without time zone", nullable: true), - Storage = table.Column(type: "bigint", nullable: true), - MaxStorageGb = table.Column(type: "smallint", nullable: true), - Gateway = table.Column(type: "smallint", nullable: true), - GatewayCustomerId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - GatewaySubscriptionId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - ReferenceData = table.Column(type: "text", nullable: true), - LicenseKey = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), - ApiKey = table.Column(type: "character varying(30)", maxLength: 30, nullable: false), - Kdf = table.Column(type: "smallint", nullable: false), - KdfIterations = table.Column(type: "integer", nullable: false), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_User", x => x.Id); - }); - - migrationBuilder.CreateTable( - name: "Collection", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - OrganizationId = table.Column(type: "uuid", nullable: false), - Name = table.Column(type: "text", nullable: true), - ExternalId = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Collection", x => x.Id); - table.ForeignKey( - name: "FK_Collection_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "Group", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - OrganizationId = table.Column(type: "uuid", nullable: false), - Name = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), - AccessAll = table.Column(type: "boolean", nullable: false), - ExternalId = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Group", x => x.Id); - table.ForeignKey( - name: "FK_Group_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "Policy", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - OrganizationId = table.Column(type: "uuid", nullable: false), - Type = table.Column(type: "smallint", nullable: false), - Data = table.Column(type: "text", nullable: true), - Enabled = table.Column(type: "boolean", nullable: false), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Policy", x => x.Id); - table.ForeignKey( - name: "FK_Policy_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "SsoConfig", - columns: table => new - { - Id = table.Column(type: "bigint", nullable: false) - .Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), - Enabled = table.Column(type: "boolean", nullable: false), - OrganizationId = table.Column(type: "uuid", nullable: false), - Data = table.Column(type: "text", nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_SsoConfig", x => x.Id); - table.ForeignKey( - name: "FK_SsoConfig_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "ProviderOrganization", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - ProviderId = table.Column(type: "uuid", nullable: false), - OrganizationId = table.Column(type: "uuid", nullable: false), - Key = table.Column(type: "text", nullable: true), - Settings = table.Column(type: "text", nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_ProviderOrganization", x => x.Id); - table.ForeignKey( - name: "FK_ProviderOrganization_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_ProviderOrganization_Provider_ProviderId", - column: x => x.ProviderId, - principalTable: "Provider", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "Cipher", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - UserId = table.Column(type: "uuid", nullable: true), - OrganizationId = table.Column(type: "uuid", nullable: true), - Type = table.Column(type: "smallint", nullable: false), - Data = table.Column(type: "text", nullable: true), - Favorites = table.Column(type: "text", nullable: true), - Folders = table.Column(type: "text", nullable: true), - Attachments = table.Column(type: "text", nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false), - DeletedDate = table.Column(type: "timestamp without time zone", nullable: true), - Reprompt = table.Column(type: "smallint", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_Cipher", x => x.Id); - table.ForeignKey( - name: "FK_Cipher_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_Cipher_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }); - - migrationBuilder.CreateTable( - name: "Device", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - UserId = table.Column(type: "uuid", nullable: false), - Name = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - Type = table.Column(type: "smallint", nullable: false), - Identifier = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - PushToken = table.Column(type: "character varying(255)", maxLength: 255, nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Device", x => x.Id); - table.ForeignKey( - name: "FK_Device_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "EmergencyAccess", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - GrantorId = table.Column(type: "uuid", nullable: false), - GranteeId = table.Column(type: "uuid", nullable: true), - Email = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), - KeyEncrypted = table.Column(type: "text", nullable: true), - Type = table.Column(type: "smallint", nullable: false), - Status = table.Column(type: "smallint", nullable: false), - WaitTimeDays = table.Column(type: "integer", nullable: false), - RecoveryInitiatedDate = table.Column(type: "timestamp without time zone", nullable: true), - LastNotificationDate = table.Column(type: "timestamp without time zone", nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_EmergencyAccess", x => x.Id); - table.ForeignKey( - name: "FK_EmergencyAccess_User_GranteeId", - column: x => x.GranteeId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_EmergencyAccess_User_GrantorId", - column: x => x.GrantorId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "Folder", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - UserId = table.Column(type: "uuid", nullable: false), - Name = table.Column(type: "text", nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Folder", x => x.Id); - table.ForeignKey( - name: "FK_Folder_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "OrganizationUser", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - OrganizationId = table.Column(type: "uuid", nullable: false), - UserId = table.Column(type: "uuid", nullable: true), - Email = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), - Key = table.Column(type: "text", nullable: true), - ResetPasswordKey = table.Column(type: "text", nullable: true), - Status = table.Column(type: "smallint", nullable: false), - Type = table.Column(type: "smallint", nullable: false), - AccessAll = table.Column(type: "boolean", nullable: false), - ExternalId = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false), - Permissions = table.Column(type: "text", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_OrganizationUser", x => x.Id); - table.ForeignKey( - name: "FK_OrganizationUser_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_OrganizationUser_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }); - - migrationBuilder.CreateTable( - name: "ProviderUser", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - ProviderId = table.Column(type: "uuid", nullable: false), - UserId = table.Column(type: "uuid", nullable: true), - Email = table.Column(type: "text", nullable: true), - Key = table.Column(type: "text", nullable: true), - Status = table.Column(type: "smallint", nullable: false), - Type = table.Column(type: "smallint", nullable: false), - Permissions = table.Column(type: "text", nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_ProviderUser", x => x.Id); - table.ForeignKey( - name: "FK_ProviderUser_Provider_ProviderId", - column: x => x.ProviderId, - principalTable: "Provider", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_ProviderUser_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }); - - migrationBuilder.CreateTable( - name: "Send", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - UserId = table.Column(type: "uuid", nullable: true), - OrganizationId = table.Column(type: "uuid", nullable: true), - Type = table.Column(type: "smallint", nullable: false), - Data = table.Column(type: "text", nullable: true), - Key = table.Column(type: "text", nullable: true), - Password = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), - MaxAccessCount = table.Column(type: "integer", nullable: true), - AccessCount = table.Column(type: "integer", nullable: false), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false), - ExpirationDate = table.Column(type: "timestamp without time zone", nullable: true), - DeletionDate = table.Column(type: "timestamp without time zone", nullable: false), - Disabled = table.Column(type: "boolean", nullable: false), - HideEmail = table.Column(type: "boolean", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_Send", x => x.Id); - table.ForeignKey( - name: "FK_Send_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_Send_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }); - - migrationBuilder.CreateTable( - name: "SsoUser", - columns: table => new - { - Id = table.Column(type: "bigint", nullable: false) - .Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), - UserId = table.Column(type: "uuid", nullable: false), - OrganizationId = table.Column(type: "uuid", nullable: true), - ExternalId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true, collation: "postgresIndetermanisticCollation"), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_SsoUser", x => x.Id); - table.ForeignKey( - name: "FK_SsoUser_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_SsoUser_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "Transaction", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - UserId = table.Column(type: "uuid", nullable: true), - OrganizationId = table.Column(type: "uuid", nullable: true), - Type = table.Column(type: "smallint", nullable: false), - Amount = table.Column(type: "numeric", nullable: false), - Refunded = table.Column(type: "boolean", nullable: true), - RefundedAmount = table.Column(type: "numeric", nullable: true), - Details = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), - PaymentMethodType = table.Column(type: "smallint", nullable: true), - Gateway = table.Column(type: "smallint", nullable: true), - GatewayId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_Transaction", x => x.Id); - table.ForeignKey( - name: "FK_Transaction_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_Transaction_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }); - - migrationBuilder.CreateTable( - name: "U2f", - columns: table => new - { - Id = table.Column(type: "integer", nullable: false) - .Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), - UserId = table.Column(type: "uuid", nullable: false), - KeyHandle = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), - Challenge = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), - AppId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - Version = table.Column(type: "character varying(20)", maxLength: 20, nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_U2f", x => x.Id); - table.ForeignKey( - name: "FK_U2f_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "CollectionGroups", - columns: table => new - { - CollectionId = table.Column(type: "uuid", nullable: false), - GroupId = table.Column(type: "uuid", nullable: false), - ReadOnly = table.Column(type: "boolean", nullable: false), - HidePasswords = table.Column(type: "boolean", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_CollectionGroups", x => new { x.CollectionId, x.GroupId }); - table.ForeignKey( - name: "FK_CollectionGroups_Collection_CollectionId", - column: x => x.CollectionId, - principalTable: "Collection", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_CollectionGroups_Group_GroupId", - column: x => x.GroupId, - principalTable: "Group", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "CollectionCipher", - columns: table => new - { - CollectionId = table.Column(type: "uuid", nullable: false), - CipherId = table.Column(type: "uuid", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_CollectionCipher", x => new { x.CollectionId, x.CipherId }); - table.ForeignKey( - name: "FK_CollectionCipher_Cipher_CipherId", - column: x => x.CipherId, - principalTable: "Cipher", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_CollectionCipher_Collection_CollectionId", - column: x => x.CollectionId, - principalTable: "Collection", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateTable( - name: "CollectionUsers", - columns: table => new - { - CollectionId = table.Column(type: "uuid", nullable: false), - OrganizationUserId = table.Column(type: "uuid", nullable: false), - UserId = table.Column(type: "uuid", nullable: true), - ReadOnly = table.Column(type: "boolean", nullable: false), - HidePasswords = table.Column(type: "boolean", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_CollectionUsers", x => new { x.CollectionId, x.OrganizationUserId }); - table.ForeignKey( - name: "FK_CollectionUsers_Collection_CollectionId", - column: x => x.CollectionId, - principalTable: "Collection", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_CollectionUsers_OrganizationUser_OrganizationUserId", - column: x => x.OrganizationUserId, - principalTable: "OrganizationUser", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_CollectionUsers_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }); - - migrationBuilder.CreateTable( - name: "GroupUser", - columns: table => new - { - GroupId = table.Column(type: "uuid", nullable: false), - OrganizationUserId = table.Column(type: "uuid", nullable: false), - UserId = table.Column(type: "uuid", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_GroupUser", x => new { x.GroupId, x.OrganizationUserId }); - table.ForeignKey( - name: "FK_GroupUser_Group_GroupId", - column: x => x.GroupId, - principalTable: "Group", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_GroupUser_OrganizationUser_OrganizationUserId", - column: x => x.OrganizationUserId, - principalTable: "OrganizationUser", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_GroupUser_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }); - - migrationBuilder.CreateTable( - name: "ProviderOrganizationProviderUser", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - ProviderOrganizationId = table.Column(type: "uuid", nullable: false), - ProviderUserId = table.Column(type: "uuid", nullable: false), - Type = table.Column(type: "smallint", nullable: false), - Permissions = table.Column(type: "text", nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_ProviderOrganizationProviderUser", x => x.Id); - table.ForeignKey( - name: "FK_ProviderOrganizationProviderUser_ProviderOrganization_Provi~", - column: x => x.ProviderOrganizationId, - principalTable: "ProviderOrganization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_ProviderOrganizationProviderUser_ProviderUser_ProviderUserId", - column: x => x.ProviderUserId, - principalTable: "ProviderUser", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); - - migrationBuilder.CreateIndex( - name: "IX_Cipher_OrganizationId", - table: "Cipher", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_Cipher_UserId", - table: "Cipher", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Collection_OrganizationId", - table: "Collection", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_CollectionCipher_CipherId", - table: "CollectionCipher", - column: "CipherId"); - - migrationBuilder.CreateIndex( - name: "IX_CollectionGroups_GroupId", - table: "CollectionGroups", - column: "GroupId"); - - migrationBuilder.CreateIndex( - name: "IX_CollectionUsers_OrganizationUserId", - table: "CollectionUsers", - column: "OrganizationUserId"); - - migrationBuilder.CreateIndex( - name: "IX_CollectionUsers_UserId", - table: "CollectionUsers", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Device_UserId", - table: "Device", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_EmergencyAccess_GranteeId", - table: "EmergencyAccess", - column: "GranteeId"); - - migrationBuilder.CreateIndex( - name: "IX_EmergencyAccess_GrantorId", - table: "EmergencyAccess", - column: "GrantorId"); - - migrationBuilder.CreateIndex( - name: "IX_Folder_UserId", - table: "Folder", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Group_OrganizationId", - table: "Group", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_GroupUser_OrganizationUserId", - table: "GroupUser", - column: "OrganizationUserId"); - - migrationBuilder.CreateIndex( - name: "IX_GroupUser_UserId", - table: "GroupUser", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_OrganizationUser_OrganizationId", - table: "OrganizationUser", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_OrganizationUser_UserId", - table: "OrganizationUser", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Policy_OrganizationId", - table: "Policy", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganization_OrganizationId", - table: "ProviderOrganization", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganization_ProviderId", - table: "ProviderOrganization", - column: "ProviderId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganizationProviderUser_ProviderOrganizationId", - table: "ProviderOrganizationProviderUser", - column: "ProviderOrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganizationProviderUser_ProviderUserId", - table: "ProviderOrganizationProviderUser", - column: "ProviderUserId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderUser_ProviderId", - table: "ProviderUser", - column: "ProviderId"); - - migrationBuilder.CreateIndex( - name: "IX_ProviderUser_UserId", - table: "ProviderUser", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_Send_OrganizationId", - table: "Send", - column: "OrganizationId"); - - migrationBuilder.CreateIndex( - name: "IX_Send_UserId", - table: "Send", - column: "UserId"); - - migrationBuilder.CreateIndex( - name: "IX_SsoConfig_OrganizationId", - table: "SsoConfig", - column: "OrganizationId"); + migrationBuilder.AlterDatabase() + .Annotation("Npgsql:CollationDefinition:postgresIndetermanisticCollation", "en-u-ks-primary,en-u-ks-primary,icu,False"); + + migrationBuilder.CreateTable( + name: "Event", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + Date = table.Column(type: "timestamp without time zone", nullable: false), + Type = table.Column(type: "integer", nullable: false), + UserId = table.Column(type: "uuid", nullable: true), + OrganizationId = table.Column(type: "uuid", nullable: true), + CipherId = table.Column(type: "uuid", nullable: true), + CollectionId = table.Column(type: "uuid", nullable: true), + PolicyId = table.Column(type: "uuid", nullable: true), + GroupId = table.Column(type: "uuid", nullable: true), + OrganizationUserId = table.Column(type: "uuid", nullable: true), + DeviceType = table.Column(type: "smallint", nullable: true), + IpAddress = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + ActingUserId = table.Column(type: "uuid", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_Event", x => x.Id); + }); + + migrationBuilder.CreateTable( + name: "Grant", + columns: table => new + { + Key = table.Column(type: "character varying(200)", maxLength: 200, nullable: false), + Type = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + SubjectId = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), + SessionId = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), + ClientId = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), + Description = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + ExpirationDate = table.Column(type: "timestamp without time zone", nullable: true), + ConsumedDate = table.Column(type: "timestamp without time zone", nullable: true), + Data = table.Column(type: "text", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_Grant", x => x.Key); + }); + + migrationBuilder.CreateTable( + name: "Installation", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + Email = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), + Key = table.Column(type: "character varying(150)", maxLength: 150, nullable: true), + Enabled = table.Column(type: "boolean", nullable: false), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Installation", x => x.Id); + }); + + migrationBuilder.CreateTable( + name: "Organization", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + Identifier = table.Column(type: "character varying(50)", maxLength: 50, nullable: true, collation: "postgresIndetermanisticCollation"), + Name = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + BusinessName = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + BusinessAddress1 = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + BusinessAddress2 = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + BusinessAddress3 = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + BusinessCountry = table.Column(type: "character varying(2)", maxLength: 2, nullable: true), + BusinessTaxNumber = table.Column(type: "character varying(30)", maxLength: 30, nullable: true), + BillingEmail = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), + Plan = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + PlanType = table.Column(type: "smallint", nullable: false), + Seats = table.Column(type: "integer", nullable: true), + MaxCollections = table.Column(type: "smallint", nullable: true), + UsePolicies = table.Column(type: "boolean", nullable: false), + UseSso = table.Column(type: "boolean", nullable: false), + UseGroups = table.Column(type: "boolean", nullable: false), + UseDirectory = table.Column(type: "boolean", nullable: false), + UseEvents = table.Column(type: "boolean", nullable: false), + UseTotp = table.Column(type: "boolean", nullable: false), + Use2fa = table.Column(type: "boolean", nullable: false), + UseApi = table.Column(type: "boolean", nullable: false), + UseResetPassword = table.Column(type: "boolean", nullable: false), + SelfHost = table.Column(type: "boolean", nullable: false), + UsersGetPremium = table.Column(type: "boolean", nullable: false), + Storage = table.Column(type: "bigint", nullable: true), + MaxStorageGb = table.Column(type: "smallint", nullable: true), + Gateway = table.Column(type: "smallint", nullable: true), + GatewayCustomerId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + GatewaySubscriptionId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + ReferenceData = table.Column(type: "text", nullable: true), + Enabled = table.Column(type: "boolean", nullable: false), + LicenseKey = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), + ApiKey = table.Column(type: "character varying(30)", maxLength: 30, nullable: true), + PublicKey = table.Column(type: "text", nullable: true), + PrivateKey = table.Column(type: "text", nullable: true), + TwoFactorProviders = table.Column(type: "text", nullable: true), + ExpirationDate = table.Column(type: "timestamp without time zone", nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Organization", x => x.Id); + }); + + migrationBuilder.CreateTable( + name: "Provider", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + Name = table.Column(type: "text", nullable: true), + BusinessName = table.Column(type: "text", nullable: true), + BusinessAddress1 = table.Column(type: "text", nullable: true), + BusinessAddress2 = table.Column(type: "text", nullable: true), + BusinessAddress3 = table.Column(type: "text", nullable: true), + BusinessCountry = table.Column(type: "text", nullable: true), + BusinessTaxNumber = table.Column(type: "text", nullable: true), + BillingEmail = table.Column(type: "text", nullable: true), + Status = table.Column(type: "smallint", nullable: false), + UseEvents = table.Column(type: "boolean", nullable: false), + Enabled = table.Column(type: "boolean", nullable: false), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Provider", x => x.Id); + }); + + migrationBuilder.CreateTable( + name: "TaxRate", + columns: table => new + { + Id = table.Column(type: "character varying(40)", maxLength: 40, nullable: false), + Country = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + State = table.Column(type: "character varying(2)", maxLength: 2, nullable: true), + PostalCode = table.Column(type: "character varying(10)", maxLength: 10, nullable: true), + Rate = table.Column(type: "numeric", nullable: false), + Active = table.Column(type: "boolean", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_TaxRate", x => x.Id); + }); + + migrationBuilder.CreateTable( + name: "User", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + Name = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + Email = table.Column(type: "character varying(256)", maxLength: 256, nullable: false, collation: "postgresIndetermanisticCollation"), + EmailVerified = table.Column(type: "boolean", nullable: false), + MasterPassword = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), + MasterPasswordHint = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + Culture = table.Column(type: "character varying(10)", maxLength: 10, nullable: true), + SecurityStamp = table.Column(type: "character varying(50)", maxLength: 50, nullable: false), + TwoFactorProviders = table.Column(type: "text", nullable: true), + TwoFactorRecoveryCode = table.Column(type: "character varying(32)", maxLength: 32, nullable: true), + EquivalentDomains = table.Column(type: "text", nullable: true), + ExcludedGlobalEquivalentDomains = table.Column(type: "text", nullable: true), + AccountRevisionDate = table.Column(type: "timestamp without time zone", nullable: false), + Key = table.Column(type: "text", nullable: true), + PublicKey = table.Column(type: "text", nullable: true), + PrivateKey = table.Column(type: "text", nullable: true), + Premium = table.Column(type: "boolean", nullable: false), + PremiumExpirationDate = table.Column(type: "timestamp without time zone", nullable: true), + RenewalReminderDate = table.Column(type: "timestamp without time zone", nullable: true), + Storage = table.Column(type: "bigint", nullable: true), + MaxStorageGb = table.Column(type: "smallint", nullable: true), + Gateway = table.Column(type: "smallint", nullable: true), + GatewayCustomerId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + GatewaySubscriptionId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + ReferenceData = table.Column(type: "text", nullable: true), + LicenseKey = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), + ApiKey = table.Column(type: "character varying(30)", maxLength: 30, nullable: false), + Kdf = table.Column(type: "smallint", nullable: false), + KdfIterations = table.Column(type: "integer", nullable: false), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_User", x => x.Id); + }); + + migrationBuilder.CreateTable( + name: "Collection", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + OrganizationId = table.Column(type: "uuid", nullable: false), + Name = table.Column(type: "text", nullable: true), + ExternalId = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Collection", x => x.Id); + table.ForeignKey( + name: "FK_Collection_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "Group", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + OrganizationId = table.Column(type: "uuid", nullable: false), + Name = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), + AccessAll = table.Column(type: "boolean", nullable: false), + ExternalId = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Group", x => x.Id); + table.ForeignKey( + name: "FK_Group_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "Policy", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + OrganizationId = table.Column(type: "uuid", nullable: false), + Type = table.Column(type: "smallint", nullable: false), + Data = table.Column(type: "text", nullable: true), + Enabled = table.Column(type: "boolean", nullable: false), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Policy", x => x.Id); + table.ForeignKey( + name: "FK_Policy_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "SsoConfig", + columns: table => new + { + Id = table.Column(type: "bigint", nullable: false) + .Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), + Enabled = table.Column(type: "boolean", nullable: false), + OrganizationId = table.Column(type: "uuid", nullable: false), + Data = table.Column(type: "text", nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_SsoConfig", x => x.Id); + table.ForeignKey( + name: "FK_SsoConfig_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "ProviderOrganization", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + ProviderId = table.Column(type: "uuid", nullable: false), + OrganizationId = table.Column(type: "uuid", nullable: false), + Key = table.Column(type: "text", nullable: true), + Settings = table.Column(type: "text", nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_ProviderOrganization", x => x.Id); + table.ForeignKey( + name: "FK_ProviderOrganization_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_ProviderOrganization_Provider_ProviderId", + column: x => x.ProviderId, + principalTable: "Provider", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "Cipher", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + UserId = table.Column(type: "uuid", nullable: true), + OrganizationId = table.Column(type: "uuid", nullable: true), + Type = table.Column(type: "smallint", nullable: false), + Data = table.Column(type: "text", nullable: true), + Favorites = table.Column(type: "text", nullable: true), + Folders = table.Column(type: "text", nullable: true), + Attachments = table.Column(type: "text", nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false), + DeletedDate = table.Column(type: "timestamp without time zone", nullable: true), + Reprompt = table.Column(type: "smallint", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_Cipher", x => x.Id); + table.ForeignKey( + name: "FK_Cipher_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_Cipher_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }); + + migrationBuilder.CreateTable( + name: "Device", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + UserId = table.Column(type: "uuid", nullable: false), + Name = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + Type = table.Column(type: "smallint", nullable: false), + Identifier = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + PushToken = table.Column(type: "character varying(255)", maxLength: 255, nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Device", x => x.Id); + table.ForeignKey( + name: "FK_Device_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "EmergencyAccess", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + GrantorId = table.Column(type: "uuid", nullable: false), + GranteeId = table.Column(type: "uuid", nullable: true), + Email = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), + KeyEncrypted = table.Column(type: "text", nullable: true), + Type = table.Column(type: "smallint", nullable: false), + Status = table.Column(type: "smallint", nullable: false), + WaitTimeDays = table.Column(type: "integer", nullable: false), + RecoveryInitiatedDate = table.Column(type: "timestamp without time zone", nullable: true), + LastNotificationDate = table.Column(type: "timestamp without time zone", nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_EmergencyAccess", x => x.Id); + table.ForeignKey( + name: "FK_EmergencyAccess_User_GranteeId", + column: x => x.GranteeId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_EmergencyAccess_User_GrantorId", + column: x => x.GrantorId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "Folder", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + UserId = table.Column(type: "uuid", nullable: false), + Name = table.Column(type: "text", nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Folder", x => x.Id); + table.ForeignKey( + name: "FK_Folder_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "OrganizationUser", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + OrganizationId = table.Column(type: "uuid", nullable: false), + UserId = table.Column(type: "uuid", nullable: true), + Email = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), + Key = table.Column(type: "text", nullable: true), + ResetPasswordKey = table.Column(type: "text", nullable: true), + Status = table.Column(type: "smallint", nullable: false), + Type = table.Column(type: "smallint", nullable: false), + AccessAll = table.Column(type: "boolean", nullable: false), + ExternalId = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false), + Permissions = table.Column(type: "text", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_OrganizationUser", x => x.Id); + table.ForeignKey( + name: "FK_OrganizationUser_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_OrganizationUser_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }); + + migrationBuilder.CreateTable( + name: "ProviderUser", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + ProviderId = table.Column(type: "uuid", nullable: false), + UserId = table.Column(type: "uuid", nullable: true), + Email = table.Column(type: "text", nullable: true), + Key = table.Column(type: "text", nullable: true), + Status = table.Column(type: "smallint", nullable: false), + Type = table.Column(type: "smallint", nullable: false), + Permissions = table.Column(type: "text", nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_ProviderUser", x => x.Id); + table.ForeignKey( + name: "FK_ProviderUser_Provider_ProviderId", + column: x => x.ProviderId, + principalTable: "Provider", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_ProviderUser_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }); + + migrationBuilder.CreateTable( + name: "Send", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + UserId = table.Column(type: "uuid", nullable: true), + OrganizationId = table.Column(type: "uuid", nullable: true), + Type = table.Column(type: "smallint", nullable: false), + Data = table.Column(type: "text", nullable: true), + Key = table.Column(type: "text", nullable: true), + Password = table.Column(type: "character varying(300)", maxLength: 300, nullable: true), + MaxAccessCount = table.Column(type: "integer", nullable: true), + AccessCount = table.Column(type: "integer", nullable: false), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false), + ExpirationDate = table.Column(type: "timestamp without time zone", nullable: true), + DeletionDate = table.Column(type: "timestamp without time zone", nullable: false), + Disabled = table.Column(type: "boolean", nullable: false), + HideEmail = table.Column(type: "boolean", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_Send", x => x.Id); + table.ForeignKey( + name: "FK_Send_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_Send_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }); + + migrationBuilder.CreateTable( + name: "SsoUser", + columns: table => new + { + Id = table.Column(type: "bigint", nullable: false) + .Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), + UserId = table.Column(type: "uuid", nullable: false), + OrganizationId = table.Column(type: "uuid", nullable: true), + ExternalId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true, collation: "postgresIndetermanisticCollation"), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_SsoUser", x => x.Id); + table.ForeignKey( + name: "FK_SsoUser_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_SsoUser_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "Transaction", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + UserId = table.Column(type: "uuid", nullable: true), + OrganizationId = table.Column(type: "uuid", nullable: true), + Type = table.Column(type: "smallint", nullable: false), + Amount = table.Column(type: "numeric", nullable: false), + Refunded = table.Column(type: "boolean", nullable: true), + RefundedAmount = table.Column(type: "numeric", nullable: true), + Details = table.Column(type: "character varying(100)", maxLength: 100, nullable: true), + PaymentMethodType = table.Column(type: "smallint", nullable: true), + Gateway = table.Column(type: "smallint", nullable: true), + GatewayId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_Transaction", x => x.Id); + table.ForeignKey( + name: "FK_Transaction_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_Transaction_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }); + + migrationBuilder.CreateTable( + name: "U2f", + columns: table => new + { + Id = table.Column(type: "integer", nullable: false) + .Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), + UserId = table.Column(type: "uuid", nullable: false), + KeyHandle = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), + Challenge = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), + AppId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + Version = table.Column(type: "character varying(20)", maxLength: 20, nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_U2f", x => x.Id); + table.ForeignKey( + name: "FK_U2f_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "CollectionGroups", + columns: table => new + { + CollectionId = table.Column(type: "uuid", nullable: false), + GroupId = table.Column(type: "uuid", nullable: false), + ReadOnly = table.Column(type: "boolean", nullable: false), + HidePasswords = table.Column(type: "boolean", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_CollectionGroups", x => new { x.CollectionId, x.GroupId }); + table.ForeignKey( + name: "FK_CollectionGroups_Collection_CollectionId", + column: x => x.CollectionId, + principalTable: "Collection", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_CollectionGroups_Group_GroupId", + column: x => x.GroupId, + principalTable: "Group", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "CollectionCipher", + columns: table => new + { + CollectionId = table.Column(type: "uuid", nullable: false), + CipherId = table.Column(type: "uuid", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_CollectionCipher", x => new { x.CollectionId, x.CipherId }); + table.ForeignKey( + name: "FK_CollectionCipher_Cipher_CipherId", + column: x => x.CipherId, + principalTable: "Cipher", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_CollectionCipher_Collection_CollectionId", + column: x => x.CollectionId, + principalTable: "Collection", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateTable( + name: "CollectionUsers", + columns: table => new + { + CollectionId = table.Column(type: "uuid", nullable: false), + OrganizationUserId = table.Column(type: "uuid", nullable: false), + UserId = table.Column(type: "uuid", nullable: true), + ReadOnly = table.Column(type: "boolean", nullable: false), + HidePasswords = table.Column(type: "boolean", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_CollectionUsers", x => new { x.CollectionId, x.OrganizationUserId }); + table.ForeignKey( + name: "FK_CollectionUsers_Collection_CollectionId", + column: x => x.CollectionId, + principalTable: "Collection", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_CollectionUsers_OrganizationUser_OrganizationUserId", + column: x => x.OrganizationUserId, + principalTable: "OrganizationUser", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_CollectionUsers_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }); + + migrationBuilder.CreateTable( + name: "GroupUser", + columns: table => new + { + GroupId = table.Column(type: "uuid", nullable: false), + OrganizationUserId = table.Column(type: "uuid", nullable: false), + UserId = table.Column(type: "uuid", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_GroupUser", x => new { x.GroupId, x.OrganizationUserId }); + table.ForeignKey( + name: "FK_GroupUser_Group_GroupId", + column: x => x.GroupId, + principalTable: "Group", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_GroupUser_OrganizationUser_OrganizationUserId", + column: x => x.OrganizationUserId, + principalTable: "OrganizationUser", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_GroupUser_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }); + + migrationBuilder.CreateTable( + name: "ProviderOrganizationProviderUser", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + ProviderOrganizationId = table.Column(type: "uuid", nullable: false), + ProviderUserId = table.Column(type: "uuid", nullable: false), + Type = table.Column(type: "smallint", nullable: false), + Permissions = table.Column(type: "text", nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_ProviderOrganizationProviderUser", x => x.Id); + table.ForeignKey( + name: "FK_ProviderOrganizationProviderUser_ProviderOrganization_Provi~", + column: x => x.ProviderOrganizationId, + principalTable: "ProviderOrganization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_ProviderOrganizationProviderUser_ProviderUser_ProviderUserId", + column: x => x.ProviderUserId, + principalTable: "ProviderUser", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); + + migrationBuilder.CreateIndex( + name: "IX_Cipher_OrganizationId", + table: "Cipher", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_Cipher_UserId", + table: "Cipher", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Collection_OrganizationId", + table: "Collection", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_CollectionCipher_CipherId", + table: "CollectionCipher", + column: "CipherId"); + + migrationBuilder.CreateIndex( + name: "IX_CollectionGroups_GroupId", + table: "CollectionGroups", + column: "GroupId"); + + migrationBuilder.CreateIndex( + name: "IX_CollectionUsers_OrganizationUserId", + table: "CollectionUsers", + column: "OrganizationUserId"); + + migrationBuilder.CreateIndex( + name: "IX_CollectionUsers_UserId", + table: "CollectionUsers", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Device_UserId", + table: "Device", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_EmergencyAccess_GranteeId", + table: "EmergencyAccess", + column: "GranteeId"); + + migrationBuilder.CreateIndex( + name: "IX_EmergencyAccess_GrantorId", + table: "EmergencyAccess", + column: "GrantorId"); + + migrationBuilder.CreateIndex( + name: "IX_Folder_UserId", + table: "Folder", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Group_OrganizationId", + table: "Group", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_GroupUser_OrganizationUserId", + table: "GroupUser", + column: "OrganizationUserId"); + + migrationBuilder.CreateIndex( + name: "IX_GroupUser_UserId", + table: "GroupUser", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_OrganizationUser_OrganizationId", + table: "OrganizationUser", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_OrganizationUser_UserId", + table: "OrganizationUser", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Policy_OrganizationId", + table: "Policy", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganization_OrganizationId", + table: "ProviderOrganization", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganization_ProviderId", + table: "ProviderOrganization", + column: "ProviderId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganizationProviderUser_ProviderOrganizationId", + table: "ProviderOrganizationProviderUser", + column: "ProviderOrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganizationProviderUser_ProviderUserId", + table: "ProviderOrganizationProviderUser", + column: "ProviderUserId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderUser_ProviderId", + table: "ProviderUser", + column: "ProviderId"); + + migrationBuilder.CreateIndex( + name: "IX_ProviderUser_UserId", + table: "ProviderUser", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_Send_OrganizationId", + table: "Send", + column: "OrganizationId"); + + migrationBuilder.CreateIndex( + name: "IX_Send_UserId", + table: "Send", + column: "UserId"); + + migrationBuilder.CreateIndex( + name: "IX_SsoConfig_OrganizationId", + table: "SsoConfig", + column: "OrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_SsoUser_OrganizationId", - table: "SsoUser", - column: "OrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_SsoUser_OrganizationId", + table: "SsoUser", + column: "OrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_SsoUser_UserId", - table: "SsoUser", - column: "UserId"); + migrationBuilder.CreateIndex( + name: "IX_SsoUser_UserId", + table: "SsoUser", + column: "UserId"); - migrationBuilder.CreateIndex( - name: "IX_Transaction_OrganizationId", - table: "Transaction", - column: "OrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_Transaction_OrganizationId", + table: "Transaction", + column: "OrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_Transaction_UserId", - table: "Transaction", - column: "UserId"); + migrationBuilder.CreateIndex( + name: "IX_Transaction_UserId", + table: "Transaction", + column: "UserId"); - migrationBuilder.CreateIndex( - name: "IX_U2f_UserId", - table: "U2f", - column: "UserId"); - } + migrationBuilder.CreateIndex( + name: "IX_U2f_UserId", + table: "U2f", + column: "UserId"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropTable( - name: "CollectionCipher"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropTable( + name: "CollectionCipher"); - migrationBuilder.DropTable( - name: "CollectionGroups"); + migrationBuilder.DropTable( + name: "CollectionGroups"); - migrationBuilder.DropTable( - name: "CollectionUsers"); + migrationBuilder.DropTable( + name: "CollectionUsers"); - migrationBuilder.DropTable( - name: "Device"); + migrationBuilder.DropTable( + name: "Device"); - migrationBuilder.DropTable( - name: "EmergencyAccess"); + migrationBuilder.DropTable( + name: "EmergencyAccess"); - migrationBuilder.DropTable( - name: "Event"); + migrationBuilder.DropTable( + name: "Event"); - migrationBuilder.DropTable( - name: "Folder"); + migrationBuilder.DropTable( + name: "Folder"); - migrationBuilder.DropTable( - name: "Grant"); + migrationBuilder.DropTable( + name: "Grant"); - migrationBuilder.DropTable( - name: "GroupUser"); + migrationBuilder.DropTable( + name: "GroupUser"); - migrationBuilder.DropTable( - name: "Installation"); + migrationBuilder.DropTable( + name: "Installation"); - migrationBuilder.DropTable( - name: "Policy"); + migrationBuilder.DropTable( + name: "Policy"); - migrationBuilder.DropTable( - name: "ProviderOrganizationProviderUser"); + migrationBuilder.DropTable( + name: "ProviderOrganizationProviderUser"); - migrationBuilder.DropTable( - name: "Send"); + migrationBuilder.DropTable( + name: "Send"); - migrationBuilder.DropTable( - name: "SsoConfig"); + migrationBuilder.DropTable( + name: "SsoConfig"); - migrationBuilder.DropTable( - name: "SsoUser"); + migrationBuilder.DropTable( + name: "SsoUser"); - migrationBuilder.DropTable( - name: "TaxRate"); + migrationBuilder.DropTable( + name: "TaxRate"); - migrationBuilder.DropTable( - name: "Transaction"); + migrationBuilder.DropTable( + name: "Transaction"); - migrationBuilder.DropTable( - name: "U2f"); + migrationBuilder.DropTable( + name: "U2f"); - migrationBuilder.DropTable( - name: "Cipher"); + migrationBuilder.DropTable( + name: "Cipher"); - migrationBuilder.DropTable( - name: "Collection"); + migrationBuilder.DropTable( + name: "Collection"); - migrationBuilder.DropTable( - name: "Group"); + migrationBuilder.DropTable( + name: "Group"); - migrationBuilder.DropTable( - name: "OrganizationUser"); + migrationBuilder.DropTable( + name: "OrganizationUser"); - migrationBuilder.DropTable( - name: "ProviderOrganization"); + migrationBuilder.DropTable( + name: "ProviderOrganization"); - migrationBuilder.DropTable( - name: "ProviderUser"); + migrationBuilder.DropTable( + name: "ProviderUser"); - migrationBuilder.DropTable( - name: "Organization"); + migrationBuilder.DropTable( + name: "Organization"); - migrationBuilder.DropTable( - name: "Provider"); + migrationBuilder.DropTable( + name: "Provider"); - migrationBuilder.DropTable( - name: "User"); - } + migrationBuilder.DropTable( + name: "User"); } } diff --git a/util/PostgresMigrations/Migrations/20210709092227_RemoveProviderOrganizationProviderUser.cs b/util/PostgresMigrations/Migrations/20210709092227_RemoveProviderOrganizationProviderUser.cs index ba7da780f..f0f0d235b 100644 --- a/util/PostgresMigrations/Migrations/20210709092227_RemoveProviderOrganizationProviderUser.cs +++ b/util/PostgresMigrations/Migrations/20210709092227_RemoveProviderOrganizationProviderUser.cs @@ -1,75 +1,74 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations +namespace Bit.PostgresMigrations.Migrations; + +public partial class RemoveProviderOrganizationProviderUser : Migration { - public partial class RemoveProviderOrganizationProviderUser : Migration + protected override void Up(MigrationBuilder migrationBuilder) { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropTable( - name: "ProviderOrganizationProviderUser"); + migrationBuilder.DropTable( + name: "ProviderOrganizationProviderUser"); - migrationBuilder.AddColumn( - name: "ProviderId", - table: "Event", - type: "uuid", - nullable: true); + migrationBuilder.AddColumn( + name: "ProviderId", + table: "Event", + type: "uuid", + nullable: true); - migrationBuilder.AddColumn( - name: "ProviderUserId", - table: "Event", - type: "uuid", - nullable: true); - } + migrationBuilder.AddColumn( + name: "ProviderUserId", + table: "Event", + type: "uuid", + nullable: true); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "ProviderId", - table: "Event"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "ProviderId", + table: "Event"); - migrationBuilder.DropColumn( - name: "ProviderUserId", - table: "Event"); + migrationBuilder.DropColumn( + name: "ProviderUserId", + table: "Event"); - migrationBuilder.CreateTable( - name: "ProviderOrganizationProviderUser", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - Permissions = table.Column(type: "text", nullable: true), - ProviderOrganizationId = table.Column(type: "uuid", nullable: false), - ProviderUserId = table.Column(type: "uuid", nullable: false), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false), - Type = table.Column(type: "smallint", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_ProviderOrganizationProviderUser", x => x.Id); - table.ForeignKey( - name: "FK_ProviderOrganizationProviderUser_ProviderOrganization_Provi~", - column: x => x.ProviderOrganizationId, - principalTable: "ProviderOrganization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - table.ForeignKey( - name: "FK_ProviderOrganizationProviderUser_ProviderUser_ProviderUserId", - column: x => x.ProviderUserId, - principalTable: "ProviderUser", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); + migrationBuilder.CreateTable( + name: "ProviderOrganizationProviderUser", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + Permissions = table.Column(type: "text", nullable: true), + ProviderOrganizationId = table.Column(type: "uuid", nullable: false), + ProviderUserId = table.Column(type: "uuid", nullable: false), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false), + Type = table.Column(type: "smallint", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_ProviderOrganizationProviderUser", x => x.Id); + table.ForeignKey( + name: "FK_ProviderOrganizationProviderUser_ProviderOrganization_Provi~", + column: x => x.ProviderOrganizationId, + principalTable: "ProviderOrganization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + table.ForeignKey( + name: "FK_ProviderOrganizationProviderUser_ProviderUser_ProviderUserId", + column: x => x.ProviderUserId, + principalTable: "ProviderUser", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganizationProviderUser_ProviderOrganizationId", - table: "ProviderOrganizationProviderUser", - column: "ProviderOrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganizationProviderUser_ProviderOrganizationId", + table: "ProviderOrganizationProviderUser", + column: "ProviderOrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_ProviderOrganizationProviderUser_ProviderUserId", - table: "ProviderOrganizationProviderUser", - column: "ProviderUserId"); - } + migrationBuilder.CreateIndex( + name: "IX_ProviderOrganizationProviderUser_ProviderUserId", + table: "ProviderOrganizationProviderUser", + column: "ProviderUserId"); } } diff --git a/util/PostgresMigrations/Migrations/20210716141748_UserForcePasswordReset.cs b/util/PostgresMigrations/Migrations/20210716141748_UserForcePasswordReset.cs index bb39dfe4b..5b435b218 100644 --- a/util/PostgresMigrations/Migrations/20210716141748_UserForcePasswordReset.cs +++ b/util/PostgresMigrations/Migrations/20210716141748_UserForcePasswordReset.cs @@ -1,24 +1,23 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations -{ - public partial class UserForcePasswordReset : Migration - { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "ForcePasswordReset", - table: "User", - type: "boolean", - nullable: false, - defaultValue: false); - } +namespace Bit.PostgresMigrations.Migrations; - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "ForcePasswordReset", - table: "User"); - } +public partial class UserForcePasswordReset : Migration +{ + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "ForcePasswordReset", + table: "User", + type: "boolean", + nullable: false, + defaultValue: false); + } + + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "ForcePasswordReset", + table: "User"); } } diff --git a/util/PostgresMigrations/Migrations/20210920201829_AddMaxAutoscaleSeatsToOrganization.cs b/util/PostgresMigrations/Migrations/20210920201829_AddMaxAutoscaleSeatsToOrganization.cs index 98d2acce6..41ab20399 100644 --- a/util/PostgresMigrations/Migrations/20210920201829_AddMaxAutoscaleSeatsToOrganization.cs +++ b/util/PostgresMigrations/Migrations/20210920201829_AddMaxAutoscaleSeatsToOrganization.cs @@ -1,43 +1,42 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations +namespace Bit.PostgresMigrations.Migrations; + +public partial class AddMaxAutoscaleSeatsToOrganization : Migration { - public partial class AddMaxAutoscaleSeatsToOrganization : Migration + protected override void Up(MigrationBuilder migrationBuilder) { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "MaxAutoscaleSeats", - table: "Organization", - type: "integer", - nullable: true); + migrationBuilder.AddColumn( + name: "MaxAutoscaleSeats", + table: "Organization", + type: "integer", + nullable: true); - migrationBuilder.AddColumn( - name: "OwnersNotifiedOfAutoscaling", - table: "Organization", - type: "timestamp without time zone", - nullable: true); + migrationBuilder.AddColumn( + name: "OwnersNotifiedOfAutoscaling", + table: "Organization", + type: "timestamp without time zone", + nullable: true); - migrationBuilder.AddColumn( - name: "ProviderOrganizationId", - table: "Event", - type: "uuid", - nullable: true); - } + migrationBuilder.AddColumn( + name: "ProviderOrganizationId", + table: "Event", + type: "uuid", + nullable: true); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "MaxAutoscaleSeats", - table: "Organization"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "MaxAutoscaleSeats", + table: "Organization"); - migrationBuilder.DropColumn( - name: "OwnersNotifiedOfAutoscaling", - table: "Organization"); + migrationBuilder.DropColumn( + name: "OwnersNotifiedOfAutoscaling", + table: "Organization"); - migrationBuilder.DropColumn( - name: "ProviderOrganizationId", - table: "Event"); - } + migrationBuilder.DropColumn( + name: "ProviderOrganizationId", + table: "Event"); } } diff --git a/util/PostgresMigrations/Migrations/20211011145128_SplitManageCollectionsPermissions2.cs b/util/PostgresMigrations/Migrations/20211011145128_SplitManageCollectionsPermissions2.cs index 90b0884ab..d1c08d3fb 100644 --- a/util/PostgresMigrations/Migrations/20211011145128_SplitManageCollectionsPermissions2.cs +++ b/util/PostgresMigrations/Migrations/20211011145128_SplitManageCollectionsPermissions2.cs @@ -1,21 +1,20 @@ using Bit.Core.Utilities; using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations +namespace Bit.PostgresMigrations.Migrations; + +public partial class SplitManageCollectionsPermissions2 : Migration { - public partial class SplitManageCollectionsPermissions2 : Migration + private const string _scriptLocation = + "PostgresMigrations.Scripts.2021-09-21_01_SplitManageCollectionsPermission.psql"; + + protected override void Up(MigrationBuilder migrationBuilder) { - private const string _scriptLocation = - "PostgresMigrations.Scripts.2021-09-21_01_SplitManageCollectionsPermission.psql"; + migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync(_scriptLocation)); + } - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync(_scriptLocation)); - } - - protected override void Down(MigrationBuilder migrationBuilder) - { - throw new Exception("Irreversible migration"); - } + protected override void Down(MigrationBuilder migrationBuilder) + { + throw new Exception("Irreversible migration"); } } diff --git a/util/PostgresMigrations/Migrations/20211021204521_SetMaxAutoscaleSeatsToCurrentSeatCount.cs b/util/PostgresMigrations/Migrations/20211021204521_SetMaxAutoscaleSeatsToCurrentSeatCount.cs index c8c569a0d..c569d7f1b 100644 --- a/util/PostgresMigrations/Migrations/20211021204521_SetMaxAutoscaleSeatsToCurrentSeatCount.cs +++ b/util/PostgresMigrations/Migrations/20211021204521_SetMaxAutoscaleSeatsToCurrentSeatCount.cs @@ -1,21 +1,20 @@ using Bit.Core.Utilities; using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations +namespace Bit.PostgresMigrations.Migrations; + +public partial class SetMaxAutoscaleSeatsToCurrentSeatCount : Migration { - public partial class SetMaxAutoscaleSeatsToCurrentSeatCount : Migration + private const string _scriptLocation = + "PostgresMigrations.Scripts.2021-10-21_00_SetMaxAutoscaleSeatCount.psql"; + + protected override void Up(MigrationBuilder migrationBuilder) { - private const string _scriptLocation = - "PostgresMigrations.Scripts.2021-10-21_00_SetMaxAutoscaleSeatCount.psql"; + migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync(_scriptLocation)); + } - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.Sql(CoreHelpers.GetEmbeddedResourceContentsAsync(_scriptLocation)); - } - - protected override void Down(MigrationBuilder migrationBuilder) - { - throw new Exception("Irreversible migration"); - } + protected override void Down(MigrationBuilder migrationBuilder) + { + throw new Exception("Irreversible migration"); } } diff --git a/util/PostgresMigrations/Migrations/20211108041547_KeyConnector.cs b/util/PostgresMigrations/Migrations/20211108041547_KeyConnector.cs index 264869124..7619e7689 100644 --- a/util/PostgresMigrations/Migrations/20211108041547_KeyConnector.cs +++ b/util/PostgresMigrations/Migrations/20211108041547_KeyConnector.cs @@ -1,24 +1,23 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations -{ - public partial class KeyConnector : Migration - { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "UsesKeyConnector", - table: "User", - type: "boolean", - nullable: false, - defaultValue: false); - } +namespace Bit.PostgresMigrations.Migrations; - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "UsesKeyConnector", - table: "User"); - } +public partial class KeyConnector : Migration +{ + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "UsesKeyConnector", + table: "User", + type: "boolean", + nullable: false, + defaultValue: false); + } + + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "UsesKeyConnector", + table: "User"); } } diff --git a/util/PostgresMigrations/Migrations/20211108225011_OrganizationSponsorship.cs b/util/PostgresMigrations/Migrations/20211108225011_OrganizationSponsorship.cs index 6918e885d..a787141e7 100644 --- a/util/PostgresMigrations/Migrations/20211108225011_OrganizationSponsorship.cs +++ b/util/PostgresMigrations/Migrations/20211108225011_OrganizationSponsorship.cs @@ -1,82 +1,81 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations +namespace Bit.PostgresMigrations.Migrations; + +public partial class OrganizationSponsorship : Migration { - public partial class OrganizationSponsorship : Migration + protected override void Up(MigrationBuilder migrationBuilder) { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "UsesCryptoAgent", - table: "User", - type: "boolean", - nullable: false, - defaultValue: false); + migrationBuilder.AddColumn( + name: "UsesCryptoAgent", + table: "User", + type: "boolean", + nullable: false, + defaultValue: false); - migrationBuilder.CreateTable( - name: "OrganizationSponsorship", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - InstallationId = table.Column(type: "uuid", nullable: true), - SponsoringOrganizationId = table.Column(type: "uuid", nullable: true), - SponsoringOrganizationUserId = table.Column(type: "uuid", nullable: true), - SponsoredOrganizationId = table.Column(type: "uuid", nullable: true), - FriendlyName = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), - OfferedToEmail = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), - PlanSponsorshipType = table.Column(type: "smallint", nullable: true), - CloudSponsor = table.Column(type: "boolean", nullable: false), - LastSyncDate = table.Column(type: "timestamp without time zone", nullable: true), - TimesRenewedWithoutValidation = table.Column(type: "smallint", nullable: false), - SponsorshipLapsedDate = table.Column(type: "timestamp without time zone", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_OrganizationSponsorship", x => x.Id); - table.ForeignKey( - name: "FK_OrganizationSponsorship_Installation_InstallationId", - column: x => x.InstallationId, - principalTable: "Installation", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoredOrganizationId", - column: x => x.SponsoredOrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - table.ForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", - column: x => x.SponsoringOrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - }); + migrationBuilder.CreateTable( + name: "OrganizationSponsorship", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + InstallationId = table.Column(type: "uuid", nullable: true), + SponsoringOrganizationId = table.Column(type: "uuid", nullable: true), + SponsoringOrganizationUserId = table.Column(type: "uuid", nullable: true), + SponsoredOrganizationId = table.Column(type: "uuid", nullable: true), + FriendlyName = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), + OfferedToEmail = table.Column(type: "character varying(256)", maxLength: 256, nullable: true), + PlanSponsorshipType = table.Column(type: "smallint", nullable: true), + CloudSponsor = table.Column(type: "boolean", nullable: false), + LastSyncDate = table.Column(type: "timestamp without time zone", nullable: true), + TimesRenewedWithoutValidation = table.Column(type: "smallint", nullable: false), + SponsorshipLapsedDate = table.Column(type: "timestamp without time zone", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_OrganizationSponsorship", x => x.Id); + table.ForeignKey( + name: "FK_OrganizationSponsorship_Installation_InstallationId", + column: x => x.InstallationId, + principalTable: "Installation", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoredOrganizationId", + column: x => x.SponsoredOrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + table.ForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", + column: x => x.SponsoringOrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + }); - migrationBuilder.CreateIndex( - name: "IX_OrganizationSponsorship_InstallationId", - table: "OrganizationSponsorship", - column: "InstallationId"); + migrationBuilder.CreateIndex( + name: "IX_OrganizationSponsorship_InstallationId", + table: "OrganizationSponsorship", + column: "InstallationId"); - migrationBuilder.CreateIndex( - name: "IX_OrganizationSponsorship_SponsoredOrganizationId", - table: "OrganizationSponsorship", - column: "SponsoredOrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_OrganizationSponsorship_SponsoredOrganizationId", + table: "OrganizationSponsorship", + column: "SponsoredOrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_OrganizationSponsorship_SponsoringOrganizationId", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId"); - } + migrationBuilder.CreateIndex( + name: "IX_OrganizationSponsorship_SponsoringOrganizationId", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropTable( - name: "OrganizationSponsorship"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropTable( + name: "OrganizationSponsorship"); - migrationBuilder.DropColumn( - name: "UsesCryptoAgent", - table: "User"); - } + migrationBuilder.DropColumn( + name: "UsesCryptoAgent", + table: "User"); } } diff --git a/util/PostgresMigrations/Migrations/20211115142623_KeyConnectorFlag.cs b/util/PostgresMigrations/Migrations/20211115142623_KeyConnectorFlag.cs index edc522086..225f67bf9 100644 --- a/util/PostgresMigrations/Migrations/20211115142623_KeyConnectorFlag.cs +++ b/util/PostgresMigrations/Migrations/20211115142623_KeyConnectorFlag.cs @@ -1,24 +1,23 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations -{ - public partial class KeyConnectorFlag : Migration - { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "UseKeyConnector", - table: "Organization", - type: "boolean", - nullable: false, - defaultValue: false); - } +namespace Bit.PostgresMigrations.Migrations; - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "UseKeyConnector", - table: "Organization"); - } +public partial class KeyConnectorFlag : Migration +{ + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "UseKeyConnector", + table: "Organization", + type: "boolean", + nullable: false, + defaultValue: false); + } + + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "UseKeyConnector", + table: "Organization"); } } diff --git a/util/PostgresMigrations/Migrations/20220121092321_RemoveU2F.cs b/util/PostgresMigrations/Migrations/20220121092321_RemoveU2F.cs index 906c30be4..0679e212c 100644 --- a/util/PostgresMigrations/Migrations/20220121092321_RemoveU2F.cs +++ b/util/PostgresMigrations/Migrations/20220121092321_RemoveU2F.cs @@ -1,46 +1,45 @@ using Microsoft.EntityFrameworkCore.Migrations; using Npgsql.EntityFrameworkCore.PostgreSQL.Metadata; -namespace Bit.PostgresMigrations.Migrations +namespace Bit.PostgresMigrations.Migrations; + +public partial class RemoveU2F : Migration { - public partial class RemoveU2F : Migration + protected override void Up(MigrationBuilder migrationBuilder) { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropTable( - name: "U2f"); - } + migrationBuilder.DropTable( + name: "U2f"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.CreateTable( - name: "U2f", - columns: table => new - { - Id = table.Column(type: "integer", nullable: false) - .Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), - AppId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), - Challenge = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), - CreationDate = table.Column(type: "timestamp without time zone", nullable: false), - KeyHandle = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), - UserId = table.Column(type: "uuid", nullable: false), - Version = table.Column(type: "character varying(20)", maxLength: 20, nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_U2f", x => x.Id); - table.ForeignKey( - name: "FK_U2f_User_UserId", - column: x => x.UserId, - principalTable: "User", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.CreateTable( + name: "U2f", + columns: table => new + { + Id = table.Column(type: "integer", nullable: false) + .Annotation("Npgsql:ValueGenerationStrategy", NpgsqlValueGenerationStrategy.IdentityByDefaultColumn), + AppId = table.Column(type: "character varying(50)", maxLength: 50, nullable: true), + Challenge = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), + CreationDate = table.Column(type: "timestamp without time zone", nullable: false), + KeyHandle = table.Column(type: "character varying(200)", maxLength: 200, nullable: true), + UserId = table.Column(type: "uuid", nullable: false), + Version = table.Column(type: "character varying(20)", maxLength: 20, nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_U2f", x => x.Id); + table.ForeignKey( + name: "FK_U2f_User_UserId", + column: x => x.UserId, + principalTable: "User", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); - migrationBuilder.CreateIndex( - name: "IX_U2f_UserId", - table: "U2f", - column: "UserId"); - } + migrationBuilder.CreateIndex( + name: "IX_U2f_UserId", + table: "U2f", + column: "UserId"); } } diff --git a/util/PostgresMigrations/Migrations/20220301211818_FailedLoginCaptcha.cs b/util/PostgresMigrations/Migrations/20220301211818_FailedLoginCaptcha.cs index 6c57172fb..6015ef357 100644 --- a/util/PostgresMigrations/Migrations/20220301211818_FailedLoginCaptcha.cs +++ b/util/PostgresMigrations/Migrations/20220301211818_FailedLoginCaptcha.cs @@ -1,34 +1,33 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations +namespace Bit.PostgresMigrations.Migrations; + +public partial class FailedLoginCaptcha : Migration { - public partial class FailedLoginCaptcha : Migration + protected override void Up(MigrationBuilder migrationBuilder) { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "FailedLoginCount", - table: "User", - type: "integer", - nullable: false, - defaultValue: 0); + migrationBuilder.AddColumn( + name: "FailedLoginCount", + table: "User", + type: "integer", + nullable: false, + defaultValue: 0); - migrationBuilder.AddColumn( - name: "LastFailedLoginDate", - table: "User", - type: "timestamp without time zone", - nullable: true); - } + migrationBuilder.AddColumn( + name: "LastFailedLoginDate", + table: "User", + type: "timestamp without time zone", + nullable: true); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "FailedLoginCount", - table: "User"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "FailedLoginCount", + table: "User"); - migrationBuilder.DropColumn( - name: "LastFailedLoginDate", - table: "User"); - } + migrationBuilder.DropColumn( + name: "LastFailedLoginDate", + table: "User"); } } diff --git a/util/PostgresMigrations/Migrations/20220322183505_SelfHostF4E.cs b/util/PostgresMigrations/Migrations/20220322183505_SelfHostF4E.cs index b636101b0..0c030f0dd 100644 --- a/util/PostgresMigrations/Migrations/20220322183505_SelfHostF4E.cs +++ b/util/PostgresMigrations/Migrations/20220322183505_SelfHostF4E.cs @@ -1,154 +1,153 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations +namespace Bit.PostgresMigrations.Migrations; + +public partial class SelfHostF4E : Migration { - public partial class SelfHostF4E : Migration + private const string _scriptLocationTemplate = "2022-03-01_00_{0}_MigrateOrganizationApiKeys.psql"; + + protected override void Up(MigrationBuilder migrationBuilder) { - private const string _scriptLocationTemplate = "2022-03-01_00_{0}_MigrateOrganizationApiKeys.psql"; + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Installation_InstallationId", + table: "OrganizationSponsorship"); - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Installation_InstallationId", - table: "OrganizationSponsorship"); + migrationBuilder.DropIndex( + name: "IX_OrganizationSponsorship_InstallationId", + table: "OrganizationSponsorship"); - migrationBuilder.DropIndex( - name: "IX_OrganizationSponsorship_InstallationId", - table: "OrganizationSponsorship"); + migrationBuilder.DropColumn( + name: "InstallationId", + table: "OrganizationSponsorship"); - migrationBuilder.DropColumn( - name: "InstallationId", - table: "OrganizationSponsorship"); + migrationBuilder.DropColumn( + name: "TimesRenewedWithoutValidation", + table: "OrganizationSponsorship"); - migrationBuilder.DropColumn( - name: "TimesRenewedWithoutValidation", - table: "OrganizationSponsorship"); + migrationBuilder.CreateTable( + name: "OrganizationApiKey", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + OrganizationId = table.Column(type: "uuid", nullable: false), + Type = table.Column(type: "smallint", nullable: false), + ApiKey = table.Column(type: "character varying(30)", maxLength: 30, nullable: true), + RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) + }, + constraints: table => + { + table.PrimaryKey("PK_OrganizationApiKey", x => x.Id); + table.ForeignKey( + name: "FK_OrganizationApiKey_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); - migrationBuilder.CreateTable( - name: "OrganizationApiKey", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - OrganizationId = table.Column(type: "uuid", nullable: false), - Type = table.Column(type: "smallint", nullable: false), - ApiKey = table.Column(type: "character varying(30)", maxLength: 30, nullable: true), - RevisionDate = table.Column(type: "timestamp without time zone", nullable: false) - }, - constraints: table => - { - table.PrimaryKey("PK_OrganizationApiKey", x => x.Id); - table.ForeignKey( - name: "FK_OrganizationApiKey_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); + migrationBuilder.SqlResource(_scriptLocationTemplate); - migrationBuilder.SqlResource(_scriptLocationTemplate); + migrationBuilder.DropColumn( + name: "ApiKey", + table: "Organization"); - migrationBuilder.DropColumn( - name: "ApiKey", - table: "Organization"); + migrationBuilder.RenameColumn( + name: "SponsorshipLapsedDate", + table: "OrganizationSponsorship", + newName: "ValidUntil"); - migrationBuilder.RenameColumn( - name: "SponsorshipLapsedDate", - table: "OrganizationSponsorship", - newName: "ValidUntil"); - - migrationBuilder.RenameColumn( - name: "CloudSponsor", - table: "OrganizationSponsorship", - newName: "ToDelete"); + migrationBuilder.RenameColumn( + name: "CloudSponsor", + table: "OrganizationSponsorship", + newName: "ToDelete"); - migrationBuilder.CreateTable( - name: "OrganizationConnection", - columns: table => new - { - Id = table.Column(type: "uuid", nullable: false), - Type = table.Column(type: "smallint", nullable: false), - OrganizationId = table.Column(type: "uuid", nullable: false), - Enabled = table.Column(type: "boolean", nullable: false), - Config = table.Column(type: "text", nullable: true) - }, - constraints: table => - { - table.PrimaryKey("PK_OrganizationConnection", x => x.Id); - table.ForeignKey( - name: "FK_OrganizationConnection_Organization_OrganizationId", - column: x => x.OrganizationId, - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - }); + migrationBuilder.CreateTable( + name: "OrganizationConnection", + columns: table => new + { + Id = table.Column(type: "uuid", nullable: false), + Type = table.Column(type: "smallint", nullable: false), + OrganizationId = table.Column(type: "uuid", nullable: false), + Enabled = table.Column(type: "boolean", nullable: false), + Config = table.Column(type: "text", nullable: true) + }, + constraints: table => + { + table.PrimaryKey("PK_OrganizationConnection", x => x.Id); + table.ForeignKey( + name: "FK_OrganizationConnection_Organization_OrganizationId", + column: x => x.OrganizationId, + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + }); - migrationBuilder.CreateIndex( - name: "IX_OrganizationApiKey_OrganizationId", - table: "OrganizationApiKey", - column: "OrganizationId"); + migrationBuilder.CreateIndex( + name: "IX_OrganizationApiKey_OrganizationId", + table: "OrganizationApiKey", + column: "OrganizationId"); - migrationBuilder.CreateIndex( - name: "IX_OrganizationConnection_OrganizationId", - table: "OrganizationConnection", - column: "OrganizationId"); - } + migrationBuilder.CreateIndex( + name: "IX_OrganizationConnection_OrganizationId", + table: "OrganizationConnection", + column: "OrganizationId"); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "ApiKey", - table: "Organization", - type: "character varying(30)", - maxLength: 30, - nullable: true); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "ApiKey", + table: "Organization", + type: "character varying(30)", + maxLength: 30, + nullable: true); - migrationBuilder.SqlResource(_scriptLocationTemplate); + migrationBuilder.SqlResource(_scriptLocationTemplate); - migrationBuilder.DropTable( - name: "OrganizationApiKey"); + migrationBuilder.DropTable( + name: "OrganizationApiKey"); - migrationBuilder.DropTable( - name: "OrganizationConnection"); + migrationBuilder.DropTable( + name: "OrganizationConnection"); - migrationBuilder.RenameColumn( - name: "ValidUntil", - table: "OrganizationSponsorship", - newName: "SponsorshipLapsedDate"); + migrationBuilder.RenameColumn( + name: "ValidUntil", + table: "OrganizationSponsorship", + newName: "SponsorshipLapsedDate"); - migrationBuilder.RenameColumn( - name: "ToDelete", - table: "OrganizationSponsorship", - newName: "CloudSponsor"); + migrationBuilder.RenameColumn( + name: "ToDelete", + table: "OrganizationSponsorship", + newName: "CloudSponsor"); - migrationBuilder.AddColumn( - name: "InstallationId", - table: "OrganizationSponsorship", - type: "uuid", - nullable: true); + migrationBuilder.AddColumn( + name: "InstallationId", + table: "OrganizationSponsorship", + type: "uuid", + nullable: true); - migrationBuilder.AddColumn( - name: "TimesRenewedWithoutValidation", - table: "OrganizationSponsorship", - type: "smallint", - nullable: false, - defaultValue: (byte)0); + migrationBuilder.AddColumn( + name: "TimesRenewedWithoutValidation", + table: "OrganizationSponsorship", + type: "smallint", + nullable: false, + defaultValue: (byte)0); - migrationBuilder.CreateIndex( - name: "IX_OrganizationSponsorship_InstallationId", - table: "OrganizationSponsorship", - column: "InstallationId"); + migrationBuilder.CreateIndex( + name: "IX_OrganizationSponsorship_InstallationId", + table: "OrganizationSponsorship", + column: "InstallationId"); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Installation_InstallationId", - table: "OrganizationSponsorship", - column: "InstallationId", - principalTable: "Installation", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - } + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Installation_InstallationId", + table: "OrganizationSponsorship", + column: "InstallationId", + principalTable: "Installation", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); } } diff --git a/util/PostgresMigrations/Migrations/20220411190525_SponsorshipBulkActions.cs b/util/PostgresMigrations/Migrations/20220411190525_SponsorshipBulkActions.cs index 46b76b2bf..7b569a62c 100644 --- a/util/PostgresMigrations/Migrations/20220411190525_SponsorshipBulkActions.cs +++ b/util/PostgresMigrations/Migrations/20220411190525_SponsorshipBulkActions.cs @@ -1,73 +1,72 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations +namespace Bit.PostgresMigrations.Migrations; + +public partial class SponsorshipBulkActions : Migration { - public partial class SponsorshipBulkActions : Migration + protected override void Up(MigrationBuilder migrationBuilder) { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", - table: "OrganizationSponsorship"); + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", + table: "OrganizationSponsorship"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationUserId", - table: "OrganizationSponsorship", - type: "uuid", - nullable: false, - defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), - oldClrType: typeof(Guid), - oldType: "uuid", - oldNullable: true); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationUserId", + table: "OrganizationSponsorship", + type: "uuid", + nullable: false, + defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), + oldClrType: typeof(Guid), + oldType: "uuid", + oldNullable: true); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationId", - table: "OrganizationSponsorship", - type: "uuid", - nullable: false, - defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), - oldClrType: typeof(Guid), - oldType: "uuid", - oldNullable: true); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationId", + table: "OrganizationSponsorship", + type: "uuid", + nullable: false, + defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), + oldClrType: typeof(Guid), + oldType: "uuid", + oldNullable: true); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId", - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - } + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId", + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", - table: "OrganizationSponsorship"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", + table: "OrganizationSponsorship"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationUserId", - table: "OrganizationSponsorship", - type: "uuid", - nullable: true, - oldClrType: typeof(Guid), - oldType: "uuid"); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationUserId", + table: "OrganizationSponsorship", + type: "uuid", + nullable: true, + oldClrType: typeof(Guid), + oldType: "uuid"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationId", - table: "OrganizationSponsorship", - type: "uuid", - nullable: true, - oldClrType: typeof(Guid), - oldType: "uuid"); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationId", + table: "OrganizationSponsorship", + type: "uuid", + nullable: true, + oldClrType: typeof(Guid), + oldType: "uuid"); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId", - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - } + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId", + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); } } diff --git a/util/PostgresMigrations/Migrations/20220420171153_AddInstallationIdToEvents.cs b/util/PostgresMigrations/Migrations/20220420171153_AddInstallationIdToEvents.cs index 94bfa5b7c..a02d9e70a 100644 --- a/util/PostgresMigrations/Migrations/20220420171153_AddInstallationIdToEvents.cs +++ b/util/PostgresMigrations/Migrations/20220420171153_AddInstallationIdToEvents.cs @@ -1,65 +1,64 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations +namespace Bit.PostgresMigrations.Migrations; + +public partial class AddInstallationIdToEvents : Migration { - public partial class AddInstallationIdToEvents : Migration + protected override void Up(MigrationBuilder migrationBuilder) { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", - table: "OrganizationSponsorship"); + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", + table: "OrganizationSponsorship"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationId", - table: "OrganizationSponsorship", - type: "uuid", - nullable: true, - oldClrType: typeof(Guid), - oldType: "uuid"); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationId", + table: "OrganizationSponsorship", + type: "uuid", + nullable: true, + oldClrType: typeof(Guid), + oldType: "uuid"); - migrationBuilder.AddColumn( - name: "InstallationId", - table: "Event", - type: "uuid", - nullable: true); + migrationBuilder.AddColumn( + name: "InstallationId", + table: "Event", + type: "uuid", + nullable: true); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId", - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Restrict); - } + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId", + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Restrict); + } - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", - table: "OrganizationSponsorship"); + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", + table: "OrganizationSponsorship"); - migrationBuilder.DropColumn( - name: "InstallationId", - table: "Event"); + migrationBuilder.DropColumn( + name: "InstallationId", + table: "Event"); - migrationBuilder.AlterColumn( - name: "SponsoringOrganizationId", - table: "OrganizationSponsorship", - type: "uuid", - nullable: false, - defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), - oldClrType: typeof(Guid), - oldType: "uuid", - oldNullable: true); + migrationBuilder.AlterColumn( + name: "SponsoringOrganizationId", + table: "OrganizationSponsorship", + type: "uuid", + nullable: false, + defaultValue: new Guid("00000000-0000-0000-0000-000000000000"), + oldClrType: typeof(Guid), + oldType: "uuid", + oldNullable: true); - migrationBuilder.AddForeignKey( - name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", - table: "OrganizationSponsorship", - column: "SponsoringOrganizationId", - principalTable: "Organization", - principalColumn: "Id", - onDelete: ReferentialAction.Cascade); - } + migrationBuilder.AddForeignKey( + name: "FK_OrganizationSponsorship_Organization_SponsoringOrganization~", + table: "OrganizationSponsorship", + column: "SponsoringOrganizationId", + principalTable: "Organization", + principalColumn: "Id", + onDelete: ReferentialAction.Cascade); } } diff --git a/util/PostgresMigrations/Migrations/20220524170740_DeviceUnknownVerification.cs b/util/PostgresMigrations/Migrations/20220524170740_DeviceUnknownVerification.cs index 880c65908..3dd1b4c5f 100644 --- a/util/PostgresMigrations/Migrations/20220524170740_DeviceUnknownVerification.cs +++ b/util/PostgresMigrations/Migrations/20220524170740_DeviceUnknownVerification.cs @@ -1,24 +1,23 @@ using Microsoft.EntityFrameworkCore.Migrations; -namespace Bit.PostgresMigrations.Migrations -{ - public partial class DeviceUnknownVerification : Migration - { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "UnknownDeviceVerificationEnabled", - table: "User", - type: "boolean", - nullable: false, - defaultValue: true); - } +namespace Bit.PostgresMigrations.Migrations; - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "UnknownDeviceVerificationEnabled", - table: "User"); - } +public partial class DeviceUnknownVerification : Migration +{ + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "UnknownDeviceVerificationEnabled", + table: "User", + type: "boolean", + nullable: false, + defaultValue: true); + } + + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "UnknownDeviceVerificationEnabled", + table: "User"); } } diff --git a/util/PostgresMigrations/Migrations/20220707162231_UseScimFlag.cs b/util/PostgresMigrations/Migrations/20220707162231_UseScimFlag.cs index 6a71e38fb..02c7ca90e 100644 --- a/util/PostgresMigrations/Migrations/20220707162231_UseScimFlag.cs +++ b/util/PostgresMigrations/Migrations/20220707162231_UseScimFlag.cs @@ -2,25 +2,24 @@ #nullable disable -namespace Bit.PostgresMigrations.Migrations -{ - public partial class UseScimFlag : Migration - { - protected override void Up(MigrationBuilder migrationBuilder) - { - migrationBuilder.AddColumn( - name: "UseScim", - table: "Organization", - type: "boolean", - nullable: false, - defaultValue: false); - } +namespace Bit.PostgresMigrations.Migrations; - protected override void Down(MigrationBuilder migrationBuilder) - { - migrationBuilder.DropColumn( - name: "UseScim", - table: "Organization"); - } +public partial class UseScimFlag : Migration +{ + protected override void Up(MigrationBuilder migrationBuilder) + { + migrationBuilder.AddColumn( + name: "UseScim", + table: "Organization", + type: "boolean", + nullable: false, + defaultValue: false); + } + + protected override void Down(MigrationBuilder migrationBuilder) + { + migrationBuilder.DropColumn( + name: "UseScim", + table: "Organization"); } } diff --git a/util/Server/Program.cs b/util/Server/Program.cs index 25f5fd440..767b96514 100644 --- a/util/Server/Program.cs +++ b/util/Server/Program.cs @@ -1,41 +1,40 @@ -namespace Bit.Server +namespace Bit.Server; + +public class Program { - public class Program + public static void Main(string[] args) { - public static void Main(string[] args) + var config = new ConfigurationBuilder() + .AddCommandLine(args) + .Build(); + + var builder = new WebHostBuilder() + .UseConfiguration(config) + .UseKestrel() + .UseStartup() + .ConfigureLogging((hostingContext, logging) => + { + logging.AddConsole().AddDebug(); + }) + .ConfigureKestrel((context, options) => { }); + + var contentRoot = config.GetValue("contentRoot"); + if (!string.IsNullOrWhiteSpace(contentRoot)) { - var config = new ConfigurationBuilder() - .AddCommandLine(args) - .Build(); - - var builder = new WebHostBuilder() - .UseConfiguration(config) - .UseKestrel() - .UseStartup() - .ConfigureLogging((hostingContext, logging) => - { - logging.AddConsole().AddDebug(); - }) - .ConfigureKestrel((context, options) => { }); - - var contentRoot = config.GetValue("contentRoot"); - if (!string.IsNullOrWhiteSpace(contentRoot)) - { - builder.UseContentRoot(contentRoot); - } - else - { - builder.UseContentRoot(Directory.GetCurrentDirectory()); - } - - var webRoot = config.GetValue("webRoot"); - if (string.IsNullOrWhiteSpace(webRoot)) - { - builder.UseWebRoot(webRoot); - } - - var host = builder.Build(); - host.Run(); + builder.UseContentRoot(contentRoot); } + else + { + builder.UseContentRoot(Directory.GetCurrentDirectory()); + } + + var webRoot = config.GetValue("webRoot"); + if (string.IsNullOrWhiteSpace(webRoot)) + { + builder.UseWebRoot(webRoot); + } + + var host = builder.Build(); + host.Run(); } } diff --git a/util/Server/Startup.cs b/util/Server/Startup.cs index 362d87383..7b195beb5 100644 --- a/util/Server/Startup.cs +++ b/util/Server/Startup.cs @@ -1,90 +1,89 @@ using System.Globalization; using Microsoft.AspNetCore.StaticFiles; -namespace Bit.Server +namespace Bit.Server; + +public class Startup { - public class Startup + private readonly List _longCachedPaths = new List { - private readonly List _longCachedPaths = new List - { - "/app/", "/locales/", "/fonts/", "/connectors/", "/scripts/" - }; - private readonly List _mediumCachedPaths = new List - { - "/images/" - }; + "/app/", "/locales/", "/fonts/", "/connectors/", "/scripts/" + }; + private readonly List _mediumCachedPaths = new List + { + "/images/" + }; - public Startup() - { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - } + public Startup() + { + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + } - public void ConfigureServices(IServiceCollection services) - { - services.AddRouting(); - } + public void ConfigureServices(IServiceCollection services) + { + services.AddRouting(); + } - public void Configure( - IApplicationBuilder app, - IConfiguration configuration) + public void Configure( + IApplicationBuilder app, + IConfiguration configuration) + { + if (configuration.GetValue("serveUnknown") ?? false) { - if (configuration.GetValue("serveUnknown") ?? false) + app.UseStaticFiles(new StaticFileOptions { - app.UseStaticFiles(new StaticFileOptions - { - ServeUnknownFileTypes = true, - DefaultContentType = "application/octet-stream" - }); - app.UseRouting(); - app.UseEndpoints(endpoints => - { - endpoints.MapGet("/alive", - async context => await context.Response.WriteAsync(System.DateTime.UtcNow.ToString())); - }); - } - else if (configuration.GetValue("webVault") ?? false) + ServeUnknownFileTypes = true, + DefaultContentType = "application/octet-stream" + }); + app.UseRouting(); + app.UseEndpoints(endpoints => { - // TODO: This should be removed when asp.net natively support avif - var provider = new FileExtensionContentTypeProvider { Mappings = { [".avif"] = "image/avif" } }; + endpoints.MapGet("/alive", + async context => await context.Response.WriteAsync(System.DateTime.UtcNow.ToString())); + }); + } + else if (configuration.GetValue("webVault") ?? false) + { + // TODO: This should be removed when asp.net natively support avif + var provider = new FileExtensionContentTypeProvider { Mappings = { [".avif"] = "image/avif" } }; - var options = new DefaultFilesOptions(); - options.DefaultFileNames.Clear(); - options.DefaultFileNames.Add("index.html"); - app.UseDefaultFiles(options); - app.UseStaticFiles(new StaticFileOptions + var options = new DefaultFilesOptions(); + options.DefaultFileNames.Clear(); + options.DefaultFileNames.Add("index.html"); + app.UseDefaultFiles(options); + app.UseStaticFiles(new StaticFileOptions + { + ContentTypeProvider = provider, + OnPrepareResponse = ctx => { - ContentTypeProvider = provider, - OnPrepareResponse = ctx => + if (!ctx.Context.Request.Path.HasValue || + ctx.Context.Response.Headers.ContainsKey("Cache-Control")) { - if (!ctx.Context.Request.Path.HasValue || - ctx.Context.Response.Headers.ContainsKey("Cache-Control")) - { - return; - } - var path = ctx.Context.Request.Path.Value; - if (_longCachedPaths.Any(ext => path.StartsWith(ext))) - { - // 14 days - ctx.Context.Response.Headers.Append("Cache-Control", "max-age=1209600"); - } - if (_mediumCachedPaths.Any(ext => path.StartsWith(ext))) - { - // 7 days - ctx.Context.Response.Headers.Append("Cache-Control", "max-age=604800"); - } + return; } - }); - } - else + var path = ctx.Context.Request.Path.Value; + if (_longCachedPaths.Any(ext => path.StartsWith(ext))) + { + // 14 days + ctx.Context.Response.Headers.Append("Cache-Control", "max-age=1209600"); + } + if (_mediumCachedPaths.Any(ext => path.StartsWith(ext))) + { + // 7 days + ctx.Context.Response.Headers.Append("Cache-Control", "max-age=604800"); + } + } + }); + } + else + { + app.UseFileServer(); + app.UseRouting(); + app.UseEndpoints(endpoints => { - app.UseFileServer(); - app.UseRouting(); - app.UseEndpoints(endpoints => - { - endpoints.MapGet("/alive", - async context => await context.Response.WriteAsync(System.DateTime.UtcNow.ToString())); - }); - } + endpoints.MapGet("/alive", + async context => await context.Response.WriteAsync(System.DateTime.UtcNow.ToString())); + }); } } } diff --git a/util/Setup/AppIdBuilder.cs b/util/Setup/AppIdBuilder.cs index 46fe222b6..6e984aa90 100644 --- a/util/Setup/AppIdBuilder.cs +++ b/util/Setup/AppIdBuilder.cs @@ -1,34 +1,33 @@ -namespace Bit.Setup +namespace Bit.Setup; + +public class AppIdBuilder { - public class AppIdBuilder + private readonly Context _context; + + public AppIdBuilder(Context context) { - private readonly Context _context; + _context = context; + } - public AppIdBuilder(Context context) + public void Build() + { + var model = new TemplateModel { - _context = context; - } + Url = _context.Config.Url + }; - public void Build() + // Needed for backwards compatability with migrated U2F tokens. + Helpers.WriteLine(_context, "Building FIDO U2F app id."); + Directory.CreateDirectory("/bitwarden/web/"); + var template = Helpers.ReadTemplate("AppId"); + using (var sw = File.CreateText("/bitwarden/web/app-id.json")) { - var model = new TemplateModel - { - Url = _context.Config.Url - }; - - // Needed for backwards compatability with migrated U2F tokens. - Helpers.WriteLine(_context, "Building FIDO U2F app id."); - Directory.CreateDirectory("/bitwarden/web/"); - var template = Helpers.ReadTemplate("AppId"); - using (var sw = File.CreateText("/bitwarden/web/app-id.json")) - { - sw.Write(template(model)); - } - } - - public class TemplateModel - { - public string Url { get; set; } + sw.Write(template(model)); } } + + public class TemplateModel + { + public string Url { get; set; } + } } diff --git a/util/Setup/CertBuilder.cs b/util/Setup/CertBuilder.cs index 3a43888f2..a01e9d98b 100644 --- a/util/Setup/CertBuilder.cs +++ b/util/Setup/CertBuilder.cs @@ -1,112 +1,111 @@ -namespace Bit.Setup -{ - public class CertBuilder - { - private readonly Context _context; +namespace Bit.Setup; - public CertBuilder(Context context) +public class CertBuilder +{ + private readonly Context _context; + + public CertBuilder(Context context) + { + _context = context; + } + + public void BuildForInstall() + { + if (_context.Stub) { - _context = context; + _context.Config.Ssl = true; + _context.Install.Trusted = true; + _context.Install.SelfSignedCert = false; + _context.Install.DiffieHellman = false; + _context.Install.IdentityCertPassword = "IDENTITY_CERT_PASSWORD"; + return; } - public void BuildForInstall() + _context.Config.Ssl = _context.Config.SslManagedLetsEncrypt; + + if (!_context.Config.Ssl) { - if (_context.Stub) + var skipSSL = _context.Parameters.ContainsKey("skip-ssl") && (_context.Parameters["skip-ssl"] == "true" || _context.Parameters["skip-ssl"] == "1"); + + if (!skipSSL) { - _context.Config.Ssl = true; - _context.Install.Trusted = true; - _context.Install.SelfSignedCert = false; - _context.Install.DiffieHellman = false; - _context.Install.IdentityCertPassword = "IDENTITY_CERT_PASSWORD"; - return; - } - - _context.Config.Ssl = _context.Config.SslManagedLetsEncrypt; - - if (!_context.Config.Ssl) - { - var skipSSL = _context.Parameters.ContainsKey("skip-ssl") && (_context.Parameters["skip-ssl"] == "true" || _context.Parameters["skip-ssl"] == "1"); - - if (!skipSSL) + _context.Config.Ssl = Helpers.ReadQuestion("Do you have a SSL certificate to use?"); + if (_context.Config.Ssl) { - _context.Config.Ssl = Helpers.ReadQuestion("Do you have a SSL certificate to use?"); - if (_context.Config.Ssl) - { - Directory.CreateDirectory($"/bitwarden/ssl/{_context.Install.Domain}/"); - var message = "Make sure 'certificate.crt' and 'private.key' are provided in the \n" + - "appropriate directory before running 'start' (see docs for info)."; - Helpers.ShowBanner(_context, "NOTE", message); - } - else if (Helpers.ReadQuestion("Do you want to generate a self-signed SSL certificate?")) - { - Directory.CreateDirectory($"/bitwarden/ssl/self/{_context.Install.Domain}/"); - Helpers.WriteLine(_context, "Generating self signed SSL certificate."); - _context.Config.Ssl = true; - _context.Install.Trusted = false; - _context.Install.SelfSignedCert = true; - Helpers.Exec("openssl req -x509 -newkey rsa:4096 -sha256 -nodes -days 36500 " + - $"-keyout /bitwarden/ssl/self/{_context.Install.Domain}/private.key " + - $"-out /bitwarden/ssl/self/{_context.Install.Domain}/certificate.crt " + - $"-reqexts SAN -extensions SAN " + - $"-config <(cat /usr/lib/ssl/openssl.cnf <(printf '[SAN]\nsubjectAltName=DNS:{_context.Install.Domain}\nbasicConstraints=CA:true')) " + - $"-subj \"/C=US/ST=California/L=Santa Barbara/O=Bitwarden Inc./OU=Bitwarden/CN={_context.Install.Domain}\""); - } + Directory.CreateDirectory($"/bitwarden/ssl/{_context.Install.Domain}/"); + var message = "Make sure 'certificate.crt' and 'private.key' are provided in the \n" + + "appropriate directory before running 'start' (see docs for info)."; + Helpers.ShowBanner(_context, "NOTE", message); + } + else if (Helpers.ReadQuestion("Do you want to generate a self-signed SSL certificate?")) + { + Directory.CreateDirectory($"/bitwarden/ssl/self/{_context.Install.Domain}/"); + Helpers.WriteLine(_context, "Generating self signed SSL certificate."); + _context.Config.Ssl = true; + _context.Install.Trusted = false; + _context.Install.SelfSignedCert = true; + Helpers.Exec("openssl req -x509 -newkey rsa:4096 -sha256 -nodes -days 36500 " + + $"-keyout /bitwarden/ssl/self/{_context.Install.Domain}/private.key " + + $"-out /bitwarden/ssl/self/{_context.Install.Domain}/certificate.crt " + + $"-reqexts SAN -extensions SAN " + + $"-config <(cat /usr/lib/ssl/openssl.cnf <(printf '[SAN]\nsubjectAltName=DNS:{_context.Install.Domain}\nbasicConstraints=CA:true')) " + + $"-subj \"/C=US/ST=California/L=Santa Barbara/O=Bitwarden Inc./OU=Bitwarden/CN={_context.Install.Domain}\""); } } - - if (_context.Config.SslManagedLetsEncrypt) - { - _context.Install.Trusted = true; - _context.Install.DiffieHellman = true; - Directory.CreateDirectory($"/bitwarden/letsencrypt/live/{_context.Install.Domain}/"); - Helpers.Exec($"openssl dhparam -out " + - $"/bitwarden/letsencrypt/live/{_context.Install.Domain}/dhparam.pem 2048"); - } - else if (_context.Config.Ssl && !_context.Install.SelfSignedCert) - { - _context.Install.Trusted = Helpers.ReadQuestion("Is this a trusted SSL certificate " + - "(requires ca.crt, see docs)?"); - } - - Helpers.WriteLine(_context, "Generating key for IdentityServer."); - _context.Install.IdentityCertPassword = Helpers.SecureRandomString(32, alpha: true, numeric: true); - Directory.CreateDirectory("/bitwarden/identity/"); - Helpers.Exec("openssl req -x509 -newkey rsa:4096 -sha256 -nodes -keyout identity.key " + - "-out identity.crt -subj \"/CN=Bitwarden IdentityServer\" -days 36500"); - Helpers.Exec("openssl pkcs12 -export -out /bitwarden/identity/identity.pfx -inkey identity.key " + - $"-in identity.crt -passout pass:{_context.Install.IdentityCertPassword}"); - - Helpers.WriteLine(_context); - - if (!_context.Config.Ssl) - { - var message = "You are not using a SSL certificate. Bitwarden requires HTTPS to operate. \n" + - "You must front your installation with a HTTPS proxy or the web vault (and \n" + - "other Bitwarden apps) will not work properly."; - Helpers.ShowBanner(_context, "WARNING", message, ConsoleColor.Yellow); - } - else if (_context.Config.Ssl && !_context.Install.Trusted) - { - var message = "You are using an untrusted SSL certificate. This certificate will not be \n" + - "trusted by Bitwarden client applications. You must add this certificate to \n" + - "the trusted store on each device or else you will receive errors when trying \n" + - "to connect to your installation."; - Helpers.ShowBanner(_context, "WARNING", message, ConsoleColor.Yellow); - } } - public void BuildForUpdater() + if (_context.Config.SslManagedLetsEncrypt) { - if (_context.Config.EnableKeyConnector && !File.Exists("/bitwarden/key-connector/bwkc.pfx")) - { - Directory.CreateDirectory("/bitwarden/key-connector/"); - var keyConnectorCertPassword = Helpers.GetValueFromEnvFile("key-connector", - "keyConnectorSettings__certificate__filesystemPassword"); - Helpers.Exec("openssl req -x509 -newkey rsa:4096 -sha256 -nodes -keyout bwkc.key " + - "-out bwkc.crt -subj \"/CN=Bitwarden Key Connector\" -days 36500"); - Helpers.Exec("openssl pkcs12 -export -out /bitwarden/key-connector/bwkc.pfx -inkey bwkc.key " + - $"-in bwkc.crt -passout pass:{keyConnectorCertPassword}"); - } + _context.Install.Trusted = true; + _context.Install.DiffieHellman = true; + Directory.CreateDirectory($"/bitwarden/letsencrypt/live/{_context.Install.Domain}/"); + Helpers.Exec($"openssl dhparam -out " + + $"/bitwarden/letsencrypt/live/{_context.Install.Domain}/dhparam.pem 2048"); + } + else if (_context.Config.Ssl && !_context.Install.SelfSignedCert) + { + _context.Install.Trusted = Helpers.ReadQuestion("Is this a trusted SSL certificate " + + "(requires ca.crt, see docs)?"); + } + + Helpers.WriteLine(_context, "Generating key for IdentityServer."); + _context.Install.IdentityCertPassword = Helpers.SecureRandomString(32, alpha: true, numeric: true); + Directory.CreateDirectory("/bitwarden/identity/"); + Helpers.Exec("openssl req -x509 -newkey rsa:4096 -sha256 -nodes -keyout identity.key " + + "-out identity.crt -subj \"/CN=Bitwarden IdentityServer\" -days 36500"); + Helpers.Exec("openssl pkcs12 -export -out /bitwarden/identity/identity.pfx -inkey identity.key " + + $"-in identity.crt -passout pass:{_context.Install.IdentityCertPassword}"); + + Helpers.WriteLine(_context); + + if (!_context.Config.Ssl) + { + var message = "You are not using a SSL certificate. Bitwarden requires HTTPS to operate. \n" + + "You must front your installation with a HTTPS proxy or the web vault (and \n" + + "other Bitwarden apps) will not work properly."; + Helpers.ShowBanner(_context, "WARNING", message, ConsoleColor.Yellow); + } + else if (_context.Config.Ssl && !_context.Install.Trusted) + { + var message = "You are using an untrusted SSL certificate. This certificate will not be \n" + + "trusted by Bitwarden client applications. You must add this certificate to \n" + + "the trusted store on each device or else you will receive errors when trying \n" + + "to connect to your installation."; + Helpers.ShowBanner(_context, "WARNING", message, ConsoleColor.Yellow); + } + } + + public void BuildForUpdater() + { + if (_context.Config.EnableKeyConnector && !File.Exists("/bitwarden/key-connector/bwkc.pfx")) + { + Directory.CreateDirectory("/bitwarden/key-connector/"); + var keyConnectorCertPassword = Helpers.GetValueFromEnvFile("key-connector", + "keyConnectorSettings__certificate__filesystemPassword"); + Helpers.Exec("openssl req -x509 -newkey rsa:4096 -sha256 -nodes -keyout bwkc.key " + + "-out bwkc.crt -subj \"/CN=Bitwarden Key Connector\" -days 36500"); + Helpers.Exec("openssl pkcs12 -export -out /bitwarden/key-connector/bwkc.pfx -inkey bwkc.key " + + $"-in bwkc.crt -passout pass:{keyConnectorCertPassword}"); } } } diff --git a/util/Setup/Configuration.cs b/util/Setup/Configuration.cs index e0062b522..b58b87952 100644 --- a/util/Setup/Configuration.cs +++ b/util/Setup/Configuration.cs @@ -1,122 +1,121 @@ using System.ComponentModel; using YamlDotNet.Serialization; -namespace Bit.Setup +namespace Bit.Setup; + +public class Configuration { - public class Configuration + [Description("Note: After making changes to this file you need to run the `rebuild` or `update`\n" + + "command for them to be applied.\n\n" + + + "Full URL for accessing the installation from a browser. (Required)")] + public string Url { get; set; } = "https://localhost"; + + [Description("Auto-generate the `./docker/docker-compose.yml` config file.\n" + + "WARNING: Disabling generated config files can break future updates. You will be\n" + + "responsible for maintaining this config file.\n" + + "Template: https://github.com/bitwarden/server/blob/master/util/Setup/Templates/DockerCompose.hbs")] + public bool GenerateComposeConfig { get; set; } = true; + + [Description("Auto-generate the `./nginx/default.conf` file.\n" + + "WARNING: Disabling generated config files can break future updates. You will be\n" + + "responsible for maintaining this config file.\n" + + "Template: https://github.com/bitwarden/server/blob/master/util/Setup/Templates/NginxConfig.hbs")] + public bool GenerateNginxConfig { get; set; } = true; + + [Description("Docker compose file port mapping for HTTP. Leave empty to remove the port mapping.\n" + + "Learn more: https://docs.docker.com/compose/compose-file/#ports")] + public string HttpPort { get; set; } = "80"; + + [Description("Docker compose file port mapping for HTTPS. Leave empty to remove the port mapping.\n" + + "Learn more: https://docs.docker.com/compose/compose-file/#ports")] + public string HttpsPort { get; set; } = "443"; + + [Description("Docker compose file version. Leave empty for default.\n" + + "Learn more: https://docs.docker.com/compose/compose-file/compose-versioning/")] + public string ComposeVersion { get; set; } + + [Description("Configure Nginx for Captcha.")] + public bool Captcha { get; set; } = false; + + [Description("Configure Nginx for SSL.")] + public bool Ssl { get; set; } = true; + + [Description("SSL versions used by Nginx (ssl_protocols). Leave empty for recommended default.\n" + + "Learn more: https://wiki.mozilla.org/Security/Server_Side_TLS")] + public string SslVersions { get; set; } + + [Description("SSL ciphersuites used by Nginx (ssl_ciphers). Leave empty for recommended default.\n" + + "Learn more: https://wiki.mozilla.org/Security/Server_Side_TLS")] + public string SslCiphersuites { get; set; } + + [Description("Installation uses a managed Let's Encrypt certificate.")] + public bool SslManagedLetsEncrypt { get; set; } + + [Description("The actual certificate. (Required if using SSL without managed Let's Encrypt)\n" + + "Note: Path uses the container's ssl directory. The `./ssl` host directory is mapped to\n" + + "`/etc/ssl` within the container.")] + public string SslCertificatePath { get; set; } + + [Description("The certificate's private key. (Required if using SSL without managed Let's Encrypt)\n" + + "Note: Path uses the container's ssl directory. The `./ssl` host directory is mapped to\n" + + "`/etc/ssl` within the container.")] + public string SslKeyPath { get; set; } + + [Description("If the certificate is trusted by a CA, you should provide the CA's certificate.\n" + + "Note: Path uses the container's ssl directory. The `./ssl` host directory is mapped to\n" + + "`/etc/ssl` within the container.")] + public string SslCaPath { get; set; } + + [Description("Diffie Hellman ephemeral parameters\n" + + "Learn more: https://security.stackexchange.com/q/94390/79072\n" + + "Note: Path uses the container's ssl directory. The `./ssl` host directory is mapped to\n" + + "`/etc/ssl` within the container.")] + public string SslDiffieHellmanPath { get; set; } + + [Description("Nginx Header Content-Security-Policy parameter\n" + + "WARNING: Reconfiguring this parameter may break features. By changing this parameter\n" + + "you become responsible for maintaining this value.")] + public string NginxHeaderContentSecurityPolicy { get; set; } = "default-src 'self'; style-src 'self' " + + "'unsafe-inline'; img-src 'self' data: https://haveibeenpwned.com https://www.gravatar.com; " + + "child-src 'self' https://*.duosecurity.com https://*.duofederal.com; " + + "frame-src 'self' https://*.duosecurity.com https://*.duofederal.com; " + + "connect-src 'self' wss://{0} https://api.pwnedpasswords.com " + + "https://2fa.directory; object-src 'self' blob:;"; + + [Description("Communicate with the Bitwarden push relay service (push.bitwarden.com) for mobile\n" + + "app live sync.")] + public bool PushNotifications { get; set; } = true; + + [Description("Use a docker volume (`mssql_data`) instead of a host-mapped volume for the persisted " + + "database.\n" + + "WARNING: Changing this value will cause you to lose access to the existing persisted database.\n" + + "Learn more: https://docs.docker.com/storage/volumes/")] + public bool DatabaseDockerVolume { get; set; } + + [Description("Defines \"real\" IPs in nginx.conf. Useful for defining proxy servers that forward the \n" + + "client IP address.\n" + + "Learn more: https://nginx.org/en/docs/http/ngx_http_realip_module.html\n\n" + + "Defined as a dictionary, e.g.:\n" + + "real_ips: ['10.10.0.0/24', '172.16.0.0/16']")] + public List RealIps { get; set; } + + [Description("Enable Key Connector (https://bitwarden.com/help/article/deploy-key-connector)")] + public bool EnableKeyConnector { get; set; } = false; + + [Description("Enable SCIM")] + public bool EnableScim { get; set; } = false; + + [YamlIgnore] + public string Domain { - [Description("Note: After making changes to this file you need to run the `rebuild` or `update`\n" + - "command for them to be applied.\n\n" + - - "Full URL for accessing the installation from a browser. (Required)")] - public string Url { get; set; } = "https://localhost"; - - [Description("Auto-generate the `./docker/docker-compose.yml` config file.\n" + - "WARNING: Disabling generated config files can break future updates. You will be\n" + - "responsible for maintaining this config file.\n" + - "Template: https://github.com/bitwarden/server/blob/master/util/Setup/Templates/DockerCompose.hbs")] - public bool GenerateComposeConfig { get; set; } = true; - - [Description("Auto-generate the `./nginx/default.conf` file.\n" + - "WARNING: Disabling generated config files can break future updates. You will be\n" + - "responsible for maintaining this config file.\n" + - "Template: https://github.com/bitwarden/server/blob/master/util/Setup/Templates/NginxConfig.hbs")] - public bool GenerateNginxConfig { get; set; } = true; - - [Description("Docker compose file port mapping for HTTP. Leave empty to remove the port mapping.\n" + - "Learn more: https://docs.docker.com/compose/compose-file/#ports")] - public string HttpPort { get; set; } = "80"; - - [Description("Docker compose file port mapping for HTTPS. Leave empty to remove the port mapping.\n" + - "Learn more: https://docs.docker.com/compose/compose-file/#ports")] - public string HttpsPort { get; set; } = "443"; - - [Description("Docker compose file version. Leave empty for default.\n" + - "Learn more: https://docs.docker.com/compose/compose-file/compose-versioning/")] - public string ComposeVersion { get; set; } - - [Description("Configure Nginx for Captcha.")] - public bool Captcha { get; set; } = false; - - [Description("Configure Nginx for SSL.")] - public bool Ssl { get; set; } = true; - - [Description("SSL versions used by Nginx (ssl_protocols). Leave empty for recommended default.\n" + - "Learn more: https://wiki.mozilla.org/Security/Server_Side_TLS")] - public string SslVersions { get; set; } - - [Description("SSL ciphersuites used by Nginx (ssl_ciphers). Leave empty for recommended default.\n" + - "Learn more: https://wiki.mozilla.org/Security/Server_Side_TLS")] - public string SslCiphersuites { get; set; } - - [Description("Installation uses a managed Let's Encrypt certificate.")] - public bool SslManagedLetsEncrypt { get; set; } - - [Description("The actual certificate. (Required if using SSL without managed Let's Encrypt)\n" + - "Note: Path uses the container's ssl directory. The `./ssl` host directory is mapped to\n" + - "`/etc/ssl` within the container.")] - public string SslCertificatePath { get; set; } - - [Description("The certificate's private key. (Required if using SSL without managed Let's Encrypt)\n" + - "Note: Path uses the container's ssl directory. The `./ssl` host directory is mapped to\n" + - "`/etc/ssl` within the container.")] - public string SslKeyPath { get; set; } - - [Description("If the certificate is trusted by a CA, you should provide the CA's certificate.\n" + - "Note: Path uses the container's ssl directory. The `./ssl` host directory is mapped to\n" + - "`/etc/ssl` within the container.")] - public string SslCaPath { get; set; } - - [Description("Diffie Hellman ephemeral parameters\n" + - "Learn more: https://security.stackexchange.com/q/94390/79072\n" + - "Note: Path uses the container's ssl directory. The `./ssl` host directory is mapped to\n" + - "`/etc/ssl` within the container.")] - public string SslDiffieHellmanPath { get; set; } - - [Description("Nginx Header Content-Security-Policy parameter\n" + - "WARNING: Reconfiguring this parameter may break features. By changing this parameter\n" + - "you become responsible for maintaining this value.")] - public string NginxHeaderContentSecurityPolicy { get; set; } = "default-src 'self'; style-src 'self' " + - "'unsafe-inline'; img-src 'self' data: https://haveibeenpwned.com https://www.gravatar.com; " + - "child-src 'self' https://*.duosecurity.com https://*.duofederal.com; " + - "frame-src 'self' https://*.duosecurity.com https://*.duofederal.com; " + - "connect-src 'self' wss://{0} https://api.pwnedpasswords.com " + - "https://2fa.directory; object-src 'self' blob:;"; - - [Description("Communicate with the Bitwarden push relay service (push.bitwarden.com) for mobile\n" + - "app live sync.")] - public bool PushNotifications { get; set; } = true; - - [Description("Use a docker volume (`mssql_data`) instead of a host-mapped volume for the persisted " + - "database.\n" + - "WARNING: Changing this value will cause you to lose access to the existing persisted database.\n" + - "Learn more: https://docs.docker.com/storage/volumes/")] - public bool DatabaseDockerVolume { get; set; } - - [Description("Defines \"real\" IPs in nginx.conf. Useful for defining proxy servers that forward the \n" + - "client IP address.\n" + - "Learn more: https://nginx.org/en/docs/http/ngx_http_realip_module.html\n\n" + - "Defined as a dictionary, e.g.:\n" + - "real_ips: ['10.10.0.0/24', '172.16.0.0/16']")] - public List RealIps { get; set; } - - [Description("Enable Key Connector (https://bitwarden.com/help/article/deploy-key-connector)")] - public bool EnableKeyConnector { get; set; } = false; - - [Description("Enable SCIM")] - public bool EnableScim { get; set; } = false; - - [YamlIgnore] - public string Domain + get { - get + if (Uri.TryCreate(Url, UriKind.Absolute, out var uri)) { - if (Uri.TryCreate(Url, UriKind.Absolute, out var uri)) - { - return uri.Host; - } - return null; + return uri.Host; } + return null; } } } diff --git a/util/Setup/Context.cs b/util/Setup/Context.cs index cf8efa90e..f82e5005c 100644 --- a/util/Setup/Context.cs +++ b/util/Setup/Context.cs @@ -1,153 +1,152 @@ using YamlDotNet.Serialization; using YamlDotNet.Serialization.NamingConventions; -namespace Bit.Setup +namespace Bit.Setup; + +public class Context { - public class Context + private const string ConfigPath = "/bitwarden/config.yml"; + + public string[] Args { get; set; } + public bool Quiet { get; set; } + public bool Stub { get; set; } + public IDictionary Parameters { get; set; } + public string OutputDir { get; set; } = "/etc/bitwarden"; + public string HostOS { get; set; } = "win"; + public string CoreVersion { get; set; } = "latest"; + public string WebVersion { get; set; } = "latest"; + public string KeyConnectorVersion { get; set; } = "latest"; + public Installation Install { get; set; } = new Installation(); + public Configuration Config { get; set; } = new Configuration(); + + public bool PrintToScreen() { - private const string ConfigPath = "/bitwarden/config.yml"; + return !Quiet || Parameters.ContainsKey("install"); + } - public string[] Args { get; set; } - public bool Quiet { get; set; } - public bool Stub { get; set; } - public IDictionary Parameters { get; set; } - public string OutputDir { get; set; } = "/etc/bitwarden"; - public string HostOS { get; set; } = "win"; - public string CoreVersion { get; set; } = "latest"; - public string WebVersion { get; set; } = "latest"; - public string KeyConnectorVersion { get; set; } = "latest"; - public Installation Install { get; set; } = new Installation(); - public Configuration Config { get; set; } = new Configuration(); - - public bool PrintToScreen() + public void LoadConfiguration() + { + if (!File.Exists(ConfigPath)) { - return !Quiet || Parameters.ContainsKey("install"); - } + Helpers.WriteLine(this, "No existing `config.yml` detected. Let's generate one."); - public void LoadConfiguration() - { - if (!File.Exists(ConfigPath)) + // Looks like updating from older version. Try to create config file. + var url = Helpers.GetValueFromEnvFile("global", "globalSettings__baseServiceUri__vault"); + if (!Uri.TryCreate(url, UriKind.Absolute, out var uri)) { - Helpers.WriteLine(this, "No existing `config.yml` detected. Let's generate one."); + Helpers.WriteLine(this, "Unable to determine existing installation url."); + return; + } + Config.Url = url; - // Looks like updating from older version. Try to create config file. - var url = Helpers.GetValueFromEnvFile("global", "globalSettings__baseServiceUri__vault"); - if (!Uri.TryCreate(url, UriKind.Absolute, out var uri)) + var push = Helpers.GetValueFromEnvFile("global", "globalSettings__pushRelayBaseUri"); + Config.PushNotifications = push != "REPLACE"; + + var composeFile = "/bitwarden/docker/docker-compose.yml"; + if (File.Exists(composeFile)) + { + var fileLines = File.ReadAllLines(composeFile); + foreach (var line in fileLines) { - Helpers.WriteLine(this, "Unable to determine existing installation url."); - return; - } - Config.Url = url; - - var push = Helpers.GetValueFromEnvFile("global", "globalSettings__pushRelayBaseUri"); - Config.PushNotifications = push != "REPLACE"; - - var composeFile = "/bitwarden/docker/docker-compose.yml"; - if (File.Exists(composeFile)) - { - var fileLines = File.ReadAllLines(composeFile); - foreach (var line in fileLines) + if (!line.StartsWith("# Parameter:")) { - if (!line.StartsWith("# Parameter:")) - { - continue; - } + continue; + } - var paramParts = line.Split("="); - if (paramParts.Length < 2) - { - continue; - } + var paramParts = line.Split("="); + if (paramParts.Length < 2) + { + continue; + } - if (paramParts[0] == "# Parameter:MssqlDataDockerVolume" && - bool.TryParse(paramParts[1], out var mssqlDataDockerVolume)) - { - Config.DatabaseDockerVolume = mssqlDataDockerVolume; - continue; - } + if (paramParts[0] == "# Parameter:MssqlDataDockerVolume" && + bool.TryParse(paramParts[1], out var mssqlDataDockerVolume)) + { + Config.DatabaseDockerVolume = mssqlDataDockerVolume; + continue; + } - if (paramParts[0] == "# Parameter:HttpPort" && int.TryParse(paramParts[1], out var httpPort)) - { - Config.HttpPort = httpPort == 0 ? null : httpPort.ToString(); - continue; - } + if (paramParts[0] == "# Parameter:HttpPort" && int.TryParse(paramParts[1], out var httpPort)) + { + Config.HttpPort = httpPort == 0 ? null : httpPort.ToString(); + continue; + } - if (paramParts[0] == "# Parameter:HttpsPort" && int.TryParse(paramParts[1], out var httpsPort)) - { - Config.HttpsPort = httpsPort == 0 ? null : httpsPort.ToString(); - continue; - } + if (paramParts[0] == "# Parameter:HttpsPort" && int.TryParse(paramParts[1], out var httpsPort)) + { + Config.HttpsPort = httpsPort == 0 ? null : httpsPort.ToString(); + continue; } } + } - var nginxFile = "/bitwarden/nginx/default.conf"; - if (File.Exists(nginxFile)) + var nginxFile = "/bitwarden/nginx/default.conf"; + if (File.Exists(nginxFile)) + { + var confContent = File.ReadAllText(nginxFile); + var selfSigned = confContent.Contains("/etc/ssl/self/"); + Config.Ssl = confContent.Contains("ssl http2;"); + Config.SslManagedLetsEncrypt = !selfSigned && confContent.Contains("/etc/letsencrypt/live/"); + var diffieHellman = confContent.Contains("/dhparam.pem;"); + var trusted = confContent.Contains("ssl_trusted_certificate "); + if (Config.SslManagedLetsEncrypt) { - var confContent = File.ReadAllText(nginxFile); - var selfSigned = confContent.Contains("/etc/ssl/self/"); - Config.Ssl = confContent.Contains("ssl http2;"); - Config.SslManagedLetsEncrypt = !selfSigned && confContent.Contains("/etc/letsencrypt/live/"); - var diffieHellman = confContent.Contains("/dhparam.pem;"); - var trusted = confContent.Contains("ssl_trusted_certificate "); - if (Config.SslManagedLetsEncrypt) + Config.Ssl = true; + } + else if (Config.Ssl) + { + var sslPath = selfSigned ? $"/etc/ssl/self/{Config.Domain}" : $"/etc/ssl/{Config.Domain}"; + Config.SslCertificatePath = string.Concat(sslPath, "/", "certificate.crt"); + Config.SslKeyPath = string.Concat(sslPath, "/", "private.key"); + if (trusted) { - Config.Ssl = true; + Config.SslCaPath = string.Concat(sslPath, "/", "ca.crt"); } - else if (Config.Ssl) + if (diffieHellman) { - var sslPath = selfSigned ? $"/etc/ssl/self/{Config.Domain}" : $"/etc/ssl/{Config.Domain}"; - Config.SslCertificatePath = string.Concat(sslPath, "/", "certificate.crt"); - Config.SslKeyPath = string.Concat(sslPath, "/", "private.key"); - if (trusted) - { - Config.SslCaPath = string.Concat(sslPath, "/", "ca.crt"); - } - if (diffieHellman) - { - Config.SslDiffieHellmanPath = string.Concat(sslPath, "/", "dhparam.pem"); - } + Config.SslDiffieHellmanPath = string.Concat(sslPath, "/", "dhparam.pem"); } } - - SaveConfiguration(); } - var configText = File.ReadAllText(ConfigPath); - var deserializer = new DeserializerBuilder() - .WithNamingConvention(UnderscoredNamingConvention.Instance) - .Build(); - Config = deserializer.Deserialize(configText); + SaveConfiguration(); } - public void SaveConfiguration() - { - if (Config == null) - { - throw new Exception("Config is null."); - } - var serializer = new SerializerBuilder() - .WithNamingConvention(UnderscoredNamingConvention.Instance) - .WithTypeInspector(inner => new CommentGatheringTypeInspector(inner)) - .WithEmissionPhaseObjectGraphVisitor(args => new CommentsObjectGraphVisitor(args.InnerVisitor)) - .Build(); - var yaml = serializer.Serialize(Config); - Directory.CreateDirectory("/bitwarden/"); - using (var sw = File.CreateText(ConfigPath)) - { - sw.Write(yaml); - } - } + var configText = File.ReadAllText(ConfigPath); + var deserializer = new DeserializerBuilder() + .WithNamingConvention(UnderscoredNamingConvention.Instance) + .Build(); + Config = deserializer.Deserialize(configText); + } - public class Installation + public void SaveConfiguration() + { + if (Config == null) { - public Guid InstallationId { get; set; } - public string InstallationKey { get; set; } - public bool DiffieHellman { get; set; } - public bool Trusted { get; set; } - public bool SelfSignedCert { get; set; } - public string IdentityCertPassword { get; set; } - public string Domain { get; set; } - public string Database { get; set; } + throw new Exception("Config is null."); + } + var serializer = new SerializerBuilder() + .WithNamingConvention(UnderscoredNamingConvention.Instance) + .WithTypeInspector(inner => new CommentGatheringTypeInspector(inner)) + .WithEmissionPhaseObjectGraphVisitor(args => new CommentsObjectGraphVisitor(args.InnerVisitor)) + .Build(); + var yaml = serializer.Serialize(Config); + Directory.CreateDirectory("/bitwarden/"); + using (var sw = File.CreateText(ConfigPath)) + { + sw.Write(yaml); } } + + public class Installation + { + public Guid InstallationId { get; set; } + public string InstallationKey { get; set; } + public bool DiffieHellman { get; set; } + public bool Trusted { get; set; } + public bool SelfSignedCert { get; set; } + public string IdentityCertPassword { get; set; } + public string Domain { get; set; } + public string Database { get; set; } + } } diff --git a/util/Setup/DockerComposeBuilder.cs b/util/Setup/DockerComposeBuilder.cs index d007ffe1c..0d76dc9e9 100644 --- a/util/Setup/DockerComposeBuilder.cs +++ b/util/Setup/DockerComposeBuilder.cs @@ -1,80 +1,79 @@ -namespace Bit.Setup +namespace Bit.Setup; + +public class DockerComposeBuilder { - public class DockerComposeBuilder + private readonly Context _context; + + public DockerComposeBuilder(Context context) { - private readonly Context _context; + _context = context; + } - public DockerComposeBuilder(Context context) + public void BuildForInstaller() + { + _context.Config.DatabaseDockerVolume = _context.HostOS == "mac"; + Build(); + } + + public void BuildForUpdater() + { + Build(); + } + + private void Build() + { + Directory.CreateDirectory("/bitwarden/docker/"); + Helpers.WriteLine(_context, "Building docker-compose.yml."); + if (!_context.Config.GenerateComposeConfig) { - _context = context; + Helpers.WriteLine(_context, "...skipped"); + return; } - public void BuildForInstaller() + var template = Helpers.ReadTemplate("DockerCompose"); + var model = new TemplateModel(_context); + using (var sw = File.CreateText("/bitwarden/docker/docker-compose.yml")) { - _context.Config.DatabaseDockerVolume = _context.HostOS == "mac"; - Build(); - } - - public void BuildForUpdater() - { - Build(); - } - - private void Build() - { - Directory.CreateDirectory("/bitwarden/docker/"); - Helpers.WriteLine(_context, "Building docker-compose.yml."); - if (!_context.Config.GenerateComposeConfig) - { - Helpers.WriteLine(_context, "...skipped"); - return; - } - - var template = Helpers.ReadTemplate("DockerCompose"); - var model = new TemplateModel(_context); - using (var sw = File.CreateText("/bitwarden/docker/docker-compose.yml")) - { - sw.Write(template(model)); - } - } - - public class TemplateModel - { - public TemplateModel(Context context) - { - if (!string.IsNullOrWhiteSpace(context.Config.ComposeVersion)) - { - ComposeVersion = context.Config.ComposeVersion; - } - MssqlDataDockerVolume = context.Config.DatabaseDockerVolume; - EnableKeyConnector = context.Config.EnableKeyConnector; - EnableScim = context.Config.EnableScim; - HttpPort = context.Config.HttpPort; - HttpsPort = context.Config.HttpsPort; - if (!string.IsNullOrWhiteSpace(context.CoreVersion)) - { - CoreVersion = context.CoreVersion; - } - if (!string.IsNullOrWhiteSpace(context.WebVersion)) - { - WebVersion = context.WebVersion; - } - if (!string.IsNullOrWhiteSpace(context.KeyConnectorVersion)) - { - KeyConnectorVersion = context.KeyConnectorVersion; - } - } - - public string ComposeVersion { get; set; } = "3"; - public bool MssqlDataDockerVolume { get; set; } - public bool EnableKeyConnector { get; set; } - public bool EnableScim { get; set; } - public string HttpPort { get; set; } - public string HttpsPort { get; set; } - public bool HasPort => !string.IsNullOrWhiteSpace(HttpPort) || !string.IsNullOrWhiteSpace(HttpsPort); - public string CoreVersion { get; set; } = "latest"; - public string WebVersion { get; set; } = "latest"; - public string KeyConnectorVersion { get; set; } = "latest"; + sw.Write(template(model)); } } + + public class TemplateModel + { + public TemplateModel(Context context) + { + if (!string.IsNullOrWhiteSpace(context.Config.ComposeVersion)) + { + ComposeVersion = context.Config.ComposeVersion; + } + MssqlDataDockerVolume = context.Config.DatabaseDockerVolume; + EnableKeyConnector = context.Config.EnableKeyConnector; + EnableScim = context.Config.EnableScim; + HttpPort = context.Config.HttpPort; + HttpsPort = context.Config.HttpsPort; + if (!string.IsNullOrWhiteSpace(context.CoreVersion)) + { + CoreVersion = context.CoreVersion; + } + if (!string.IsNullOrWhiteSpace(context.WebVersion)) + { + WebVersion = context.WebVersion; + } + if (!string.IsNullOrWhiteSpace(context.KeyConnectorVersion)) + { + KeyConnectorVersion = context.KeyConnectorVersion; + } + } + + public string ComposeVersion { get; set; } = "3"; + public bool MssqlDataDockerVolume { get; set; } + public bool EnableKeyConnector { get; set; } + public bool EnableScim { get; set; } + public string HttpPort { get; set; } + public string HttpsPort { get; set; } + public bool HasPort => !string.IsNullOrWhiteSpace(HttpPort) || !string.IsNullOrWhiteSpace(HttpsPort); + public string CoreVersion { get; set; } = "latest"; + public string WebVersion { get; set; } = "latest"; + public string KeyConnectorVersion { get; set; } = "latest"; + } } diff --git a/util/Setup/EnvironmentFileBuilder.cs b/util/Setup/EnvironmentFileBuilder.cs index 77a94bd06..893ca8537 100644 --- a/util/Setup/EnvironmentFileBuilder.cs +++ b/util/Setup/EnvironmentFileBuilder.cs @@ -1,225 +1,224 @@ using System.Data.SqlClient; -namespace Bit.Setup +namespace Bit.Setup; + +public class EnvironmentFileBuilder { - public class EnvironmentFileBuilder + private readonly Context _context; + + private IDictionary _globalValues; + private IDictionary _mssqlValues; + private IDictionary _globalOverrideValues; + private IDictionary _mssqlOverrideValues; + private IDictionary _keyConnectorOverrideValues; + + public EnvironmentFileBuilder(Context context) { - private readonly Context _context; - - private IDictionary _globalValues; - private IDictionary _mssqlValues; - private IDictionary _globalOverrideValues; - private IDictionary _mssqlOverrideValues; - private IDictionary _keyConnectorOverrideValues; - - public EnvironmentFileBuilder(Context context) + _context = context; + _globalValues = new Dictionary { - _context = context; - _globalValues = new Dictionary - { - ["ASPNETCORE_ENVIRONMENT"] = "Production", - ["globalSettings__selfHosted"] = "true", - ["globalSettings__baseServiceUri__vault"] = "http://localhost", - ["globalSettings__pushRelayBaseUri"] = "https://push.bitwarden.com", - }; - _mssqlValues = new Dictionary - { - ["ACCEPT_EULA"] = "Y", - ["MSSQL_PID"] = "Express", - ["SA_PASSWORD"] = "SECRET", - }; + ["ASPNETCORE_ENVIRONMENT"] = "Production", + ["globalSettings__selfHosted"] = "true", + ["globalSettings__baseServiceUri__vault"] = "http://localhost", + ["globalSettings__pushRelayBaseUri"] = "https://push.bitwarden.com", + }; + _mssqlValues = new Dictionary + { + ["ACCEPT_EULA"] = "Y", + ["MSSQL_PID"] = "Express", + ["SA_PASSWORD"] = "SECRET", + }; + } + + public void BuildForInstaller() + { + Directory.CreateDirectory("/bitwarden/env/"); + Init(); + Build(); + } + + public void BuildForUpdater() + { + Init(); + LoadExistingValues(_globalOverrideValues, "/bitwarden/env/global.override.env"); + LoadExistingValues(_mssqlOverrideValues, "/bitwarden/env/mssql.override.env"); + LoadExistingValues(_keyConnectorOverrideValues, "/bitwarden/env/key-connector.override.env"); + + if (_context.Config.PushNotifications && + _globalOverrideValues.ContainsKey("globalSettings__pushRelayBaseUri") && + _globalOverrideValues["globalSettings__pushRelayBaseUri"] == "REPLACE") + { + _globalOverrideValues.Remove("globalSettings__pushRelayBaseUri"); } - public void BuildForInstaller() + Build(); + } + + private void Init() + { + var dbPassword = _context.Stub ? "RANDOM_DATABASE_PASSWORD" : Helpers.SecureRandomString(32); + var dbConnectionString = new SqlConnectionStringBuilder { - Directory.CreateDirectory("/bitwarden/env/"); - Init(); - Build(); + DataSource = "tcp:mssql,1433", + InitialCatalog = _context.Install?.Database ?? "vault", + UserID = "sa", + Password = dbPassword, + MultipleActiveResultSets = false, + Encrypt = true, + ConnectTimeout = 30, + TrustServerCertificate = true, + PersistSecurityInfo = false + }.ConnectionString; + + _globalOverrideValues = new Dictionary + { + ["globalSettings__baseServiceUri__vault"] = _context.Config.Url, + ["globalSettings__sqlServer__connectionString"] = $"\"{dbConnectionString.Replace("\"", "\\\"")}\"", + ["globalSettings__identityServer__certificatePassword"] = _context.Install?.IdentityCertPassword, + ["globalSettings__internalIdentityKey"] = _context.Stub ? "RANDOM_IDENTITY_KEY" : + Helpers.SecureRandomString(64, alpha: true, numeric: true), + ["globalSettings__oidcIdentityClientKey"] = _context.Stub ? "RANDOM_IDENTITY_KEY" : + Helpers.SecureRandomString(64, alpha: true, numeric: true), + ["globalSettings__duo__aKey"] = _context.Stub ? "RANDOM_DUO_AKEY" : + Helpers.SecureRandomString(64, alpha: true, numeric: true), + ["globalSettings__installation__id"] = _context.Install?.InstallationId.ToString(), + ["globalSettings__installation__key"] = _context.Install?.InstallationKey, + ["globalSettings__yubico__clientId"] = "REPLACE", + ["globalSettings__yubico__key"] = "REPLACE", + ["globalSettings__mail__replyToEmail"] = $"no-reply@{_context.Config.Domain}", + ["globalSettings__mail__smtp__host"] = "REPLACE", + ["globalSettings__mail__smtp__port"] = "587", + ["globalSettings__mail__smtp__ssl"] = "false", + ["globalSettings__mail__smtp__username"] = "REPLACE", + ["globalSettings__mail__smtp__password"] = "REPLACE", + ["globalSettings__disableUserRegistration"] = "false", + ["globalSettings__hibpApiKey"] = "REPLACE", + ["adminSettings__admins"] = string.Empty, + }; + + if (!_context.Config.PushNotifications) + { + _globalOverrideValues.Add("globalSettings__pushRelayBaseUri", "REPLACE"); } - public void BuildForUpdater() + _mssqlOverrideValues = new Dictionary { - Init(); - LoadExistingValues(_globalOverrideValues, "/bitwarden/env/global.override.env"); - LoadExistingValues(_mssqlOverrideValues, "/bitwarden/env/mssql.override.env"); - LoadExistingValues(_keyConnectorOverrideValues, "/bitwarden/env/key-connector.override.env"); + ["SA_PASSWORD"] = dbPassword, + }; - if (_context.Config.PushNotifications && - _globalOverrideValues.ContainsKey("globalSettings__pushRelayBaseUri") && - _globalOverrideValues["globalSettings__pushRelayBaseUri"] == "REPLACE") - { - _globalOverrideValues.Remove("globalSettings__pushRelayBaseUri"); - } + _keyConnectorOverrideValues = new Dictionary + { + ["keyConnectorSettings__webVaultUri"] = _context.Config.Url, + ["keyConnectorSettings__identityServerUri"] = "http://identity:5000", + ["keyConnectorSettings__database__provider"] = "json", + ["keyConnectorSettings__database__jsonFilePath"] = "/etc/bitwarden/key-connector/data.json", + ["keyConnectorSettings__rsaKey__provider"] = "certificate", + ["keyConnectorSettings__certificate__provider"] = "filesystem", + ["keyConnectorSettings__certificate__filesystemPath"] = "/etc/bitwarden/key-connector/bwkc.pfx", + ["keyConnectorSettings__certificate__filesystemPassword"] = Helpers.SecureRandomString(32, alpha: true, numeric: true), + }; + } - Build(); + private void LoadExistingValues(IDictionary _values, string file) + { + if (!File.Exists(file)) + { + return; } - private void Init() + var fileLines = File.ReadAllLines(file); + foreach (var line in fileLines) { - var dbPassword = _context.Stub ? "RANDOM_DATABASE_PASSWORD" : Helpers.SecureRandomString(32); - var dbConnectionString = new SqlConnectionStringBuilder + if (!line.Contains("=")) { - DataSource = "tcp:mssql,1433", - InitialCatalog = _context.Install?.Database ?? "vault", - UserID = "sa", - Password = dbPassword, - MultipleActiveResultSets = false, - Encrypt = true, - ConnectTimeout = 30, - TrustServerCertificate = true, - PersistSecurityInfo = false - }.ConnectionString; - - _globalOverrideValues = new Dictionary - { - ["globalSettings__baseServiceUri__vault"] = _context.Config.Url, - ["globalSettings__sqlServer__connectionString"] = $"\"{dbConnectionString.Replace("\"", "\\\"")}\"", - ["globalSettings__identityServer__certificatePassword"] = _context.Install?.IdentityCertPassword, - ["globalSettings__internalIdentityKey"] = _context.Stub ? "RANDOM_IDENTITY_KEY" : - Helpers.SecureRandomString(64, alpha: true, numeric: true), - ["globalSettings__oidcIdentityClientKey"] = _context.Stub ? "RANDOM_IDENTITY_KEY" : - Helpers.SecureRandomString(64, alpha: true, numeric: true), - ["globalSettings__duo__aKey"] = _context.Stub ? "RANDOM_DUO_AKEY" : - Helpers.SecureRandomString(64, alpha: true, numeric: true), - ["globalSettings__installation__id"] = _context.Install?.InstallationId.ToString(), - ["globalSettings__installation__key"] = _context.Install?.InstallationKey, - ["globalSettings__yubico__clientId"] = "REPLACE", - ["globalSettings__yubico__key"] = "REPLACE", - ["globalSettings__mail__replyToEmail"] = $"no-reply@{_context.Config.Domain}", - ["globalSettings__mail__smtp__host"] = "REPLACE", - ["globalSettings__mail__smtp__port"] = "587", - ["globalSettings__mail__smtp__ssl"] = "false", - ["globalSettings__mail__smtp__username"] = "REPLACE", - ["globalSettings__mail__smtp__password"] = "REPLACE", - ["globalSettings__disableUserRegistration"] = "false", - ["globalSettings__hibpApiKey"] = "REPLACE", - ["adminSettings__admins"] = string.Empty, - }; - - if (!_context.Config.PushNotifications) - { - _globalOverrideValues.Add("globalSettings__pushRelayBaseUri", "REPLACE"); + continue; } - _mssqlOverrideValues = new Dictionary + var value = string.Empty; + var lineParts = line.Split("=", 2); + if (lineParts.Length < 1) { - ["SA_PASSWORD"] = dbPassword, - }; - - _keyConnectorOverrideValues = new Dictionary - { - ["keyConnectorSettings__webVaultUri"] = _context.Config.Url, - ["keyConnectorSettings__identityServerUri"] = "http://identity:5000", - ["keyConnectorSettings__database__provider"] = "json", - ["keyConnectorSettings__database__jsonFilePath"] = "/etc/bitwarden/key-connector/data.json", - ["keyConnectorSettings__rsaKey__provider"] = "certificate", - ["keyConnectorSettings__certificate__provider"] = "filesystem", - ["keyConnectorSettings__certificate__filesystemPath"] = "/etc/bitwarden/key-connector/bwkc.pfx", - ["keyConnectorSettings__certificate__filesystemPassword"] = Helpers.SecureRandomString(32, alpha: true, numeric: true), - }; - } - - private void LoadExistingValues(IDictionary _values, string file) - { - if (!File.Exists(file)) - { - return; + continue; } - var fileLines = File.ReadAllLines(file); - foreach (var line in fileLines) + if (lineParts.Length > 1) { - if (!line.Contains("=")) - { - continue; - } - - var value = string.Empty; - var lineParts = line.Split("=", 2); - if (lineParts.Length < 1) - { - continue; - } - - if (lineParts.Length > 1) - { - value = lineParts[1]; - } - - if (_values.ContainsKey(lineParts[0])) - { - _values[lineParts[0]] = value; - } - else - { - _values.Add(lineParts[0], value.Replace("\\\"", "\"")); - } - } - } - - private void Build() - { - var template = Helpers.ReadTemplate("EnvironmentFile"); - - Helpers.WriteLine(_context, "Building docker environment files."); - Directory.CreateDirectory("/bitwarden/docker/"); - using (var sw = File.CreateText("/bitwarden/docker/global.env")) - { - sw.Write(template(new TemplateModel(_globalValues))); - } - Helpers.Exec("chmod 600 /bitwarden/docker/global.env"); - - using (var sw = File.CreateText("/bitwarden/docker/mssql.env")) - { - sw.Write(template(new TemplateModel(_mssqlValues))); - } - Helpers.Exec("chmod 600 /bitwarden/docker/mssql.env"); - - Helpers.WriteLine(_context, "Building docker environment override files."); - Directory.CreateDirectory("/bitwarden/env/"); - using (var sw = File.CreateText("/bitwarden/env/global.override.env")) - { - sw.Write(template(new TemplateModel(_globalOverrideValues))); - } - Helpers.Exec("chmod 600 /bitwarden/env/global.override.env"); - - using (var sw = File.CreateText("/bitwarden/env/mssql.override.env")) - { - sw.Write(template(new TemplateModel(_mssqlOverrideValues))); - } - Helpers.Exec("chmod 600 /bitwarden/env/mssql.override.env"); - - if (_context.Config.EnableKeyConnector) - { - using (var sw = File.CreateText("/bitwarden/env/key-connector.override.env")) - { - sw.Write(template(new TemplateModel(_keyConnectorOverrideValues))); - } - - Helpers.Exec("chmod 600 /bitwarden/env/key-connector.override.env"); + value = lineParts[1]; } - // Empty uid env file. Only used on Linux hosts. - if (!File.Exists("/bitwarden/env/uid.env")) + if (_values.ContainsKey(lineParts[0])) { - using (var sw = File.CreateText("/bitwarden/env/uid.env")) { } + _values[lineParts[0]] = value; } - } - - public class TemplateModel - { - public TemplateModel(IEnumerable> variables) + else { - Variables = variables.Select(v => new Kvp { Key = v.Key, Value = v.Value }); - } - - public IEnumerable Variables { get; set; } - - public class Kvp - { - public string Key { get; set; } - public string Value { get; set; } + _values.Add(lineParts[0], value.Replace("\\\"", "\"")); } } } + + private void Build() + { + var template = Helpers.ReadTemplate("EnvironmentFile"); + + Helpers.WriteLine(_context, "Building docker environment files."); + Directory.CreateDirectory("/bitwarden/docker/"); + using (var sw = File.CreateText("/bitwarden/docker/global.env")) + { + sw.Write(template(new TemplateModel(_globalValues))); + } + Helpers.Exec("chmod 600 /bitwarden/docker/global.env"); + + using (var sw = File.CreateText("/bitwarden/docker/mssql.env")) + { + sw.Write(template(new TemplateModel(_mssqlValues))); + } + Helpers.Exec("chmod 600 /bitwarden/docker/mssql.env"); + + Helpers.WriteLine(_context, "Building docker environment override files."); + Directory.CreateDirectory("/bitwarden/env/"); + using (var sw = File.CreateText("/bitwarden/env/global.override.env")) + { + sw.Write(template(new TemplateModel(_globalOverrideValues))); + } + Helpers.Exec("chmod 600 /bitwarden/env/global.override.env"); + + using (var sw = File.CreateText("/bitwarden/env/mssql.override.env")) + { + sw.Write(template(new TemplateModel(_mssqlOverrideValues))); + } + Helpers.Exec("chmod 600 /bitwarden/env/mssql.override.env"); + + if (_context.Config.EnableKeyConnector) + { + using (var sw = File.CreateText("/bitwarden/env/key-connector.override.env")) + { + sw.Write(template(new TemplateModel(_keyConnectorOverrideValues))); + } + + Helpers.Exec("chmod 600 /bitwarden/env/key-connector.override.env"); + } + + // Empty uid env file. Only used on Linux hosts. + if (!File.Exists("/bitwarden/env/uid.env")) + { + using (var sw = File.CreateText("/bitwarden/env/uid.env")) { } + } + } + + public class TemplateModel + { + public TemplateModel(IEnumerable> variables) + { + Variables = variables.Select(v => new Kvp { Key = v.Key, Value = v.Value }); + } + + public IEnumerable Variables { get; set; } + + public class Kvp + { + public string Key { get; set; } + public string Value { get; set; } + } + } } diff --git a/util/Setup/Helpers.cs b/util/Setup/Helpers.cs index 06c48f2fe..ea7351b98 100644 --- a/util/Setup/Helpers.cs +++ b/util/Setup/Helpers.cs @@ -4,223 +4,222 @@ using System.Runtime.InteropServices; using System.Security.Cryptography; using System.Text; -namespace Bit.Setup +namespace Bit.Setup; + +public static class Helpers { - public static class Helpers + public static string SecureRandomString(int length, bool alpha = true, bool upper = true, bool lower = true, + bool numeric = true, bool special = false) { - public static string SecureRandomString(int length, bool alpha = true, bool upper = true, bool lower = true, - bool numeric = true, bool special = false) + return SecureRandomString(length, RandomStringCharacters(alpha, upper, lower, numeric, special)); + } + + // ref https://stackoverflow.com/a/8996788/1090359 with modifications + public static string SecureRandomString(int length, string characters) + { + if (length < 0) { - return SecureRandomString(length, RandomStringCharacters(alpha, upper, lower, numeric, special)); + throw new ArgumentOutOfRangeException(nameof(length), "length cannot be less than zero."); } - // ref https://stackoverflow.com/a/8996788/1090359 with modifications - public static string SecureRandomString(int length, string characters) + if ((characters?.Length ?? 0) == 0) { - if (length < 0) - { - throw new ArgumentOutOfRangeException(nameof(length), "length cannot be less than zero."); - } + throw new ArgumentOutOfRangeException(nameof(characters), "characters invalid."); + } - if ((characters?.Length ?? 0) == 0) - { - throw new ArgumentOutOfRangeException(nameof(characters), "characters invalid."); - } + const int byteSize = 0x100; + if (byteSize < characters.Length) + { + throw new ArgumentException( + string.Format("{0} may contain no more than {1} characters.", nameof(characters), byteSize), + nameof(characters)); + } - const int byteSize = 0x100; - if (byteSize < characters.Length) + var outOfRangeStart = byteSize - (byteSize % characters.Length); + using (var rng = RandomNumberGenerator.Create()) + { + var sb = new StringBuilder(); + var buffer = new byte[128]; + while (sb.Length < length) { - throw new ArgumentException( - string.Format("{0} may contain no more than {1} characters.", nameof(characters), byteSize), - nameof(characters)); - } - - var outOfRangeStart = byteSize - (byteSize % characters.Length); - using (var rng = RandomNumberGenerator.Create()) - { - var sb = new StringBuilder(); - var buffer = new byte[128]; - while (sb.Length < length) + rng.GetBytes(buffer); + for (var i = 0; i < buffer.Length && sb.Length < length; ++i) { - rng.GetBytes(buffer); - for (var i = 0; i < buffer.Length && sb.Length < length; ++i) + // Divide the byte into charSet-sized groups. If the random value falls into the last group and the + // last group is too small to choose from the entire allowedCharSet, ignore the value in order to + // avoid biasing the result. + if (outOfRangeStart <= buffer[i]) { - // Divide the byte into charSet-sized groups. If the random value falls into the last group and the - // last group is too small to choose from the entire allowedCharSet, ignore the value in order to - // avoid biasing the result. - if (outOfRangeStart <= buffer[i]) - { - continue; - } - - sb.Append(characters[buffer[i] % characters.Length]); + continue; } - } - return sb.ToString(); + sb.Append(characters[buffer[i] % characters.Length]); + } + } + + return sb.ToString(); + } + } + + private static string RandomStringCharacters(bool alpha, bool upper, bool lower, bool numeric, bool special) + { + var characters = string.Empty; + if (alpha) + { + if (upper) + { + characters += "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + } + + if (lower) + { + characters += "abcdefghijklmnopqrstuvwxyz"; } } - private static string RandomStringCharacters(bool alpha, bool upper, bool lower, bool numeric, bool special) + if (numeric) { - var characters = string.Empty; - if (alpha) - { - if (upper) - { - characters += "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; - } - - if (lower) - { - characters += "abcdefghijklmnopqrstuvwxyz"; - } - } - - if (numeric) - { - characters += "0123456789"; - } - - if (special) - { - characters += "!@#$%^*&"; - } - - return characters; + characters += "0123456789"; } - public static string GetValueFromEnvFile(string envFile, string key) + if (special) { - if (!File.Exists($"/bitwarden/env/{envFile}.override.env")) - { - return null; - } + characters += "!@#$%^*&"; + } - var lines = File.ReadAllLines($"/bitwarden/env/{envFile}.override.env"); - foreach (var line in lines) - { - if (line.StartsWith($"{key}=")) - { - return line.Split(new char[] { '=' }, 2)[1].Trim('"').Replace("\\\"", "\""); - } - } + return characters; + } + public static string GetValueFromEnvFile(string envFile, string key) + { + if (!File.Exists($"/bitwarden/env/{envFile}.override.env")) + { return null; } - public static string Exec(string cmd, bool returnStdout = false) + var lines = File.ReadAllLines($"/bitwarden/env/{envFile}.override.env"); + foreach (var line in lines) { - var process = new Process + if (line.StartsWith($"{key}=")) { - StartInfo = new ProcessStartInfo - { - RedirectStandardOutput = true, - UseShellExecute = false, - CreateNoWindow = true, - WindowStyle = ProcessWindowStyle.Hidden - } - }; - - if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - var escapedArgs = cmd.Replace("\"", "\\\""); - process.StartInfo.FileName = "/bin/bash"; - process.StartInfo.Arguments = $"-c \"{escapedArgs}\""; + return line.Split(new char[] { '=' }, 2)[1].Trim('"').Replace("\\\"", "\""); } - else - { - process.StartInfo.FileName = "powershell"; - process.StartInfo.Arguments = cmd; - } - - process.Start(); - var result = returnStdout ? process.StandardOutput.ReadToEnd() : null; - process.WaitForExit(); - return result; } - public static string ReadInput(string prompt) + return null; + } + + public static string Exec(string cmd, bool returnStdout = false) + { + var process = new Process { - Console.ForegroundColor = ConsoleColor.Cyan; - Console.Write("(!) "); - Console.ResetColor(); - Console.Write(prompt); - if (prompt.EndsWith("?")) + StartInfo = new ProcessStartInfo { - Console.Write(" (y/n)"); + RedirectStandardOutput = true, + UseShellExecute = false, + CreateNoWindow = true, + WindowStyle = ProcessWindowStyle.Hidden } - Console.Write(": "); - var input = Console.ReadLine(); + }; + + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + var escapedArgs = cmd.Replace("\"", "\\\""); + process.StartInfo.FileName = "/bin/bash"; + process.StartInfo.Arguments = $"-c \"{escapedArgs}\""; + } + else + { + process.StartInfo.FileName = "powershell"; + process.StartInfo.Arguments = cmd; + } + + process.Start(); + var result = returnStdout ? process.StandardOutput.ReadToEnd() : null; + process.WaitForExit(); + return result; + } + + public static string ReadInput(string prompt) + { + Console.ForegroundColor = ConsoleColor.Cyan; + Console.Write("(!) "); + Console.ResetColor(); + Console.Write(prompt); + if (prompt.EndsWith("?")) + { + Console.Write(" (y/n)"); + } + Console.Write(": "); + var input = Console.ReadLine(); + Console.WriteLine(); + return input; + } + + public static bool ReadQuestion(string prompt) + { + var input = ReadInput(prompt).ToLowerInvariant().Trim(); + return input == "y" || input == "yes"; + } + + public static void ShowBanner(Context context, string title, string message, ConsoleColor? color = null) + { + if (!context.PrintToScreen()) + { + return; + } + if (color != null) + { + Console.ForegroundColor = color.Value; + } + Console.WriteLine($"!!!!!!!!!! {title} !!!!!!!!!!"); + Console.WriteLine(message); + Console.WriteLine(); + Console.ResetColor(); + } + + public static HandlebarsDotNet.HandlebarsTemplate ReadTemplate(string templateName) + { + var assembly = typeof(Helpers).GetTypeInfo().Assembly; + var fullTemplateName = $"Bit.Setup.Templates.{templateName}.hbs"; + if (!assembly.GetManifestResourceNames().Any(f => f == fullTemplateName)) + { + return null; + } + using (var s = assembly.GetManifestResourceStream(fullTemplateName)) + using (var sr = new StreamReader(s)) + { + var templateText = sr.ReadToEnd(); + return HandlebarsDotNet.Handlebars.Compile(templateText); + } + } + + public static void WriteLine(Context context, string format = null, object arg0 = null, object arg1 = null, + object arg2 = null) + { + if (!context.PrintToScreen()) + { + return; + } + if (format != null && arg0 != null && arg1 != null && arg2 != null) + { + Console.WriteLine(format, arg0, arg1, arg2); + } + else if (format != null && arg0 != null && arg1 != null) + { + Console.WriteLine(format, arg0, arg1); + } + else if (format != null && arg0 != null) + { + Console.WriteLine(format, arg0); + } + else if (format != null) + { + Console.WriteLine(format); + } + else + { Console.WriteLine(); - return input; - } - - public static bool ReadQuestion(string prompt) - { - var input = ReadInput(prompt).ToLowerInvariant().Trim(); - return input == "y" || input == "yes"; - } - - public static void ShowBanner(Context context, string title, string message, ConsoleColor? color = null) - { - if (!context.PrintToScreen()) - { - return; - } - if (color != null) - { - Console.ForegroundColor = color.Value; - } - Console.WriteLine($"!!!!!!!!!! {title} !!!!!!!!!!"); - Console.WriteLine(message); - Console.WriteLine(); - Console.ResetColor(); - } - - public static HandlebarsDotNet.HandlebarsTemplate ReadTemplate(string templateName) - { - var assembly = typeof(Helpers).GetTypeInfo().Assembly; - var fullTemplateName = $"Bit.Setup.Templates.{templateName}.hbs"; - if (!assembly.GetManifestResourceNames().Any(f => f == fullTemplateName)) - { - return null; - } - using (var s = assembly.GetManifestResourceStream(fullTemplateName)) - using (var sr = new StreamReader(s)) - { - var templateText = sr.ReadToEnd(); - return HandlebarsDotNet.Handlebars.Compile(templateText); - } - } - - public static void WriteLine(Context context, string format = null, object arg0 = null, object arg1 = null, - object arg2 = null) - { - if (!context.PrintToScreen()) - { - return; - } - if (format != null && arg0 != null && arg1 != null && arg2 != null) - { - Console.WriteLine(format, arg0, arg1, arg2); - } - else if (format != null && arg0 != null && arg1 != null) - { - Console.WriteLine(format, arg0, arg1); - } - else if (format != null && arg0 != null) - { - Console.WriteLine(format, arg0); - } - else if (format != null) - { - Console.WriteLine(format); - } - else - { - Console.WriteLine(); - } } } } diff --git a/util/Setup/NginxConfigBuilder.cs b/util/Setup/NginxConfigBuilder.cs index f2ad08ced..420793cef 100644 --- a/util/Setup/NginxConfigBuilder.cs +++ b/util/Setup/NginxConfigBuilder.cs @@ -1,133 +1,132 @@ -namespace Bit.Setup +namespace Bit.Setup; + +public class NginxConfigBuilder { - public class NginxConfigBuilder + private const string ConfFile = "/bitwarden/nginx/default.conf"; + + private readonly Context _context; + + public NginxConfigBuilder(Context context) { - private const string ConfFile = "/bitwarden/nginx/default.conf"; + _context = context; + } - private readonly Context _context; - - public NginxConfigBuilder(Context context) + public void BuildForInstaller() + { + var model = new TemplateModel(_context); + if (model.Ssl && !_context.Config.SslManagedLetsEncrypt) { - _context = context; - } - - public void BuildForInstaller() - { - var model = new TemplateModel(_context); - if (model.Ssl && !_context.Config.SslManagedLetsEncrypt) + var sslPath = _context.Install.SelfSignedCert ? + $"/etc/ssl/self/{model.Domain}" : $"/etc/ssl/{model.Domain}"; + _context.Config.SslCertificatePath = model.CertificatePath = + string.Concat(sslPath, "/", "certificate.crt"); + _context.Config.SslKeyPath = model.KeyPath = + string.Concat(sslPath, "/", "private.key"); + if (_context.Install.Trusted) { - var sslPath = _context.Install.SelfSignedCert ? - $"/etc/ssl/self/{model.Domain}" : $"/etc/ssl/{model.Domain}"; - _context.Config.SslCertificatePath = model.CertificatePath = - string.Concat(sslPath, "/", "certificate.crt"); - _context.Config.SslKeyPath = model.KeyPath = - string.Concat(sslPath, "/", "private.key"); - if (_context.Install.Trusted) - { - _context.Config.SslCaPath = model.CaPath = - string.Concat(sslPath, "/", "ca.crt"); - } - if (_context.Install.DiffieHellman) - { - _context.Config.SslDiffieHellmanPath = model.DiffieHellmanPath = - string.Concat(sslPath, "/", "dhparam.pem"); - } + _context.Config.SslCaPath = model.CaPath = + string.Concat(sslPath, "/", "ca.crt"); } - Build(model); - } - - public void BuildForUpdater() - { - var model = new TemplateModel(_context); - Build(model); - } - - private void Build(TemplateModel model) - { - Directory.CreateDirectory("/bitwarden/nginx/"); - Helpers.WriteLine(_context, "Building nginx config."); - if (!_context.Config.GenerateNginxConfig) + if (_context.Install.DiffieHellman) { - Helpers.WriteLine(_context, "...skipped"); - return; - } - - var template = Helpers.ReadTemplate("NginxConfig"); - using (var sw = File.CreateText(ConfFile)) - { - sw.WriteLine(template(model)); + _context.Config.SslDiffieHellmanPath = model.DiffieHellmanPath = + string.Concat(sslPath, "/", "dhparam.pem"); } } + Build(model); + } - public class TemplateModel + public void BuildForUpdater() + { + var model = new TemplateModel(_context); + Build(model); + } + + private void Build(TemplateModel model) + { + Directory.CreateDirectory("/bitwarden/nginx/"); + Helpers.WriteLine(_context, "Building nginx config."); + if (!_context.Config.GenerateNginxConfig) { - public TemplateModel() { } + Helpers.WriteLine(_context, "...skipped"); + return; + } - public TemplateModel(Context context) - { - Captcha = context.Config.Captcha; - Ssl = context.Config.Ssl; - EnableKeyConnector = context.Config.EnableKeyConnector; - EnableScim = context.Config.EnableScim; - Domain = context.Config.Domain; - Url = context.Config.Url; - RealIps = context.Config.RealIps; - ContentSecurityPolicy = string.Format(context.Config.NginxHeaderContentSecurityPolicy, Domain); - - if (Ssl) - { - if (context.Config.SslManagedLetsEncrypt) - { - var sslPath = $"/etc/letsencrypt/live/{Domain}"; - CertificatePath = CaPath = string.Concat(sslPath, "/", "fullchain.pem"); - KeyPath = string.Concat(sslPath, "/", "privkey.pem"); - DiffieHellmanPath = string.Concat(sslPath, "/", "dhparam.pem"); - } - else - { - CertificatePath = context.Config.SslCertificatePath; - KeyPath = context.Config.SslKeyPath; - CaPath = context.Config.SslCaPath; - DiffieHellmanPath = context.Config.SslDiffieHellmanPath; - } - } - - if (!string.IsNullOrWhiteSpace(context.Config.SslCiphersuites)) - { - SslCiphers = context.Config.SslCiphersuites; - } - else - { - SslCiphers = "ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:" + - "ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:" + - "ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:" + - "ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256"; - } - - if (!string.IsNullOrWhiteSpace(context.Config.SslVersions)) - { - SslProtocols = context.Config.SslVersions; - } - else - { - SslProtocols = "TLSv1.2"; - } - } - - public bool Captcha { get; set; } - public bool Ssl { get; set; } - public bool EnableKeyConnector { get; set; } - public bool EnableScim { get; set; } - public string Domain { get; set; } - public string Url { get; set; } - public string CertificatePath { get; set; } - public string KeyPath { get; set; } - public string CaPath { get; set; } - public string DiffieHellmanPath { get; set; } - public string SslCiphers { get; set; } - public string SslProtocols { get; set; } - public string ContentSecurityPolicy { get; set; } - public List RealIps { get; set; } + var template = Helpers.ReadTemplate("NginxConfig"); + using (var sw = File.CreateText(ConfFile)) + { + sw.WriteLine(template(model)); } } + + public class TemplateModel + { + public TemplateModel() { } + + public TemplateModel(Context context) + { + Captcha = context.Config.Captcha; + Ssl = context.Config.Ssl; + EnableKeyConnector = context.Config.EnableKeyConnector; + EnableScim = context.Config.EnableScim; + Domain = context.Config.Domain; + Url = context.Config.Url; + RealIps = context.Config.RealIps; + ContentSecurityPolicy = string.Format(context.Config.NginxHeaderContentSecurityPolicy, Domain); + + if (Ssl) + { + if (context.Config.SslManagedLetsEncrypt) + { + var sslPath = $"/etc/letsencrypt/live/{Domain}"; + CertificatePath = CaPath = string.Concat(sslPath, "/", "fullchain.pem"); + KeyPath = string.Concat(sslPath, "/", "privkey.pem"); + DiffieHellmanPath = string.Concat(sslPath, "/", "dhparam.pem"); + } + else + { + CertificatePath = context.Config.SslCertificatePath; + KeyPath = context.Config.SslKeyPath; + CaPath = context.Config.SslCaPath; + DiffieHellmanPath = context.Config.SslDiffieHellmanPath; + } + } + + if (!string.IsNullOrWhiteSpace(context.Config.SslCiphersuites)) + { + SslCiphers = context.Config.SslCiphersuites; + } + else + { + SslCiphers = "ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:" + + "ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:" + + "ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:" + + "ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256"; + } + + if (!string.IsNullOrWhiteSpace(context.Config.SslVersions)) + { + SslProtocols = context.Config.SslVersions; + } + else + { + SslProtocols = "TLSv1.2"; + } + } + + public bool Captcha { get; set; } + public bool Ssl { get; set; } + public bool EnableKeyConnector { get; set; } + public bool EnableScim { get; set; } + public string Domain { get; set; } + public string Url { get; set; } + public string CertificatePath { get; set; } + public string KeyPath { get; set; } + public string CaPath { get; set; } + public string DiffieHellmanPath { get; set; } + public string SslCiphers { get; set; } + public string SslProtocols { get; set; } + public string ContentSecurityPolicy { get; set; } + public List RealIps { get; set; } + } } diff --git a/util/Setup/Program.cs b/util/Setup/Program.cs index 8eb6474fd..507b329b2 100644 --- a/util/Setup/Program.cs +++ b/util/Setup/Program.cs @@ -3,328 +3,327 @@ using System.Globalization; using System.Net.Http.Json; using Bit.Migrator; -namespace Bit.Setup +namespace Bit.Setup; + +public class Program { - public class Program + private static Context _context; + + public static void Main(string[] args) { - private static Context _context; + CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); - public static void Main(string[] args) + _context = new Context { - CultureInfo.DefaultThreadCurrentCulture = new CultureInfo("en-US"); + Args = args + }; + ParseParameters(); - _context = new Context - { - Args = args - }; - ParseParameters(); + if (_context.Parameters.ContainsKey("q")) + { + _context.Quiet = _context.Parameters["q"] == "true" || _context.Parameters["q"] == "1"; + } + if (_context.Parameters.ContainsKey("os")) + { + _context.HostOS = _context.Parameters["os"]; + } + if (_context.Parameters.ContainsKey("corev")) + { + _context.CoreVersion = _context.Parameters["corev"]; + } + if (_context.Parameters.ContainsKey("webv")) + { + _context.WebVersion = _context.Parameters["webv"]; + } + if (_context.Parameters.ContainsKey("keyconnectorv")) + { + _context.KeyConnectorVersion = _context.Parameters["keyconnectorv"]; + } + if (_context.Parameters.ContainsKey("stub")) + { + _context.Stub = _context.Parameters["stub"] == "true" || + _context.Parameters["stub"] == "1"; + } - if (_context.Parameters.ContainsKey("q")) - { - _context.Quiet = _context.Parameters["q"] == "true" || _context.Parameters["q"] == "1"; - } - if (_context.Parameters.ContainsKey("os")) - { - _context.HostOS = _context.Parameters["os"]; - } - if (_context.Parameters.ContainsKey("corev")) - { - _context.CoreVersion = _context.Parameters["corev"]; - } - if (_context.Parameters.ContainsKey("webv")) - { - _context.WebVersion = _context.Parameters["webv"]; - } - if (_context.Parameters.ContainsKey("keyconnectorv")) - { - _context.KeyConnectorVersion = _context.Parameters["keyconnectorv"]; - } - if (_context.Parameters.ContainsKey("stub")) - { - _context.Stub = _context.Parameters["stub"] == "true" || - _context.Parameters["stub"] == "1"; - } + Helpers.WriteLine(_context); - Helpers.WriteLine(_context); + if (_context.Parameters.ContainsKey("install")) + { + Install(); + } + else if (_context.Parameters.ContainsKey("update")) + { + Update(); + } + else if (_context.Parameters.ContainsKey("printenv")) + { + PrintEnvironment(); + } + else + { + Helpers.WriteLine(_context, "No top-level command detected. Exiting..."); + } + } - if (_context.Parameters.ContainsKey("install")) + private static void Install() + { + if (_context.Parameters.ContainsKey("letsencrypt")) + { + _context.Config.SslManagedLetsEncrypt = + _context.Parameters["letsencrypt"].ToLowerInvariant() == "y"; + } + if (_context.Parameters.ContainsKey("domain")) + { + _context.Install.Domain = _context.Parameters["domain"].ToLowerInvariant(); + } + if (_context.Parameters.ContainsKey("dbname")) + { + _context.Install.Database = _context.Parameters["dbname"]; + } + + if (_context.Stub) + { + _context.Install.InstallationId = Guid.Empty; + _context.Install.InstallationKey = "SECRET_INSTALLATION_KEY"; + } + else if (!ValidateInstallation()) + { + return; + } + + var certBuilder = new CertBuilder(_context); + certBuilder.BuildForInstall(); + + // Set the URL + _context.Config.Url = string.Format("http{0}://{1}", + _context.Config.Ssl ? "s" : string.Empty, _context.Install.Domain); + + var nginxBuilder = new NginxConfigBuilder(_context); + nginxBuilder.BuildForInstaller(); + + var environmentFileBuilder = new EnvironmentFileBuilder(_context); + environmentFileBuilder.BuildForInstaller(); + + var appIdBuilder = new AppIdBuilder(_context); + appIdBuilder.Build(); + + var dockerComposeBuilder = new DockerComposeBuilder(_context); + dockerComposeBuilder.BuildForInstaller(); + + _context.SaveConfiguration(); + + Console.WriteLine("\nInstallation complete"); + + Console.WriteLine("\nIf you need to make additional configuration changes, you can modify\n" + + "the settings in `{0}` and then run:\n{1}", + _context.HostOS == "win" ? ".\\bwdata\\config.yml" : "./bwdata/config.yml", + _context.HostOS == "win" ? "`.\\bitwarden.ps1 -rebuild` or `.\\bitwarden.ps1 -update`" : + "`./bitwarden.sh rebuild` or `./bitwarden.sh update`"); + + Console.WriteLine("\nNext steps, run:"); + if (_context.HostOS == "win") + { + Console.WriteLine("`.\\bitwarden.ps1 -start`"); + } + else + { + Console.WriteLine("`./bitwarden.sh start`"); + } + Console.WriteLine(string.Empty); + } + + private static void Update() + { + // This portion of code checks for multiple certs in the Identity.pfx PKCS12 bag. If found, it generates + // a new cert and bag to replace the old Identity.pfx. This fixes an issue that came up as a result of + // moving the project to .NET 5. + _context.Install.IdentityCertPassword = Helpers.GetValueFromEnvFile("global", "globalSettings__identityServer__certificatePassword"); + var certCountString = Helpers.Exec("openssl pkcs12 -nokeys -info -in /bitwarden/identity/identity.pfx " + + $"-passin pass:{_context.Install.IdentityCertPassword} 2> /dev/null | grep -c \"\\-----BEGIN CERTIFICATE----\"", true); + if (int.TryParse(certCountString, out var certCount) && certCount > 1) + { + // Extract key from identity.pfx + Helpers.Exec("openssl pkcs12 -in /bitwarden/identity/identity.pfx -nocerts -nodes -out identity.key " + + $"-passin pass:{_context.Install.IdentityCertPassword} > /dev/null 2>&1"); + // Extract certificate from identity.pfx + Helpers.Exec("openssl pkcs12 -in /bitwarden/identity/identity.pfx -clcerts -nokeys -out identity.crt " + + $"-passin pass:{_context.Install.IdentityCertPassword} > /dev/null 2>&1"); + // Create new PKCS12 bag with certificate and key + Helpers.Exec("openssl pkcs12 -export -out /bitwarden/identity/identity.pfx -inkey identity.key " + + $"-in identity.crt -passout pass:{_context.Install.IdentityCertPassword} > /dev/null 2>&1"); + } + + if (_context.Parameters.ContainsKey("db")) + { + MigrateDatabase(); + } + else + { + RebuildConfigs(); + } + } + + private static void PrintEnvironment() + { + _context.LoadConfiguration(); + if (!_context.PrintToScreen()) + { + return; + } + Console.WriteLine("\nBitwarden is up and running!"); + Console.WriteLine("==================================================="); + Console.WriteLine("\nvisit {0}", _context.Config.Url); + Console.Write("to update, run "); + if (_context.HostOS == "win") + { + Console.Write("`.\\bitwarden.ps1 -updateself` and then `.\\bitwarden.ps1 -update`"); + } + else + { + Console.Write("`./bitwarden.sh updateself` and then `./bitwarden.sh update`"); + } + Console.WriteLine("\n"); + } + + private static void MigrateDatabase(int attempt = 1) + { + try + { + Helpers.WriteLine(_context, "Migrating database."); + var vaultConnectionString = Helpers.GetValueFromEnvFile("global", + "globalSettings__sqlServer__connectionString"); + var migrator = new DbMigrator(vaultConnectionString, null); + var success = migrator.MigrateMsSqlDatabase(false); + if (success) { - Install(); - } - else if (_context.Parameters.ContainsKey("update")) - { - Update(); - } - else if (_context.Parameters.ContainsKey("printenv")) - { - PrintEnvironment(); + Helpers.WriteLine(_context, "Migration successful."); } else { - Helpers.WriteLine(_context, "No top-level command detected. Exiting..."); + Helpers.WriteLine(_context, "Migration failed."); } } - - private static void Install() + catch (SqlException e) { - if (_context.Parameters.ContainsKey("letsencrypt")) - { - _context.Config.SslManagedLetsEncrypt = - _context.Parameters["letsencrypt"].ToLowerInvariant() == "y"; - } - if (_context.Parameters.ContainsKey("domain")) - { - _context.Install.Domain = _context.Parameters["domain"].ToLowerInvariant(); - } - if (_context.Parameters.ContainsKey("dbname")) - { - _context.Install.Database = _context.Parameters["dbname"]; - } - - if (_context.Stub) - { - _context.Install.InstallationId = Guid.Empty; - _context.Install.InstallationKey = "SECRET_INSTALLATION_KEY"; - } - else if (!ValidateInstallation()) + if (e.Message.Contains("Server is in script upgrade mode") && attempt < 10) { + var nextAttempt = attempt + 1; + Helpers.WriteLine(_context, "Database is in script upgrade mode. " + + "Trying again (attempt #{0})...", nextAttempt); + System.Threading.Thread.Sleep(20000); + MigrateDatabase(nextAttempt); return; } + throw; + } + } - var certBuilder = new CertBuilder(_context); - certBuilder.BuildForInstall(); + private static bool ValidateInstallation() + { + var installationId = string.Empty; + var installationKey = string.Empty; - // Set the URL - _context.Config.Url = string.Format("http{0}://{1}", - _context.Config.Ssl ? "s" : string.Empty, _context.Install.Domain); - - var nginxBuilder = new NginxConfigBuilder(_context); - nginxBuilder.BuildForInstaller(); - - var environmentFileBuilder = new EnvironmentFileBuilder(_context); - environmentFileBuilder.BuildForInstaller(); - - var appIdBuilder = new AppIdBuilder(_context); - appIdBuilder.Build(); - - var dockerComposeBuilder = new DockerComposeBuilder(_context); - dockerComposeBuilder.BuildForInstaller(); - - _context.SaveConfiguration(); - - Console.WriteLine("\nInstallation complete"); - - Console.WriteLine("\nIf you need to make additional configuration changes, you can modify\n" + - "the settings in `{0}` and then run:\n{1}", - _context.HostOS == "win" ? ".\\bwdata\\config.yml" : "./bwdata/config.yml", - _context.HostOS == "win" ? "`.\\bitwarden.ps1 -rebuild` or `.\\bitwarden.ps1 -update`" : - "`./bitwarden.sh rebuild` or `./bitwarden.sh update`"); - - Console.WriteLine("\nNext steps, run:"); - if (_context.HostOS == "win") - { - Console.WriteLine("`.\\bitwarden.ps1 -start`"); - } - else - { - Console.WriteLine("`./bitwarden.sh start`"); - } - Console.WriteLine(string.Empty); + if (_context.Parameters.ContainsKey("install-id")) + { + installationId = _context.Parameters["install-id"].ToLowerInvariant(); + } + else + { + installationId = Helpers.ReadInput("Enter your installation id (get at https://bitwarden.com/host)"); } - private static void Update() + if (!Guid.TryParse(installationId.Trim(), out var installationidGuid)) { - // This portion of code checks for multiple certs in the Identity.pfx PKCS12 bag. If found, it generates - // a new cert and bag to replace the old Identity.pfx. This fixes an issue that came up as a result of - // moving the project to .NET 5. - _context.Install.IdentityCertPassword = Helpers.GetValueFromEnvFile("global", "globalSettings__identityServer__certificatePassword"); - var certCountString = Helpers.Exec("openssl pkcs12 -nokeys -info -in /bitwarden/identity/identity.pfx " + - $"-passin pass:{_context.Install.IdentityCertPassword} 2> /dev/null | grep -c \"\\-----BEGIN CERTIFICATE----\"", true); - if (int.TryParse(certCountString, out var certCount) && certCount > 1) - { - // Extract key from identity.pfx - Helpers.Exec("openssl pkcs12 -in /bitwarden/identity/identity.pfx -nocerts -nodes -out identity.key " + - $"-passin pass:{_context.Install.IdentityCertPassword} > /dev/null 2>&1"); - // Extract certificate from identity.pfx - Helpers.Exec("openssl pkcs12 -in /bitwarden/identity/identity.pfx -clcerts -nokeys -out identity.crt " + - $"-passin pass:{_context.Install.IdentityCertPassword} > /dev/null 2>&1"); - // Create new PKCS12 bag with certificate and key - Helpers.Exec("openssl pkcs12 -export -out /bitwarden/identity/identity.pfx -inkey identity.key " + - $"-in identity.crt -passout pass:{_context.Install.IdentityCertPassword} > /dev/null 2>&1"); - } - - if (_context.Parameters.ContainsKey("db")) - { - MigrateDatabase(); - } - else - { - RebuildConfigs(); - } + Console.WriteLine("Invalid installation id."); + return false; } - private static void PrintEnvironment() + if (_context.Parameters.ContainsKey("install-key")) { - _context.LoadConfiguration(); - if (!_context.PrintToScreen()) - { - return; - } - Console.WriteLine("\nBitwarden is up and running!"); - Console.WriteLine("==================================================="); - Console.WriteLine("\nvisit {0}", _context.Config.Url); - Console.Write("to update, run "); - if (_context.HostOS == "win") - { - Console.Write("`.\\bitwarden.ps1 -updateself` and then `.\\bitwarden.ps1 -update`"); - } - else - { - Console.Write("`./bitwarden.sh updateself` and then `./bitwarden.sh update`"); - } - Console.WriteLine("\n"); + installationKey = _context.Parameters["install-key"]; + } + else + { + installationKey = Helpers.ReadInput("Enter your installation key"); } - private static void MigrateDatabase(int attempt = 1) + _context.Install.InstallationId = installationidGuid; + _context.Install.InstallationKey = installationKey; + + try { - try + var response = new HttpClient().GetAsync("https://api.bitwarden.com/installations/" + + _context.Install.InstallationId).GetAwaiter().GetResult(); + + if (!response.IsSuccessStatusCode) { - Helpers.WriteLine(_context, "Migrating database."); - var vaultConnectionString = Helpers.GetValueFromEnvFile("global", - "globalSettings__sqlServer__connectionString"); - var migrator = new DbMigrator(vaultConnectionString, null); - var success = migrator.MigrateMsSqlDatabase(false); - if (success) + if (response.StatusCode == System.Net.HttpStatusCode.NotFound) { - Helpers.WriteLine(_context, "Migration successful."); + Console.WriteLine("Invalid installation id."); } else { - Helpers.WriteLine(_context, "Migration failed."); + Console.WriteLine("Unable to validate installation id."); } - } - catch (SqlException e) - { - if (e.Message.Contains("Server is in script upgrade mode") && attempt < 10) - { - var nextAttempt = attempt + 1; - Helpers.WriteLine(_context, "Database is in script upgrade mode. " + - "Trying again (attempt #{0})...", nextAttempt); - System.Threading.Thread.Sleep(20000); - MigrateDatabase(nextAttempt); - return; - } - throw; - } - } - private static bool ValidateInstallation() - { - var installationId = string.Empty; - var installationKey = string.Empty; - - if (_context.Parameters.ContainsKey("install-id")) - { - installationId = _context.Parameters["install-id"].ToLowerInvariant(); - } - else - { - installationId = Helpers.ReadInput("Enter your installation id (get at https://bitwarden.com/host)"); - } - - if (!Guid.TryParse(installationId.Trim(), out var installationidGuid)) - { - Console.WriteLine("Invalid installation id."); return false; } - if (_context.Parameters.ContainsKey("install-key")) + var result = response.Content.ReadFromJsonAsync().GetAwaiter().GetResult(); + if (!result.Enabled) { - installationKey = _context.Parameters["install-key"]; - } - else - { - installationKey = Helpers.ReadInput("Enter your installation key"); - } - - _context.Install.InstallationId = installationidGuid; - _context.Install.InstallationKey = installationKey; - - try - { - var response = new HttpClient().GetAsync("https://api.bitwarden.com/installations/" + - _context.Install.InstallationId).GetAwaiter().GetResult(); - - if (!response.IsSuccessStatusCode) - { - if (response.StatusCode == System.Net.HttpStatusCode.NotFound) - { - Console.WriteLine("Invalid installation id."); - } - else - { - Console.WriteLine("Unable to validate installation id."); - } - - return false; - } - - var result = response.Content.ReadFromJsonAsync().GetAwaiter().GetResult(); - if (!result.Enabled) - { - Console.WriteLine("Installation id has been disabled."); - return false; - } - - return true; - } - catch - { - Console.WriteLine("Unable to validate installation id. Problem contacting Bitwarden server."); + Console.WriteLine("Installation id has been disabled."); return false; } + + return true; } - - private static void RebuildConfigs() + catch { - _context.LoadConfiguration(); - - var environmentFileBuilder = new EnvironmentFileBuilder(_context); - environmentFileBuilder.BuildForUpdater(); - - var certBuilder = new CertBuilder(_context); - certBuilder.BuildForUpdater(); - - var nginxBuilder = new NginxConfigBuilder(_context); - nginxBuilder.BuildForUpdater(); - - var appIdBuilder = new AppIdBuilder(_context); - appIdBuilder.Build(); - - var dockerComposeBuilder = new DockerComposeBuilder(_context); - dockerComposeBuilder.BuildForUpdater(); - - _context.SaveConfiguration(); - Console.WriteLine(string.Empty); - } - - private static void ParseParameters() - { - _context.Parameters = new Dictionary(); - for (var i = 0; i < _context.Args.Length; i = i + 2) - { - if (!_context.Args[i].StartsWith("-")) - { - continue; - } - - _context.Parameters.Add(_context.Args[i].Substring(1), _context.Args[i + 1]); - } - } - - class InstallationValidationResponseModel - { - public bool Enabled { get; init; } + Console.WriteLine("Unable to validate installation id. Problem contacting Bitwarden server."); + return false; } } + + private static void RebuildConfigs() + { + _context.LoadConfiguration(); + + var environmentFileBuilder = new EnvironmentFileBuilder(_context); + environmentFileBuilder.BuildForUpdater(); + + var certBuilder = new CertBuilder(_context); + certBuilder.BuildForUpdater(); + + var nginxBuilder = new NginxConfigBuilder(_context); + nginxBuilder.BuildForUpdater(); + + var appIdBuilder = new AppIdBuilder(_context); + appIdBuilder.Build(); + + var dockerComposeBuilder = new DockerComposeBuilder(_context); + dockerComposeBuilder.BuildForUpdater(); + + _context.SaveConfiguration(); + Console.WriteLine(string.Empty); + } + + private static void ParseParameters() + { + _context.Parameters = new Dictionary(); + for (var i = 0; i < _context.Args.Length; i = i + 2) + { + if (!_context.Args[i].StartsWith("-")) + { + continue; + } + + _context.Parameters.Add(_context.Args[i].Substring(1), _context.Args[i + 1]); + } + } + + class InstallationValidationResponseModel + { + public bool Enabled { get; init; } + } } diff --git a/util/Setup/YamlComments.cs b/util/Setup/YamlComments.cs index 32b935d50..5bdb6fddf 100644 --- a/util/Setup/YamlComments.cs +++ b/util/Setup/YamlComments.cs @@ -7,102 +7,101 @@ using YamlDotNet.Serialization.TypeInspectors; // ref: https://github.com/aaubry/YamlDotNet/issues/152#issuecomment-349034754 -namespace Bit.Setup +namespace Bit.Setup; + +public class CommentGatheringTypeInspector : TypeInspectorSkeleton { - public class CommentGatheringTypeInspector : TypeInspectorSkeleton + private readonly ITypeInspector _innerTypeDescriptor; + + public CommentGatheringTypeInspector(ITypeInspector innerTypeDescriptor) { - private readonly ITypeInspector _innerTypeDescriptor; - - public CommentGatheringTypeInspector(ITypeInspector innerTypeDescriptor) - { - _innerTypeDescriptor = innerTypeDescriptor ?? throw new ArgumentNullException(nameof(innerTypeDescriptor)); - } - - public override IEnumerable GetProperties(Type type, object container) - { - return _innerTypeDescriptor.GetProperties(type, container).Select(d => new CommentsPropertyDescriptor(d)); - } - - private sealed class CommentsPropertyDescriptor : IPropertyDescriptor - { - private readonly IPropertyDescriptor _baseDescriptor; - - public CommentsPropertyDescriptor(IPropertyDescriptor baseDescriptor) - { - _baseDescriptor = baseDescriptor; - Name = baseDescriptor.Name; - } - - public string Name { get; set; } - public int Order { get; set; } - public Type Type => _baseDescriptor.Type; - public bool CanWrite => _baseDescriptor.CanWrite; - - public Type TypeOverride - { - get { return _baseDescriptor.TypeOverride; } - set { _baseDescriptor.TypeOverride = value; } - } - - public ScalarStyle ScalarStyle - { - get { return _baseDescriptor.ScalarStyle; } - set { _baseDescriptor.ScalarStyle = value; } - } - - public void Write(object target, object value) - { - _baseDescriptor.Write(target, value); - } - - public T GetCustomAttribute() where T : Attribute - { - return _baseDescriptor.GetCustomAttribute(); - } - - public IObjectDescriptor Read(object target) - { - var description = _baseDescriptor.GetCustomAttribute(); - return description != null ? - new CommentsObjectDescriptor(_baseDescriptor.Read(target), description.Description) : - _baseDescriptor.Read(target); - } - } + _innerTypeDescriptor = innerTypeDescriptor ?? throw new ArgumentNullException(nameof(innerTypeDescriptor)); } - public sealed class CommentsObjectDescriptor : IObjectDescriptor + public override IEnumerable GetProperties(Type type, object container) { - private readonly IObjectDescriptor _innerDescriptor; - - public CommentsObjectDescriptor(IObjectDescriptor innerDescriptor, string comment) - { - _innerDescriptor = innerDescriptor; - Comment = comment; - } - - public string Comment { get; private set; } - public object Value => _innerDescriptor.Value; - public Type Type => _innerDescriptor.Type; - public Type StaticType => _innerDescriptor.StaticType; - public ScalarStyle ScalarStyle => _innerDescriptor.ScalarStyle; + return _innerTypeDescriptor.GetProperties(type, container).Select(d => new CommentsPropertyDescriptor(d)); } - public class CommentsObjectGraphVisitor : ChainedObjectGraphVisitor + private sealed class CommentsPropertyDescriptor : IPropertyDescriptor { - public CommentsObjectGraphVisitor(IObjectGraphVisitor nextVisitor) - : base(nextVisitor) { } + private readonly IPropertyDescriptor _baseDescriptor; - public override bool EnterMapping(IPropertyDescriptor key, IObjectDescriptor value, IEmitter context) + public CommentsPropertyDescriptor(IPropertyDescriptor baseDescriptor) { - if (value is CommentsObjectDescriptor commentsDescriptor && commentsDescriptor.Comment != null) - { - context.Emit(new Comment(string.Empty, false)); - foreach (var comment in commentsDescriptor.Comment.Split(Environment.NewLine)) - { - context.Emit(new Comment(comment, false)); - } - } - return base.EnterMapping(key, value, context); + _baseDescriptor = baseDescriptor; + Name = baseDescriptor.Name; + } + + public string Name { get; set; } + public int Order { get; set; } + public Type Type => _baseDescriptor.Type; + public bool CanWrite => _baseDescriptor.CanWrite; + + public Type TypeOverride + { + get { return _baseDescriptor.TypeOverride; } + set { _baseDescriptor.TypeOverride = value; } + } + + public ScalarStyle ScalarStyle + { + get { return _baseDescriptor.ScalarStyle; } + set { _baseDescriptor.ScalarStyle = value; } + } + + public void Write(object target, object value) + { + _baseDescriptor.Write(target, value); + } + + public T GetCustomAttribute() where T : Attribute + { + return _baseDescriptor.GetCustomAttribute(); + } + + public IObjectDescriptor Read(object target) + { + var description = _baseDescriptor.GetCustomAttribute(); + return description != null ? + new CommentsObjectDescriptor(_baseDescriptor.Read(target), description.Description) : + _baseDescriptor.Read(target); } } } + +public sealed class CommentsObjectDescriptor : IObjectDescriptor +{ + private readonly IObjectDescriptor _innerDescriptor; + + public CommentsObjectDescriptor(IObjectDescriptor innerDescriptor, string comment) + { + _innerDescriptor = innerDescriptor; + Comment = comment; + } + + public string Comment { get; private set; } + public object Value => _innerDescriptor.Value; + public Type Type => _innerDescriptor.Type; + public Type StaticType => _innerDescriptor.StaticType; + public ScalarStyle ScalarStyle => _innerDescriptor.ScalarStyle; +} + +public class CommentsObjectGraphVisitor : ChainedObjectGraphVisitor +{ + public CommentsObjectGraphVisitor(IObjectGraphVisitor nextVisitor) + : base(nextVisitor) { } + + public override bool EnterMapping(IPropertyDescriptor key, IObjectDescriptor value, IEmitter context) + { + if (value is CommentsObjectDescriptor commentsDescriptor && commentsDescriptor.Comment != null) + { + context.Emit(new Comment(string.Empty, false)); + foreach (var comment in commentsDescriptor.Comment.Split(Environment.NewLine)) + { + context.Emit(new Comment(comment, false)); + } + } + return base.EnterMapping(key, value, context); + } +}