diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs index 6c887b98b1..5c3ec7f036 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs @@ -10,6 +10,7 @@ public static class PolicyServiceCollectionExtensions public static void AddPolicyServices(this IServiceCollection services) { services.AddScoped(); + services.AddScoped(); services.AddScoped(); } } diff --git a/src/Core/AdminConsole/Services/IPolicyServicevNext.cs b/src/Core/AdminConsole/Services/IPolicyServicevNext.cs new file mode 100644 index 0000000000..6366bc552d --- /dev/null +++ b/src/Core/AdminConsole/Services/IPolicyServicevNext.cs @@ -0,0 +1,15 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; +using Bit.Core.Services; + +namespace Bit.Core.AdminConsole.Services; + +public interface IPolicyServicevNext +{ + Task SaveAsync(Policy policy, IUserService userService, IOrganizationService organizationService, + Guid? savingUserId); +} diff --git a/src/Core/AdminConsole/Services/Implementations/PolicyService.cs b/src/Core/AdminConsole/Services/Implementations/PolicyService.cs index 1af2e91090..1ffa2a0e26 100644 --- a/src/Core/AdminConsole/Services/Implementations/PolicyService.cs +++ b/src/Core/AdminConsole/Services/Implementations/PolicyService.cs @@ -1,10 +1,10 @@ -#nullable enable - -using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; -using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Repositories; +using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -19,36 +19,43 @@ public class PolicyService : IPolicyService { private readonly IApplicationCacheService _applicationCacheService; private readonly IEventService _eventService; + private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IPolicyRepository _policyRepository; + private readonly ISsoConfigRepository _ssoConfigRepository; + private readonly IMailService _mailService; private readonly GlobalSettings _globalSettings; - private readonly Dictionary _policyDefinitions = new(); + private readonly IFeatureService _featureService; + private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; public PolicyService( IApplicationCacheService applicationCacheService, IEventService eventService, + IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, IPolicyRepository policyRepository, + ISsoConfigRepository ssoConfigRepository, + IMailService mailService, GlobalSettings globalSettings, - IEnumerable policyDefinitions) + IFeatureService featureService, + ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery) { _applicationCacheService = applicationCacheService; _eventService = eventService; + _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; _policyRepository = policyRepository; + _ssoConfigRepository = ssoConfigRepository; + _mailService = mailService; _globalSettings = globalSettings; - - foreach (var policyDefinition in policyDefinitions) - { - _policyDefinitions.Add(policyDefinition.Type, policyDefinition); - // TODO: throw if any policyDefinition is missing - } + _featureService = featureService; + _twoFactorIsEnabledQuery = twoFactorIsEnabledQuery; } public async Task SaveAsync(Policy policy, IUserService userService, IOrganizationService organizationService, Guid? savingUserId) { - var org = await _applicationCacheService.GetOrganizationAbilityAsync(policy.OrganizationId); + var org = await _organizationRepository.GetByIdAsync(policy.OrganizationId); if (org == null) { throw new BadRequestException("Organization not found"); @@ -59,59 +66,36 @@ public class PolicyService : IPolicyService throw new BadRequestException("This organization cannot use policies."); } - var policyDefinition = _policyDefinitions[policy.Type]; - var allSavedPolicies = await _policyRepository.GetManyByOrganizationIdAsync(org.Id); - var currentPolicy = allSavedPolicies.SingleOrDefault(p => p.Id == policy.Id); - - // If enabling this policy - check that all policy requirements are satisfied - if (currentPolicy is not { Enabled: true } && policy.Enabled) - { - foreach (var requiredPolicyType in policyDefinition.RequiredPolicies) - { - if (allSavedPolicies.SingleOrDefault(p => p.Type == requiredPolicyType) is not { Enabled: true }) - { - // TODO: would be better to reference the name instead of the enum - throw new BadRequestException("Policy requires PolicyType " + requiredPolicyType + " to be enabled first."); - } - } - } - - // If disabling this policy - ensure it's not required by any other policy - if (currentPolicy is { Enabled: true } && !policy.Enabled) - { - var dependentPolicies = _policyDefinitions.Values - .Where(policyDef => policyDef.RequiredPolicies.Contains(policy.Type)) - .Select(policyDef => policyDef.Type) - .Select(otherPolicyType => allSavedPolicies.SingleOrDefault(p => p.Type == otherPolicyType)) - .Where(otherPolicy => otherPolicy is { Enabled: true }) - .ToList(); - - if (dependentPolicies is { Count: > 0}) - { - throw new BadRequestException("This policy is required by " + dependentPolicies.First() + ". Try disabling that policy first." ); - } - } - - // Run other validation - var validationError = await policyDefinition.ValidateAsync(currentPolicy, policy); - if (validationError != null) - { - throw new BadRequestException(validationError); - } - - // Run side effects - await policyDefinition.OnSaveSideEffectsAsync(currentPolicy, policy); + // FIXME: This method will throw a bunch of errors based on if the + // policy that is being applied requires some other policy that is + // not enabled. It may be advisable to refactor this into a domain + // object and get this kind of stuff out of the service. + await HandleDependentPoliciesAsync(policy, org); var now = DateTime.UtcNow; - if (policy.Id == default) + if (policy.Id == default(Guid)) { policy.CreationDate = now; } policy.RevisionDate = now; - await _policyRepository.UpsertAsync(policy); - await _eventService.LogPolicyEventAsync(policy, EventType.Policy_Updated); + // We can exit early for disable operations, because they are + // simpler. + if (!policy.Enabled) + { + await SetPolicyConfiguration(policy); + return; + } + + if (_featureService.IsEnabled(FeatureFlagKeys.MembersTwoFAQueryOptimization)) + { + await EnablePolicy_vNext(policy, org, organizationService, savingUserId); + return; + } + + await EnablePolicy(policy, org, userService, organizationService, savingUserId); + return; } public async Task GetMasterPasswordPolicyForUserAsync(User user) @@ -177,4 +161,223 @@ public class PolicyService : IPolicyService return new[] { OrganizationUserType.Owner, OrganizationUserType.Admin }; } + + private async Task DependsOnSingleOrgAsync(Organization org) + { + var singleOrg = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.SingleOrg); + if (singleOrg?.Enabled != true) + { + throw new BadRequestException("Single Organization policy not enabled."); + } + } + + private async Task RequiredBySsoAsync(Organization org) + { + var requireSso = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.RequireSso); + if (requireSso?.Enabled == true) + { + throw new BadRequestException("Single Sign-On Authentication policy is enabled."); + } + } + + private async Task RequiredByKeyConnectorAsync(Organization org) + { + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(org.Id); + if (ssoConfig?.GetData()?.MemberDecryptionType == MemberDecryptionType.KeyConnector) + { + throw new BadRequestException("Key Connector is enabled."); + } + } + + private async Task RequiredByAccountRecoveryAsync(Organization org) + { + var requireSso = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.ResetPassword); + if (requireSso?.Enabled == true) + { + throw new BadRequestException("Account recovery policy is enabled."); + } + } + + private async Task RequiredByVaultTimeoutAsync(Organization org) + { + var vaultTimeout = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.MaximumVaultTimeout); + if (vaultTimeout?.Enabled == true) + { + throw new BadRequestException("Maximum Vault Timeout policy is enabled."); + } + } + + private async Task RequiredBySsoTrustedDeviceEncryptionAsync(Organization org) + { + var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(org.Id); + if (ssoConfig?.GetData()?.MemberDecryptionType == MemberDecryptionType.TrustedDeviceEncryption) + { + throw new BadRequestException("Trusted device encryption is on and requires this policy."); + } + } + + private async Task HandleDependentPoliciesAsync(Policy policy, Organization org) + { + switch (policy.Type) + { + case PolicyType.SingleOrg: + if (!policy.Enabled) + { + await RequiredBySsoAsync(org); + await RequiredByVaultTimeoutAsync(org); + await RequiredByKeyConnectorAsync(org); + await RequiredByAccountRecoveryAsync(org); + } + break; + + case PolicyType.RequireSso: + if (policy.Enabled) + { + await DependsOnSingleOrgAsync(org); + } + else + { + await RequiredByKeyConnectorAsync(org); + await RequiredBySsoTrustedDeviceEncryptionAsync(org); + } + break; + + case PolicyType.ResetPassword: + if (!policy.Enabled || policy.GetDataModel()?.AutoEnrollEnabled == false) + { + await RequiredBySsoTrustedDeviceEncryptionAsync(org); + } + + if (policy.Enabled) + { + await DependsOnSingleOrgAsync(org); + } + break; + + case PolicyType.MaximumVaultTimeout: + if (policy.Enabled) + { + await DependsOnSingleOrgAsync(org); + } + break; + } + } + + private async Task SetPolicyConfiguration(Policy policy) + { + await _policyRepository.UpsertAsync(policy); + await _eventService.LogPolicyEventAsync(policy, EventType.Policy_Updated); + } + + private async Task EnablePolicy(Policy policy, Organization org, IUserService userService, IOrganizationService organizationService, Guid? savingUserId) + { + var currentPolicy = await _policyRepository.GetByIdAsync(policy.Id); + if (!currentPolicy?.Enabled ?? true) + { + var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(policy.OrganizationId); + var removableOrgUsers = orgUsers.Where(ou => + ou.Status != OrganizationUserStatusType.Invited && ou.Status != OrganizationUserStatusType.Revoked && + ou.Type != OrganizationUserType.Owner && ou.Type != OrganizationUserType.Admin && + ou.UserId != savingUserId); + switch (policy.Type) + { + case PolicyType.TwoFactorAuthentication: + // Reorder by HasMasterPassword to prioritize checking users without a master if they have 2FA enabled + foreach (var orgUser in removableOrgUsers.OrderBy(ou => ou.HasMasterPassword)) + { + if (!await userService.TwoFactorIsEnabledAsync(orgUser)) + { + if (!orgUser.HasMasterPassword) + { + throw new BadRequestException( + "Policy could not be enabled. Non-compliant members will lose access to their accounts. Identify members without two-step login from the policies column in the members page."); + } + + await organizationService.RemoveUserAsync(policy.OrganizationId, orgUser.Id, + savingUserId); + await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync( + org.DisplayName(), orgUser.Email); + } + } + break; + case PolicyType.SingleOrg: + var userOrgs = await _organizationUserRepository.GetManyByManyUsersAsync( + removableOrgUsers.Select(ou => ou.UserId.Value)); + foreach (var orgUser in removableOrgUsers) + { + if (userOrgs.Any(ou => ou.UserId == orgUser.UserId + && ou.OrganizationId != org.Id + && ou.Status != OrganizationUserStatusType.Invited)) + { + await organizationService.RemoveUserAsync(policy.OrganizationId, orgUser.Id, + savingUserId); + await _mailService.SendOrganizationUserRemovedForPolicySingleOrgEmailAsync( + org.DisplayName(), orgUser.Email); + } + } + break; + default: + break; + } + } + + await SetPolicyConfiguration(policy); + } + + private async Task EnablePolicy_vNext(Policy policy, Organization org, IOrganizationService organizationService, Guid? savingUserId) + { + var currentPolicy = await _policyRepository.GetByIdAsync(policy.Id); + if (!currentPolicy?.Enabled ?? true) + { + var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(policy.OrganizationId); + var organizationUsersTwoFactorEnabled = await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(orgUsers); + var removableOrgUsers = orgUsers.Where(ou => + ou.Status != OrganizationUserStatusType.Invited && ou.Status != OrganizationUserStatusType.Revoked && + ou.Type != OrganizationUserType.Owner && ou.Type != OrganizationUserType.Admin && + ou.UserId != savingUserId); + switch (policy.Type) + { + case PolicyType.TwoFactorAuthentication: + // Reorder by HasMasterPassword to prioritize checking users without a master if they have 2FA enabled + foreach (var orgUser in removableOrgUsers.OrderBy(ou => ou.HasMasterPassword)) + { + var userTwoFactorEnabled = organizationUsersTwoFactorEnabled.FirstOrDefault(u => u.user.Id == orgUser.Id).twoFactorIsEnabled; + if (!userTwoFactorEnabled) + { + if (!orgUser.HasMasterPassword) + { + throw new BadRequestException( + "Policy could not be enabled. Non-compliant members will lose access to their accounts. Identify members without two-step login from the policies column in the members page."); + } + + await organizationService.RemoveUserAsync(policy.OrganizationId, orgUser.Id, + savingUserId); + await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync( + org.DisplayName(), orgUser.Email); + } + } + break; + case PolicyType.SingleOrg: + var userOrgs = await _organizationUserRepository.GetManyByManyUsersAsync( + removableOrgUsers.Select(ou => ou.UserId.Value)); + foreach (var orgUser in removableOrgUsers) + { + if (userOrgs.Any(ou => ou.UserId == orgUser.UserId + && ou.OrganizationId != org.Id + && ou.Status != OrganizationUserStatusType.Invited)) + { + await organizationService.RemoveUserAsync(policy.OrganizationId, orgUser.Id, + savingUserId); + await _mailService.SendOrganizationUserRemovedForPolicySingleOrgEmailAsync( + org.DisplayName(), orgUser.Email); + } + } + break; + default: + break; + } + } + + await SetPolicyConfiguration(policy); + } } diff --git a/src/Core/AdminConsole/Services/Implementations/PolicyServicevNext.cs b/src/Core/AdminConsole/Services/Implementations/PolicyServicevNext.cs new file mode 100644 index 0000000000..8437a1fca4 --- /dev/null +++ b/src/Core/AdminConsole/Services/Implementations/PolicyServicevNext.cs @@ -0,0 +1,110 @@ +#nullable enable + +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Settings; + +namespace Bit.Core.AdminConsole.Services.Implementations; + +public class PolicyServicevNext : IPolicyServicevNext +{ + private readonly IApplicationCacheService _applicationCacheService; + private readonly IEventService _eventService; + private readonly IPolicyRepository _policyRepository; + private readonly Dictionary _policyDefinitions = new(); + + public PolicyServicevNext( + IApplicationCacheService applicationCacheService, + IEventService eventService, + IPolicyRepository policyRepository, + IEnumerable policyDefinitions) + { + _applicationCacheService = applicationCacheService; + _eventService = eventService; + _policyRepository = policyRepository; + + foreach (var policyDefinition in policyDefinitions) + { + _policyDefinitions.Add(policyDefinition.Type, policyDefinition); + // TODO: throw if any policyDefinition is missing + } + } + + public async Task SaveAsync(Policy policy, IUserService userService, IOrganizationService organizationService, + Guid? savingUserId) + { + var org = await _applicationCacheService.GetOrganizationAbilityAsync(policy.OrganizationId); + if (org == null) + { + throw new BadRequestException("Organization not found"); + } + + if (!org.UsePolicies) + { + throw new BadRequestException("This organization cannot use policies."); + } + + var policyDefinition = _policyDefinitions[policy.Type]; + var allSavedPolicies = await _policyRepository.GetManyByOrganizationIdAsync(org.Id); + var currentPolicy = allSavedPolicies.SingleOrDefault(p => p.Id == policy.Id); + + // If enabling this policy - check that all policy requirements are satisfied + if (currentPolicy is not { Enabled: true } && policy.Enabled) + { + foreach (var requiredPolicyType in policyDefinition.RequiredPolicies) + { + if (allSavedPolicies.SingleOrDefault(p => p.Type == requiredPolicyType) is not { Enabled: true }) + { + // TODO: would be better to reference the name instead of the enum + throw new BadRequestException("Policy requires PolicyType " + requiredPolicyType + " to be enabled first."); + } + } + } + + // If disabling this policy - ensure it's not required by any other policy + if (currentPolicy is { Enabled: true } && !policy.Enabled) + { + var dependentPolicies = _policyDefinitions.Values + .Where(policyDef => policyDef.RequiredPolicies.Contains(policy.Type)) + .Select(policyDef => policyDef.Type) + .Select(otherPolicyType => allSavedPolicies.SingleOrDefault(p => p.Type == otherPolicyType)) + .Where(otherPolicy => otherPolicy is { Enabled: true }) + .ToList(); + + if (dependentPolicies is { Count: > 0}) + { + throw new BadRequestException("This policy is required by " + dependentPolicies.First() + ". Try disabling that policy first." ); + } + } + + // Run other validation + var validationError = await policyDefinition.ValidateAsync(currentPolicy, policy); + if (validationError != null) + { + throw new BadRequestException(validationError); + } + + // Run side effects + await policyDefinition.OnSaveSideEffectsAsync(currentPolicy, policy); + + var now = DateTime.UtcNow; + if (policy.Id == default) + { + policy.CreationDate = now; + } + + policy.RevisionDate = now; + + await _policyRepository.UpsertAsync(policy); + await _eventService.LogPolicyEventAsync(policy, EventType.Policy_Updated); + } +} diff --git a/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs b/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs index 8d4bcf3fcb..fd7597a748 100644 --- a/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs +++ b/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs @@ -9,7 +9,6 @@ using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Repositories; using Bit.Core.Enums; using Bit.Core.Exceptions; -using Bit.Core.Models.Data.Organizations; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; using Bit.Core.Services; @@ -52,7 +51,9 @@ public class PolicyServiceTests public async Task SaveAsync_OrganizationCannotUsePolicies_ThrowsBadRequest( [AdminConsoleFixtures.Policy(PolicyType.DisableSend)] Policy policy, SutProvider sutProvider) { - SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility + var orgId = Guid.NewGuid(); + + SetupOrg(sutProvider, policy.OrganizationId, new Organization { UsePolicies = false, }); @@ -80,7 +81,7 @@ public class PolicyServiceTests { policy.Enabled = false; - SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility + SetupOrg(sutProvider, policy.OrganizationId, new Organization { Id = policy.OrganizationId, UsePolicies = true, @@ -112,7 +113,7 @@ public class PolicyServiceTests { policy.Enabled = false; - SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility + SetupOrg(sutProvider, policy.OrganizationId, new Organization { Id = policy.OrganizationId, UsePolicies = true, @@ -146,7 +147,7 @@ public class PolicyServiceTests policy.Enabled = false; policy.Type = policyType; - SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility + SetupOrg(sutProvider, policy.OrganizationId, new Organization { Id = policy.OrganizationId, UsePolicies = true, @@ -179,7 +180,7 @@ public class PolicyServiceTests { policy.Enabled = true; - SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility + SetupOrg(sutProvider, policy.OrganizationId, new Organization { Id = policy.OrganizationId, UsePolicies = true, @@ -213,7 +214,7 @@ public class PolicyServiceTests policy.Id = default; policy.Data = null; - SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility + SetupOrg(sutProvider, policy.OrganizationId, new Organization { Id = policy.OrganizationId, UsePolicies = true, @@ -243,7 +244,7 @@ public class PolicyServiceTests { policy.Enabled = true; - SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility + SetupOrg(sutProvider, policy.OrganizationId, new Organization { Id = policy.OrganizationId, UsePolicies = true, @@ -272,19 +273,16 @@ public class PolicyServiceTests [Theory, BitAutoData] public async Task SaveAsync_ExistingPolicy_UpdateTwoFactor( - OrganizationAbility organizationAbility, Organization organization, [AdminConsoleFixtures.Policy(PolicyType.TwoFactorAuthentication)] Policy policy, SutProvider sutProvider) { // If the policy that this is updating isn't enabled then do some work now that the current one is enabled - organizationAbility.UsePolicies = true; - policy.OrganizationId = organizationAbility.Id = organization.Id; + organization.UsePolicies = true; + policy.OrganizationId = organization.Id; - SetupOrg(sutProvider, organization.Id, organizationAbility); - - sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + SetupOrg(sutProvider, organization.Id, organization); sutProvider.GetDependency() .GetByIdAsync(policy.Id) @@ -394,7 +392,7 @@ public class PolicyServiceTests [Theory, BitAutoData] public async Task SaveAsync_EnableTwoFactor_WithoutMasterPasswordOr2FA_ThrowsBadRequest( - OrganizationAbility organization, + Organization organization, [AdminConsoleFixtures.Policy(PolicyType.TwoFactorAuthentication)] Policy policy, SutProvider sutProvider) { @@ -491,10 +489,11 @@ public class PolicyServiceTests { // If the policy that this is updating isn't enabled then do some work now that the current one is enabled - var org = new OrganizationAbility() + var org = new Organization { Id = policy.OrganizationId, UsePolicies = true, + Name = "TEST", }; SetupOrg(sutProvider, policy.OrganizationId, org); @@ -565,7 +564,7 @@ public class PolicyServiceTests AutoEnrollEnabled = autoEnrollEnabled }); - SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility + SetupOrg(sutProvider, policy.OrganizationId, new Organization { Id = policy.OrganizationId, UsePolicies = true, @@ -602,7 +601,7 @@ public class PolicyServiceTests { policy.Enabled = false; - SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility + SetupOrg(sutProvider, policy.OrganizationId, new Organization { Id = policy.OrganizationId, UsePolicies = true, @@ -639,7 +638,7 @@ public class PolicyServiceTests policy.Enabled = true; policy.SetDataModel(new ResetPasswordDataModel()); - SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility + SetupOrg(sutProvider, policy.OrganizationId, new Organization { Id = policy.OrganizationId, UsePolicies = true, @@ -673,7 +672,7 @@ public class PolicyServiceTests { policy.Enabled = false; - SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility + SetupOrg(sutProvider, policy.OrganizationId, new Organization { Id = policy.OrganizationId, UsePolicies = true, @@ -798,11 +797,11 @@ public class PolicyServiceTests Assert.True(result); } - private static void SetupOrg(SutProvider sutProvider, Guid organizationId, OrganizationAbility organizationAbility) + private static void SetupOrg(SutProvider sutProvider, Guid organizationId, Organization organization) { - sutProvider.GetDependency() - .GetOrganizationAbilityAsync(organizationId) - .Returns(organizationAbility); + sutProvider.GetDependency() + .GetByIdAsync(organizationId) + .Returns(Task.FromResult(organization)); } private static void SetupUserPolicies(Guid userId, SutProvider sutProvider)