From f6c7be50cf7f82eca28d98158af70d6f7351ec1c Mon Sep 17 00:00:00 2001 From: Thomas Rittson Date: Fri, 4 Oct 2024 12:25:18 +1000 Subject: [PATCH] Reduce scope to just saving, implement RequiredPolicies --- .../Policies/IPolicyDefinition.cs | 28 ++-------- .../SingleOrgPolicyDefinition.cs | 29 +--------- .../Policies/PolicyDefinitionExtensions.cs | 8 --- .../PolicyServiceCollectionExtensions.cs | 2 +- .../Services/Implementations/PolicyService.cs | 53 +++++++++++++++---- .../Services/PolicyServiceTests.cs | 47 ++++++++-------- 6 files changed, 74 insertions(+), 93 deletions(-) delete mode 100644 src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyDefinitionExtensions.cs diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyDefinition.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyDefinition.cs index 8992ab611e..c5c1b12574 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyDefinition.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyDefinition.cs @@ -6,7 +6,7 @@ using Bit.Core.Entities; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies; -public interface IPolicyDefinition +public interface IPolicyDefinition { /// /// The PolicyType that the strategy is responsible for handling. @@ -14,23 +14,14 @@ public interface IPolicyDefinition public PolicyType Type { get; } /// - /// A predicate function that returns true if a policy should be enforced against a user - /// and false otherwise. This does not need to check Organization.UsePolicies or Policy.Enabled. + /// PolicyTypes that must be enabled before this policy can be enabled, if any. /// - public Predicate<(OrganizationUser orgUser, Policy policy)> Filter { get; } + public IEnumerable RequiredPolicies { get; } - /// - /// A reducer function that reduces Policies into policy requirements (as defined by TRequirement). - /// This is used to reconcile policies of the same type from different organizations and combine them into - /// a single object that represents the requirements of the domain. - /// - public (Func reducer, TRequirement initialValue) Reducer { get; } - - // TODO: Currently interdependencies between policies must be checked in both definitions. - // TODO: Consider a separate definition for policy prerequisites that is automatically cross-checked on all handlers, - // TODO: so they can be declared once only. /// /// Validates a policy before saving it. + /// Basic interdependencies between policies are already handled by the definition. + /// Use this for additional or more complex validation, if any. /// /// The current policy, if any /// The modified policy to be saved @@ -45,12 +36,3 @@ public interface IPolicyDefinition /// The modified policy to be saved public Task OnSaveSideEffectsAsync(Policy? currentPolicy, Policy modifiedPolicy); } - -public interface IPolicyDefinition : IPolicyDefinition -{ - /// - /// A factory that transforms the untyped Policy.Data JSON object to a domain specific object, - /// usually used for additional policy configuration. - /// - public Func? DataFactory { get; } -} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SingleOrgPolicyDefinition.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SingleOrgPolicyDefinition.cs index 8f7fc572df..cb177113ba 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SingleOrgPolicyDefinition.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SingleOrgPolicyDefinition.cs @@ -15,11 +15,10 @@ using Bit.Core.Services; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; -public record SingleOrgRequirement(bool SingleOrgRequired); - -public class SingleOrgPolicyDefinition : IPolicyDefinition +public class SingleOrgPolicyDefinition : IPolicyDefinition { public PolicyType Type => PolicyType.SingleOrg; + public IEnumerable RequiredPolicies => Array.Empty(); private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IMailService _mailService; @@ -44,13 +43,6 @@ public class SingleOrgPolicyDefinition : IPolicyDefinition _currentContext = currentContext; } - - public Predicate<(OrganizationUser orgUser, Policy policy)> Filter => tuple => - tuple.orgUser is not { Type: OrganizationUserType.Owner or OrganizationUserType.Admin }; - - public (Func reducer, SingleOrgRequirement initialValue) Reducer() => - ((SingleOrgRequirement init, Policy next, SingleOrgRequirement ) => new SingleOrgRequirement(true), new SingleOrgRequirement(false)); - public async Task OnSaveSideEffectsAsync(Policy? currentPolicy, Policy modifiedPolicy) { if (currentPolicy is null or { Enabled: false } && modifiedPolicy is { Enabled: true }) @@ -100,23 +92,6 @@ public class SingleOrgPolicyDefinition : IPolicyDefinition { var organizationId = modifiedPolicy.OrganizationId; - // Do not allow this policy to be disabled if a dependent policy is still enabled - var policies = await _policyRepository.GetManyByOrganizationIdAsync(organizationId); - if (policies.Any(p => p.Type == PolicyType.RequireSso && p.Enabled)) - { - return "Single Sign-On Authentication policy is enabled."; - } - - if (policies.Any(p => p.Type == PolicyType.MaximumVaultTimeout && p.Enabled)) - { - return "Maximum Vault Timeout policy is enabled."; - } - - if (policies.Any(p => p.Type == PolicyType.ResetPassword && p.Enabled)) - { - return "Account Recovery policy is enabled."; - } - // Do not allow this policy to be disabled if Key Connector is being used var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(organizationId); if (ssoConfig?.GetData()?.MemberDecryptionType == MemberDecryptionType.KeyConnector) diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyDefinitionExtensions.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyDefinitionExtensions.cs deleted file mode 100644 index f9253456af..0000000000 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyDefinitionExtensions.cs +++ /dev/null @@ -1,8 +0,0 @@ -using Bit.Core.AdminConsole.Entities; - -namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies; - -public static class PolicyDefinitionExtensions -{ - public static void PolicyStateChanged(Policy? currentPolicy, Policy modifiedPolicy) -} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs index 25eabf9052..6c887b98b1 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs @@ -10,6 +10,6 @@ public static class PolicyServiceCollectionExtensions public static void AddPolicyServices(this IServiceCollection services) { services.AddScoped(); - services.AddScoped, SingleOrgPolicyDefinition>(); + services.AddScoped(); } } diff --git a/src/Core/AdminConsole/Services/Implementations/PolicyService.cs b/src/Core/AdminConsole/Services/Implementations/PolicyService.cs index 8ea88fd944..1af2e91090 100644 --- a/src/Core/AdminConsole/Services/Implementations/PolicyService.cs +++ b/src/Core/AdminConsole/Services/Implementations/PolicyService.cs @@ -19,35 +19,36 @@ public class PolicyService : IPolicyService { private readonly IApplicationCacheService _applicationCacheService; private readonly IEventService _eventService; - private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IPolicyRepository _policyRepository; private readonly GlobalSettings _globalSettings; - private readonly IEnumerable> _policyStrategies; + private readonly Dictionary _policyDefinitions = new(); public PolicyService( IApplicationCacheService applicationCacheService, IEventService eventService, - IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, IPolicyRepository policyRepository, GlobalSettings globalSettings, - IEnumerable> policyStrategies) + IEnumerable policyDefinitions) { _applicationCacheService = applicationCacheService; _eventService = eventService; - _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; _policyRepository = policyRepository; _globalSettings = globalSettings; - _policyStrategies = policyStrategies; + + foreach (var policyDefinition in policyDefinitions) + { + _policyDefinitions.Add(policyDefinition.Type, policyDefinition); + // TODO: throw if any policyDefinition is missing + } } public async Task SaveAsync(Policy policy, IUserService userService, IOrganizationService organizationService, Guid? savingUserId) { - // TODO: this could use the cache - var org = await _organizationRepository.GetByIdAsync(policy.OrganizationId); + var org = await _applicationCacheService.GetOrganizationAbilityAsync(policy.OrganizationId); if (org == null) { throw new BadRequestException("Organization not found"); @@ -58,10 +59,40 @@ public class PolicyService : IPolicyService throw new BadRequestException("This organization cannot use policies."); } - var policyDefinition = _policyStrategies.Single(strategy => strategy.Type == policy.Type); - var currentPolicy = await _policyRepository.GetByIdAsync(policy.Id); + var policyDefinition = _policyDefinitions[policy.Type]; + var allSavedPolicies = await _policyRepository.GetManyByOrganizationIdAsync(org.Id); + var currentPolicy = allSavedPolicies.SingleOrDefault(p => p.Id == policy.Id); - // Validate + // If enabling this policy - check that all policy requirements are satisfied + if (currentPolicy is not { Enabled: true } && policy.Enabled) + { + foreach (var requiredPolicyType in policyDefinition.RequiredPolicies) + { + if (allSavedPolicies.SingleOrDefault(p => p.Type == requiredPolicyType) is not { Enabled: true }) + { + // TODO: would be better to reference the name instead of the enum + throw new BadRequestException("Policy requires PolicyType " + requiredPolicyType + " to be enabled first."); + } + } + } + + // If disabling this policy - ensure it's not required by any other policy + if (currentPolicy is { Enabled: true } && !policy.Enabled) + { + var dependentPolicies = _policyDefinitions.Values + .Where(policyDef => policyDef.RequiredPolicies.Contains(policy.Type)) + .Select(policyDef => policyDef.Type) + .Select(otherPolicyType => allSavedPolicies.SingleOrDefault(p => p.Type == otherPolicyType)) + .Where(otherPolicy => otherPolicy is { Enabled: true }) + .ToList(); + + if (dependentPolicies is { Count: > 0}) + { + throw new BadRequestException("This policy is required by " + dependentPolicies.First() + ". Try disabling that policy first." ); + } + } + + // Run other validation var validationError = await policyDefinition.ValidateAsync(currentPolicy, policy); if (validationError != null) { diff --git a/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs b/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs index fd7597a748..8d4bcf3fcb 100644 --- a/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs +++ b/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs @@ -9,6 +9,7 @@ using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Repositories; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Models.Data.Organizations; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; using Bit.Core.Services; @@ -51,9 +52,7 @@ public class PolicyServiceTests public async Task SaveAsync_OrganizationCannotUsePolicies_ThrowsBadRequest( [AdminConsoleFixtures.Policy(PolicyType.DisableSend)] Policy policy, SutProvider sutProvider) { - var orgId = Guid.NewGuid(); - - SetupOrg(sutProvider, policy.OrganizationId, new Organization + SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility { UsePolicies = false, }); @@ -81,7 +80,7 @@ public class PolicyServiceTests { policy.Enabled = false; - SetupOrg(sutProvider, policy.OrganizationId, new Organization + SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility { Id = policy.OrganizationId, UsePolicies = true, @@ -113,7 +112,7 @@ public class PolicyServiceTests { policy.Enabled = false; - SetupOrg(sutProvider, policy.OrganizationId, new Organization + SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility { Id = policy.OrganizationId, UsePolicies = true, @@ -147,7 +146,7 @@ public class PolicyServiceTests policy.Enabled = false; policy.Type = policyType; - SetupOrg(sutProvider, policy.OrganizationId, new Organization + SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility { Id = policy.OrganizationId, UsePolicies = true, @@ -180,7 +179,7 @@ public class PolicyServiceTests { policy.Enabled = true; - SetupOrg(sutProvider, policy.OrganizationId, new Organization + SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility { Id = policy.OrganizationId, UsePolicies = true, @@ -214,7 +213,7 @@ public class PolicyServiceTests policy.Id = default; policy.Data = null; - SetupOrg(sutProvider, policy.OrganizationId, new Organization + SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility { Id = policy.OrganizationId, UsePolicies = true, @@ -244,7 +243,7 @@ public class PolicyServiceTests { policy.Enabled = true; - SetupOrg(sutProvider, policy.OrganizationId, new Organization + SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility { Id = policy.OrganizationId, UsePolicies = true, @@ -273,16 +272,19 @@ public class PolicyServiceTests [Theory, BitAutoData] public async Task SaveAsync_ExistingPolicy_UpdateTwoFactor( + OrganizationAbility organizationAbility, Organization organization, [AdminConsoleFixtures.Policy(PolicyType.TwoFactorAuthentication)] Policy policy, SutProvider sutProvider) { // If the policy that this is updating isn't enabled then do some work now that the current one is enabled - organization.UsePolicies = true; - policy.OrganizationId = organization.Id; + organizationAbility.UsePolicies = true; + policy.OrganizationId = organizationAbility.Id = organization.Id; - SetupOrg(sutProvider, organization.Id, organization); + SetupOrg(sutProvider, organization.Id, organizationAbility); + + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); sutProvider.GetDependency() .GetByIdAsync(policy.Id) @@ -392,7 +394,7 @@ public class PolicyServiceTests [Theory, BitAutoData] public async Task SaveAsync_EnableTwoFactor_WithoutMasterPasswordOr2FA_ThrowsBadRequest( - Organization organization, + OrganizationAbility organization, [AdminConsoleFixtures.Policy(PolicyType.TwoFactorAuthentication)] Policy policy, SutProvider sutProvider) { @@ -489,11 +491,10 @@ public class PolicyServiceTests { // If the policy that this is updating isn't enabled then do some work now that the current one is enabled - var org = new Organization + var org = new OrganizationAbility() { Id = policy.OrganizationId, UsePolicies = true, - Name = "TEST", }; SetupOrg(sutProvider, policy.OrganizationId, org); @@ -564,7 +565,7 @@ public class PolicyServiceTests AutoEnrollEnabled = autoEnrollEnabled }); - SetupOrg(sutProvider, policy.OrganizationId, new Organization + SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility { Id = policy.OrganizationId, UsePolicies = true, @@ -601,7 +602,7 @@ public class PolicyServiceTests { policy.Enabled = false; - SetupOrg(sutProvider, policy.OrganizationId, new Organization + SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility { Id = policy.OrganizationId, UsePolicies = true, @@ -638,7 +639,7 @@ public class PolicyServiceTests policy.Enabled = true; policy.SetDataModel(new ResetPasswordDataModel()); - SetupOrg(sutProvider, policy.OrganizationId, new Organization + SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility { Id = policy.OrganizationId, UsePolicies = true, @@ -672,7 +673,7 @@ public class PolicyServiceTests { policy.Enabled = false; - SetupOrg(sutProvider, policy.OrganizationId, new Organization + SetupOrg(sutProvider, policy.OrganizationId, new OrganizationAbility { Id = policy.OrganizationId, UsePolicies = true, @@ -797,11 +798,11 @@ public class PolicyServiceTests Assert.True(result); } - private static void SetupOrg(SutProvider sutProvider, Guid organizationId, Organization organization) + private static void SetupOrg(SutProvider sutProvider, Guid organizationId, OrganizationAbility organizationAbility) { - sutProvider.GetDependency() - .GetByIdAsync(organizationId) - .Returns(Task.FromResult(organization)); + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(organizationId) + .Returns(organizationAbility); } private static void SetupUserPolicies(Guid userId, SutProvider sutProvider)