From 4b482f0a34b3811131fb4ab571267e11bd7b4de8 Mon Sep 17 00:00:00 2001 From: Thomas Avery <43214426+Thomas-Avery@users.noreply.github.com> Date: Thu, 7 Sep 2023 17:51:35 -0500 Subject: [PATCH] [SM-918] Enforce project maximums on import (#3253) * Refactor MaxProjectsQuery for multiple adds * Update unit tests * Add max project enforcement to imports --- .../Queries/Projects/MaxProjectsQuery.cs | 4 +- .../Queries/Projects/MaxProjectsQueryTests.cs | 44 +++++++++++++------ .../Controllers/ProjectsController.cs | 4 +- .../SecretsManagerPortingController.cs | 17 ++++++- .../Projects/Interfaces/IMaxProjectsQuery.cs | 2 +- .../Controllers/ProjectsControllerTests.cs | 2 +- 6 files changed, 53 insertions(+), 20 deletions(-) diff --git a/bitwarden_license/src/Commercial.Core/SecretsManager/Queries/Projects/MaxProjectsQuery.cs b/bitwarden_license/src/Commercial.Core/SecretsManager/Queries/Projects/MaxProjectsQuery.cs index 7cbb37f18..fc30c2c6e 100644 --- a/bitwarden_license/src/Commercial.Core/SecretsManager/Queries/Projects/MaxProjectsQuery.cs +++ b/bitwarden_license/src/Commercial.Core/SecretsManager/Queries/Projects/MaxProjectsQuery.cs @@ -20,7 +20,7 @@ public class MaxProjectsQuery : IMaxProjectsQuery _projectRepository = projectRepository; } - public async Task<(short? max, bool? atMax)> GetByOrgIdAsync(Guid organizationId) + public async Task<(short? max, bool? overMax)> GetByOrgIdAsync(Guid organizationId, int projectsToAdd) { var org = await _organizationRepository.GetByIdAsync(organizationId); if (org == null) @@ -37,7 +37,7 @@ public class MaxProjectsQuery : IMaxProjectsQuery if (plan.Type == PlanType.Free) { var projects = await _projectRepository.GetProjectCountByOrganizationIdAsync(organizationId); - return projects >= plan.MaxProjects ? (plan.MaxProjects, true) : (plan.MaxProjects, false); + return projects + projectsToAdd > plan.MaxProjects ? (plan.MaxProjects, true) : (plan.MaxProjects, false); } return (null, null); diff --git a/bitwarden_license/test/Commercial.Core.Test/SecretsManager/Queries/Projects/MaxProjectsQueryTests.cs b/bitwarden_license/test/Commercial.Core.Test/SecretsManager/Queries/Projects/MaxProjectsQueryTests.cs index 6706e0165..79ffb421e 100644 --- a/bitwarden_license/test/Commercial.Core.Test/SecretsManager/Queries/Projects/MaxProjectsQueryTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/SecretsManager/Queries/Projects/MaxProjectsQueryTests.cs @@ -22,7 +22,7 @@ public class MaxProjectsQueryTests { sutProvider.GetDependency().GetByIdAsync(default).ReturnsNull(); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.GetByOrgIdAsync(organizationId)); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.GetByOrgIdAsync(organizationId, 1)); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() .GetProjectCountByOrganizationIdAsync(organizationId); @@ -43,7 +43,7 @@ public class MaxProjectsQueryTests sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); await Assert.ThrowsAsync( - async () => await sutProvider.Sut.GetByOrgIdAsync(organization.Id)); + async () => await sutProvider.Sut.GetByOrgIdAsync(organization.Id, 1)); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() .GetProjectCountByOrganizationIdAsync(organization.Id); @@ -60,7 +60,7 @@ public class MaxProjectsQueryTests organization.PlanType = planType; sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - var (limit, overLimit) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id); + var (limit, overLimit) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id, 1); Assert.Null(limit); Assert.Null(overLimit); @@ -70,13 +70,31 @@ public class MaxProjectsQueryTests } [Theory] - [BitAutoData(PlanType.Free, 0, false)] - [BitAutoData(PlanType.Free, 1, false)] - [BitAutoData(PlanType.Free, 2, false)] - [BitAutoData(PlanType.Free, 3, true)] - [BitAutoData(PlanType.Free, 4, true)] - [BitAutoData(PlanType.Free, 40, true)] - public async Task GetByOrgIdAsync_SmFreePlan_Success(PlanType planType, int projects, bool shouldBeAtMax, + [BitAutoData(PlanType.Free, 0, 1, false)] + [BitAutoData(PlanType.Free, 1, 1, false)] + [BitAutoData(PlanType.Free, 2, 1, false)] + [BitAutoData(PlanType.Free, 3, 1, true)] + [BitAutoData(PlanType.Free, 4, 1, true)] + [BitAutoData(PlanType.Free, 40, 1, true)] + [BitAutoData(PlanType.Free, 0, 2, false)] + [BitAutoData(PlanType.Free, 1, 2, false)] + [BitAutoData(PlanType.Free, 2, 2, true)] + [BitAutoData(PlanType.Free, 3, 2, true)] + [BitAutoData(PlanType.Free, 4, 2, true)] + [BitAutoData(PlanType.Free, 40, 2, true)] + [BitAutoData(PlanType.Free, 0, 3, false)] + [BitAutoData(PlanType.Free, 1, 3, true)] + [BitAutoData(PlanType.Free, 2, 3, true)] + [BitAutoData(PlanType.Free, 3, 3, true)] + [BitAutoData(PlanType.Free, 4, 3, true)] + [BitAutoData(PlanType.Free, 40, 3, true)] + [BitAutoData(PlanType.Free, 0, 4, true)] + [BitAutoData(PlanType.Free, 1, 4, true)] + [BitAutoData(PlanType.Free, 2, 4, true)] + [BitAutoData(PlanType.Free, 3, 4, true)] + [BitAutoData(PlanType.Free, 4, 4, true)] + [BitAutoData(PlanType.Free, 40, 4, true)] + public async Task GetByOrgIdAsync_SmFreePlan__Success(PlanType planType, int projects, int projectsToAdd, bool expectedOverMax, SutProvider sutProvider, Organization organization) { organization.PlanType = planType; @@ -84,12 +102,12 @@ public class MaxProjectsQueryTests sutProvider.GetDependency().GetProjectCountByOrganizationIdAsync(organization.Id) .Returns(projects); - var (max, atMax) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id); + var (max, overMax) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id, projectsToAdd); Assert.NotNull(max); - Assert.NotNull(atMax); + Assert.NotNull(overMax); Assert.Equal(3, max.Value); - Assert.Equal(shouldBeAtMax, atMax); + Assert.Equal(expectedOverMax, overMax); await sutProvider.GetDependency().Received(1) .GetProjectCountByOrganizationIdAsync(organization.Id); diff --git a/src/Api/SecretsManager/Controllers/ProjectsController.cs b/src/Api/SecretsManager/Controllers/ProjectsController.cs index 4f3815ca7..e59918593 100644 --- a/src/Api/SecretsManager/Controllers/ProjectsController.cs +++ b/src/Api/SecretsManager/Controllers/ProjectsController.cs @@ -79,8 +79,8 @@ public class ProjectsController : Controller throw new NotFoundException(); } - var (max, atMax) = await _maxProjectsQuery.GetByOrgIdAsync(organizationId); - if (atMax != null && atMax.Value) + var (max, overMax) = await _maxProjectsQuery.GetByOrgIdAsync(organizationId, 1); + if (overMax != null && overMax.Value) { throw new BadRequestException($"You have reached the maximum number of projects ({max}) for this plan."); } diff --git a/src/Api/SecretsManager/Controllers/SecretsManagerPortingController.cs b/src/Api/SecretsManager/Controllers/SecretsManagerPortingController.cs index ac2e3e5c2..3d26c8f70 100644 --- a/src/Api/SecretsManager/Controllers/SecretsManagerPortingController.cs +++ b/src/Api/SecretsManager/Controllers/SecretsManagerPortingController.cs @@ -4,6 +4,7 @@ using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.SecretsManager.Commands.Porting.Interfaces; +using Bit.Core.SecretsManager.Queries.Projects.Interfaces; using Bit.Core.SecretsManager.Repositories; using Bit.Core.Services; using Bit.Core.Utilities; @@ -19,14 +20,18 @@ public class SecretsManagerPortingController : Controller private readonly ISecretRepository _secretRepository; private readonly IProjectRepository _projectRepository; private readonly IUserService _userService; + private readonly IMaxProjectsQuery _maxProjectsQuery; private readonly IImportCommand _importCommand; private readonly ICurrentContext _currentContext; - public SecretsManagerPortingController(ISecretRepository secretRepository, IProjectRepository projectRepository, IUserService userService, IImportCommand importCommand, ICurrentContext currentContext) + public SecretsManagerPortingController(ISecretRepository secretRepository, IProjectRepository projectRepository, + IUserService userService, IMaxProjectsQuery maxProjectsQuery, IImportCommand importCommand, + ICurrentContext currentContext) { _secretRepository = secretRepository; _projectRepository = projectRepository; _userService = userService; + _maxProjectsQuery = maxProjectsQuery; _importCommand = importCommand; _currentContext = currentContext; } @@ -69,6 +74,16 @@ public class SecretsManagerPortingController : Controller throw new BadRequestException("A secret can only be in one project at a time."); } + var projectsToAdd = importRequest.Projects?.Count(); + if (projectsToAdd is > 0) + { + var (max, overMax) = await _maxProjectsQuery.GetByOrgIdAsync(organizationId, projectsToAdd.Value); + if (overMax != null && overMax.Value) + { + throw new BadRequestException($"The maximum number of projects for this plan is ({max})."); + } + } + await _importCommand.ImportAsync(organizationId, importRequest.ToSMImport()); } } diff --git a/src/Core/SecretsManager/Queries/Projects/Interfaces/IMaxProjectsQuery.cs b/src/Core/SecretsManager/Queries/Projects/Interfaces/IMaxProjectsQuery.cs index e00f5ed67..6bb9a40c7 100644 --- a/src/Core/SecretsManager/Queries/Projects/Interfaces/IMaxProjectsQuery.cs +++ b/src/Core/SecretsManager/Queries/Projects/Interfaces/IMaxProjectsQuery.cs @@ -2,5 +2,5 @@ public interface IMaxProjectsQuery { - Task<(short? max, bool? atMax)> GetByOrgIdAsync(Guid organizationId); + Task<(short? max, bool? overMax)> GetByOrgIdAsync(Guid organizationId, int projectsToAdd); } diff --git a/test/Api.Test/SecretsManager/Controllers/ProjectsControllerTests.cs b/test/Api.Test/SecretsManager/Controllers/ProjectsControllerTests.cs index 32239159a..27aa7ea71 100644 --- a/test/Api.Test/SecretsManager/Controllers/ProjectsControllerTests.cs +++ b/test/Api.Test/SecretsManager/Controllers/ProjectsControllerTests.cs @@ -132,7 +132,7 @@ public class ProjectsControllerTests .AuthorizeAsync(Arg.Any(), data.ToProject(orgId), Arg.Any>()).ReturnsForAnyArgs(AuthorizationResult.Success()); sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(Guid.NewGuid()); - sutProvider.GetDependency().GetByOrgIdAsync(orgId).Returns(((short)3, true)); + sutProvider.GetDependency().GetByOrgIdAsync(orgId, 1).Returns(((short)3, true)); await Assert.ThrowsAsync(() => sutProvider.Sut.CreateAsync(orgId, data));