From fa62b36d444d0b1ba070a2b664b62bf34fc8b286 Mon Sep 17 00:00:00 2001 From: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com> Date: Mon, 24 Jun 2024 11:15:47 -0400 Subject: [PATCH] [AC-2774] Consolidated issues for Consolidated Billing (#4201) * Add BaseProviderController, update some endpoints to ServiceUser permissions * Prevent service user from scaling provider seats above seat minimum * Expand invoice response to include DueDate --- .../Billing/ProviderBillingService.cs | 9 + .../Billing/ProviderBillingServiceTests.cs | 68 ++++++ .../Controllers/BaseProviderController.cs | 50 +++++ .../Controllers/ProviderBillingController.cs | 50 +---- .../Controllers/ProviderClientsController.cs | 34 +-- .../Models/Responses/InvoicesResponse.cs | 2 + .../ProviderBillingControllerTests.cs | 134 +++++++---- .../ProviderClientsControllerTests.cs | 208 +++--------------- test/Api.Test/Billing/Utilities.cs | 47 ++++ 9 files changed, 318 insertions(+), 284 deletions(-) create mode 100644 src/Api/Billing/Controllers/BaseProviderController.cs create mode 100644 test/Api.Test/Billing/Utilities.cs diff --git a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs index 608e3653f..422043f04 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs @@ -13,6 +13,7 @@ using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Services; +using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Models.Business; using Bit.Core.Repositories; @@ -27,6 +28,7 @@ using static Bit.Core.Billing.Utilities; namespace Bit.Commercial.Core.Billing; public class ProviderBillingService( + ICurrentContext currentContext, IFeatureService featureService, IGlobalSettings globalSettings, ILogger logger, @@ -374,6 +376,13 @@ public class ProviderBillingService( else if (currentlyAssignedSeatTotal <= seatMinimum && newlyAssignedSeatTotal > seatMinimum) { + if (!currentContext.ProviderProviderAdmin(provider.Id)) + { + logger.LogError("Service user for provider ({ProviderID}) cannot scale a provider's seat count over the seat minimum", provider.Id); + + throw ContactSupport(); + } + await update( seatMinimum, newlyAssignedSeatTotal); diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs index a35213e35..b5e7ea632 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs @@ -14,6 +14,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Models; using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Services; +using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Business; @@ -185,6 +186,71 @@ public class ProviderBillingServiceTests pPlan => pPlan.AllocatedSeats == 60)); } + [Theory, BitAutoData] + public async Task AssignSeatsToClientOrganization_BelowToAbove_NotProviderAdmin_ContactSupport( + Provider provider, + Organization organization, + SutProvider sutProvider) + { + organization.Seats = 10; + + organization.PlanType = PlanType.TeamsMonthly; + + // Scale up 10 seats + const int seats = 20; + + var providerPlans = new List + { + new() + { + Id = Guid.NewGuid(), + PlanType = PlanType.TeamsMonthly, + ProviderId = provider.Id, + PurchasedSeats = 0, + // 100 minimum + SeatMinimum = 100, + AllocatedSeats = 95 + }, + new() + { + Id = Guid.NewGuid(), + PlanType = PlanType.EnterpriseMonthly, + ProviderId = provider.Id, + PurchasedSeats = 0, + SeatMinimum = 500, + AllocatedSeats = 0 + } + }; + + sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); + + // 95 seats currently assigned with a seat minimum of 100 + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + + var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + + sutProvider.GetDependency().GetManyDetailsByProviderAsync(provider.Id).Returns( + [ + new ProviderOrganizationOrganizationDetails + { + Plan = teamsMonthlyPlan.Name, + Status = OrganizationStatusType.Managed, + Seats = 60 + }, + new ProviderOrganizationOrganizationDetails + { + Plan = teamsMonthlyPlan.Name, + Status = OrganizationStatusType.Managed, + Seats = 35 + } + ]); + + sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(false); + + await ThrowsContactSupportAsync(() => + sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats)); + } + [Theory, BitAutoData] public async Task AssignSeatsToClientOrganization_BelowToAbove_Succeeds( Provider provider, @@ -246,6 +312,8 @@ public class ProviderBillingServiceTests } ]); + sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); + await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); // 95 current + 10 seat scale = 105 seats, 5 above the minimum diff --git a/src/Api/Billing/Controllers/BaseProviderController.cs b/src/Api/Billing/Controllers/BaseProviderController.cs new file mode 100644 index 000000000..24fdf4864 --- /dev/null +++ b/src/Api/Billing/Controllers/BaseProviderController.cs @@ -0,0 +1,50 @@ +using Bit.Core; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Extensions; +using Bit.Core.Context; +using Bit.Core.Services; +using Microsoft.AspNetCore.Mvc; + +namespace Bit.Api.Billing.Controllers; + +public abstract class BaseProviderController( + ICurrentContext currentContext, + IFeatureService featureService, + IProviderRepository providerRepository) : Controller +{ + protected Task<(Provider, IResult)> TryGetBillableProviderForAdminOperation( + Guid providerId) => TryGetBillableProviderAsync(providerId, currentContext.ProviderProviderAdmin); + + protected Task<(Provider, IResult)> TryGetBillableProviderForServiceUserOperation( + Guid providerId) => TryGetBillableProviderAsync(providerId, currentContext.ProviderUser); + + private async Task<(Provider, IResult)> TryGetBillableProviderAsync( + Guid providerId, + Func checkAuthorization) + { + if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) + { + return (null, TypedResults.NotFound()); + } + + var provider = await providerRepository.GetByIdAsync(providerId); + + if (provider == null) + { + return (null, TypedResults.NotFound()); + } + + if (!checkAuthorization(providerId)) + { + return (null, TypedResults.Unauthorized()); + } + + if (!provider.IsBillable()) + { + return (null, TypedResults.Unauthorized()); + } + + return (provider, null); + } +} diff --git a/src/Api/Billing/Controllers/ProviderBillingController.cs b/src/Api/Billing/Controllers/ProviderBillingController.cs index 246bf7360..fda7eddd0 100644 --- a/src/Api/Billing/Controllers/ProviderBillingController.cs +++ b/src/Api/Billing/Controllers/ProviderBillingController.cs @@ -1,10 +1,7 @@ using Bit.Api.Billing.Models.Requests; using Bit.Api.Billing.Models.Responses; -using Bit.Core; -using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Constants; -using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Services; using Bit.Core.Context; @@ -23,12 +20,12 @@ public class ProviderBillingController( IProviderBillingService providerBillingService, IProviderRepository providerRepository, IStripeAdapter stripeAdapter, - ISubscriberService subscriberService) : Controller + ISubscriberService subscriberService) : BaseProviderController(currentContext, featureService, providerRepository) { [HttpGet("invoices")] public async Task GetInvoicesAsync([FromRoute] Guid providerId) { - var (provider, result) = await GetAuthorizedBillableProviderOrResultAsync(providerId); + var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); if (provider == null) { @@ -45,7 +42,7 @@ public class ProviderBillingController( [HttpGet("invoices/{invoiceId}")] public async Task GenerateClientInvoiceReportAsync([FromRoute] Guid providerId, string invoiceId) { - var (provider, result) = await GetAuthorizedBillableProviderOrResultAsync(providerId); + var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); if (provider == null) { @@ -67,7 +64,7 @@ public class ProviderBillingController( [HttpGet("payment-information")] public async Task GetPaymentInformationAsync([FromRoute] Guid providerId) { - var (provider, result) = await GetAuthorizedBillableProviderOrResultAsync(providerId); + var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); if (provider == null) { @@ -89,7 +86,7 @@ public class ProviderBillingController( [HttpGet("payment-method")] public async Task GetPaymentMethodAsync([FromRoute] Guid providerId) { - var (provider, result) = await GetAuthorizedBillableProviderOrResultAsync(providerId); + var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); if (provider == null) { @@ -113,7 +110,7 @@ public class ProviderBillingController( [FromRoute] Guid providerId, [FromBody] TokenizedPaymentMethodRequestBody requestBody) { - var (provider, result) = await GetAuthorizedBillableProviderOrResultAsync(providerId); + var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); if (provider == null) { @@ -141,7 +138,7 @@ public class ProviderBillingController( [FromRoute] Guid providerId, [FromBody] VerifyBankAccountRequestBody requestBody) { - var (provider, result) = await GetAuthorizedBillableProviderOrResultAsync(providerId); + var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); if (provider == null) { @@ -156,7 +153,7 @@ public class ProviderBillingController( [HttpGet("subscription")] public async Task GetSubscriptionAsync([FromRoute] Guid providerId) { - var (provider, result) = await GetAuthorizedBillableProviderOrResultAsync(providerId); + var (provider, result) = await TryGetBillableProviderForServiceUserOperation(providerId); if (provider == null) { @@ -178,7 +175,7 @@ public class ProviderBillingController( [HttpGet("tax-information")] public async Task GetTaxInformationAsync([FromRoute] Guid providerId) { - var (provider, result) = await GetAuthorizedBillableProviderOrResultAsync(providerId); + var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); if (provider == null) { @@ -202,7 +199,7 @@ public class ProviderBillingController( [FromRoute] Guid providerId, [FromBody] TaxInformationRequestBody requestBody) { - var (provider, result) = await GetAuthorizedBillableProviderOrResultAsync(providerId); + var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); if (provider == null) { @@ -222,31 +219,4 @@ public class ProviderBillingController( return TypedResults.Ok(); } - - private async Task<(Provider, IResult)> GetAuthorizedBillableProviderOrResultAsync(Guid providerId) - { - if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) - { - return (null, TypedResults.NotFound()); - } - - var provider = await providerRepository.GetByIdAsync(providerId); - - if (provider == null) - { - return (null, TypedResults.NotFound()); - } - - if (!currentContext.ProviderProviderAdmin(providerId)) - { - return (null, TypedResults.Unauthorized()); - } - - if (!provider.IsBillable()) - { - return (null, TypedResults.Unauthorized()); - } - - return (provider, null); - } } diff --git a/src/Api/Billing/Controllers/ProviderClientsController.cs b/src/Api/Billing/Controllers/ProviderClientsController.cs index ffd74f811..eaf5c054f 100644 --- a/src/Api/Billing/Controllers/ProviderClientsController.cs +++ b/src/Api/Billing/Controllers/ProviderClientsController.cs @@ -1,5 +1,4 @@ using Bit.Api.Billing.Models.Requests; -using Bit.Core; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Services; @@ -22,16 +21,18 @@ public class ProviderClientsController( IProviderOrganizationRepository providerOrganizationRepository, IProviderRepository providerRepository, IProviderService providerService, - IUserService userService) : Controller + IUserService userService) : BaseProviderController(currentContext, featureService, providerRepository) { [HttpPost] public async Task CreateAsync( [FromRoute] Guid providerId, [FromBody] CreateClientOrganizationRequestBody requestBody) { - if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) + var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); + + if (provider == null) { - return TypedResults.NotFound(); + return result; } var user = await userService.GetUserByPrincipalAsync(User); @@ -41,18 +42,6 @@ public class ProviderClientsController( return TypedResults.Unauthorized(); } - if (!currentContext.ManageProviderOrganizations(providerId)) - { - return TypedResults.Unauthorized(); - } - - var provider = await providerRepository.GetByIdAsync(providerId); - - if (provider == null) - { - return TypedResults.NotFound(); - } - var organizationSignup = new OrganizationSignup { Name = requestBody.Name, @@ -103,21 +92,16 @@ public class ProviderClientsController( [FromRoute] Guid providerOrganizationId, [FromBody] UpdateClientOrganizationRequestBody requestBody) { - if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) - { - return TypedResults.NotFound(); - } + var (provider, result) = await TryGetBillableProviderForServiceUserOperation(providerId); - if (!currentContext.ProviderProviderAdmin(providerId)) + if (provider == null) { - return TypedResults.Unauthorized(); + return result; } - var provider = await providerRepository.GetByIdAsync(providerId); - var providerOrganization = await providerOrganizationRepository.GetByIdAsync(providerOrganizationId); - if (provider == null || providerOrganization == null) + if (providerOrganization == null) { return TypedResults.NotFound(); } diff --git a/src/Api/Billing/Models/Responses/InvoicesResponse.cs b/src/Api/Billing/Models/Responses/InvoicesResponse.cs index f5266947d..f9ab2a4ae 100644 --- a/src/Api/Billing/Models/Responses/InvoicesResponse.cs +++ b/src/Api/Billing/Models/Responses/InvoicesResponse.cs @@ -18,6 +18,7 @@ public record InvoiceDTO( string Number, decimal Total, string Status, + DateTime? DueDate, string Url, string PdfUrl) { @@ -27,6 +28,7 @@ public record InvoiceDTO( invoice.Number, invoice.Total / 100M, invoice.Status, + invoice.DueDate, invoice.HostedInvoiceUrl, invoice.InvoicePdf); } diff --git a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs index bf16aa184..11d84f7d7 100644 --- a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs @@ -21,6 +21,8 @@ using NSubstitute.ReturnsExtensions; using Stripe; using Xunit; +using static Bit.Api.Test.Billing.Utilities; + namespace Bit.Api.Test.Billing.Controllers; [ControllerCustomize(typeof(ProviderBillingController))] @@ -34,7 +36,7 @@ public class ProviderBillingControllerTests Provider provider, SutProvider sutProvider) { - ConfigureStableInputs(provider, sutProvider); + ConfigureStableAdminInputs(provider, sutProvider); var invoices = new List { @@ -54,6 +56,7 @@ public class ProviderBillingControllerTests Number = "B", Status = "open", Total = 100000, + DueDate = new DateTime(2024, 7, 1), HostedInvoiceUrl = "https://example.com/invoice/2", InvoicePdf = "https://example.com/invoice/2/pdf" }, @@ -64,6 +67,7 @@ public class ProviderBillingControllerTests Number = "A", Status = "paid", Total = 100000, + DueDate = new DateTime(2024, 6, 1), HostedInvoiceUrl = "https://example.com/invoice/1", InvoicePdf = "https://example.com/invoice/1/pdf" } @@ -86,6 +90,7 @@ public class ProviderBillingControllerTests Assert.Equal(new DateTime(2024, 6, 1), openInvoice.Date); Assert.Equal("B", openInvoice.Number); Assert.Equal(1000, openInvoice.Total); + Assert.Equal(new DateTime(2024, 7, 1), openInvoice.DueDate); Assert.Equal("https://example.com/invoice/2", openInvoice.Url); Assert.Equal("https://example.com/invoice/2/pdf", openInvoice.PdfUrl); @@ -96,6 +101,7 @@ public class ProviderBillingControllerTests Assert.Equal(new DateTime(2024, 5, 1), paidInvoice.Date); Assert.Equal("A", paidInvoice.Number); Assert.Equal(1000, paidInvoice.Total); + Assert.Equal(new DateTime(2024, 6, 1), paidInvoice.DueDate); Assert.Equal("https://example.com/invoice/1", paidInvoice.Url); Assert.Equal("https://example.com/invoice/1/pdf", paidInvoice.PdfUrl); } @@ -110,7 +116,7 @@ public class ProviderBillingControllerTests string invoiceId, SutProvider sutProvider) { - ConfigureStableInputs(provider, sutProvider); + ConfigureStableAdminInputs(provider, sutProvider); var reportContent = "Report"u8.ToArray(); @@ -129,18 +135,85 @@ public class ProviderBillingControllerTests #endregion - #region GetPaymentInformationAsync + #region GetPaymentInformationAsync & TryGetBillableProviderForAdminOperation + + [Theory, BitAutoData] + public async Task GetPaymentInformationAsync_FFDisabled_NotFound( + Guid providerId, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(false); + + var result = await sutProvider.Sut.GetPaymentInformationAsync(providerId); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task GetPaymentInformationAsync_NullProvider_NotFound( + Guid providerId, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().GetByIdAsync(providerId).ReturnsNull(); + + var result = await sutProvider.Sut.GetPaymentInformationAsync(providerId); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task GetPaymentInformationAsync_NotProviderUser_Unauthorized( + Provider provider, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + + sutProvider.GetDependency().ProviderProviderAdmin(provider.Id) + .Returns(false); + + var result = await sutProvider.Sut.GetPaymentInformationAsync(provider.Id); + + Assert.IsType(result); + } + + [Theory, BitAutoData] + public async Task GetPaymentInformationAsync_ProviderNotBillable_Unauthorized( + Provider provider, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + provider.Type = ProviderType.Reseller; + provider.Status = ProviderStatusType.Created; + + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + + sutProvider.GetDependency().ProviderProviderAdmin(provider.Id) + .Returns(true); + + var result = await sutProvider.Sut.GetPaymentInformationAsync(provider.Id); + + Assert.IsType(result); + } [Theory, BitAutoData] public async Task GetPaymentInformation_PaymentInformationNull_NotFound( Provider provider, SutProvider sutProvider) { - ConfigureStableInputs(provider, sutProvider); + ConfigureStableAdminInputs(provider, sutProvider); sutProvider.GetDependency().GetPaymentInformation(provider).ReturnsNull(); - var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id); + var result = await sutProvider.Sut.GetPaymentInformationAsync(provider.Id); Assert.IsType(result); } @@ -150,7 +223,7 @@ public class ProviderBillingControllerTests Provider provider, SutProvider sutProvider) { - ConfigureStableInputs(provider, sutProvider); + ConfigureStableAdminInputs(provider, sutProvider); var maskedPaymentMethod = new MaskedPaymentMethodDTO(PaymentMethodType.Card, "VISA *1234", false); @@ -182,11 +255,11 @@ public class ProviderBillingControllerTests Provider provider, SutProvider sutProvider) { - ConfigureStableInputs(provider, sutProvider); + ConfigureStableAdminInputs(provider, sutProvider); sutProvider.GetDependency().GetPaymentMethod(provider).ReturnsNull(); - var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id); + var result = await sutProvider.Sut.GetPaymentMethodAsync(provider.Id); Assert.IsType(result); } @@ -196,7 +269,7 @@ public class ProviderBillingControllerTests Provider provider, SutProvider sutProvider) { - ConfigureStableInputs(provider, sutProvider); + ConfigureStableAdminInputs(provider, sutProvider); sutProvider.GetDependency().GetPaymentMethod(provider).Returns(new MaskedPaymentMethodDTO( PaymentMethodType.Card, "Description", false)); @@ -214,7 +287,8 @@ public class ProviderBillingControllerTests #endregion - #region GetSubscriptionAsync + #region GetSubscriptionAsync & TryGetBillableProviderForServiceUserOperation + [Theory, BitAutoData] public async Task GetSubscriptionAsync_FFDisabled_NotFound( Guid providerId, @@ -244,7 +318,7 @@ public class ProviderBillingControllerTests } [Theory, BitAutoData] - public async Task GetSubscriptionAsync_NotProviderAdmin_Unauthorized( + public async Task GetSubscriptionAsync_NotProviderUser_Unauthorized( Provider provider, SutProvider sutProvider) { @@ -253,7 +327,7 @@ public class ProviderBillingControllerTests sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - sutProvider.GetDependency().ProviderProviderAdmin(provider.Id) + sutProvider.GetDependency().ProviderUser(provider.Id) .Returns(false); var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id); @@ -274,8 +348,8 @@ public class ProviderBillingControllerTests sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - sutProvider.GetDependency().ProviderProviderAdmin(provider.Id) - .Returns(false); + sutProvider.GetDependency().ProviderUser(provider.Id) + .Returns(true); var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id); @@ -287,7 +361,7 @@ public class ProviderBillingControllerTests Provider provider, SutProvider sutProvider) { - ConfigureStableInputs(provider, sutProvider); + ConfigureStableServiceUserInputs(provider, sutProvider); sutProvider.GetDependency().GetConsolidatedBillingSubscription(provider).ReturnsNull(); @@ -301,7 +375,7 @@ public class ProviderBillingControllerTests Provider provider, SutProvider sutProvider) { - ConfigureStableInputs(provider, sutProvider); + ConfigureStableServiceUserInputs(provider, sutProvider); var configuredProviderPlans = new List { @@ -369,11 +443,11 @@ public class ProviderBillingControllerTests Provider provider, SutProvider sutProvider) { - ConfigureStableInputs(provider, sutProvider); + ConfigureStableAdminInputs(provider, sutProvider); sutProvider.GetDependency().GetTaxInformation(provider).ReturnsNull(); - var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id); + var result = await sutProvider.Sut.GetTaxInformationAsync(provider.Id); Assert.IsType(result); } @@ -383,7 +457,7 @@ public class ProviderBillingControllerTests Provider provider, SutProvider sutProvider) { - ConfigureStableInputs(provider, sutProvider); + ConfigureStableAdminInputs(provider, sutProvider); sutProvider.GetDependency().GetTaxInformation(provider).Returns(new TaxInformationDTO( "US", @@ -419,7 +493,7 @@ public class ProviderBillingControllerTests TokenizedPaymentMethodRequestBody requestBody, SutProvider sutProvider) { - ConfigureStableInputs(provider, sutProvider); + ConfigureStableAdminInputs(provider, sutProvider); await sutProvider.Sut.UpdatePaymentMethodAsync(provider.Id, requestBody); @@ -442,7 +516,7 @@ public class ProviderBillingControllerTests TaxInformationRequestBody requestBody, SutProvider sutProvider) { - ConfigureStableInputs(provider, sutProvider); + ConfigureStableAdminInputs(provider, sutProvider); await sutProvider.Sut.UpdateTaxInformationAsync(provider.Id, requestBody); @@ -468,7 +542,7 @@ public class ProviderBillingControllerTests VerifyBankAccountRequestBody requestBody, SutProvider sutProvider) { - ConfigureStableInputs(provider, sutProvider); + ConfigureStableAdminInputs(provider, sutProvider); var result = await sutProvider.Sut.VerifyBankAccountAsync(provider.Id, requestBody); @@ -480,20 +554,4 @@ public class ProviderBillingControllerTests } #endregion - - private static void ConfigureStableInputs( - Provider provider, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - provider.Type = ProviderType.Msp; - provider.Status = ProviderStatusType.Billable; - - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - - sutProvider.GetDependency().ProviderProviderAdmin(provider.Id) - .Returns(true); - } } diff --git a/test/Api.Test/Billing/Controllers/ProviderClientsControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderClientsControllerTests.cs index fd445cd54..92d03f1e9 100644 --- a/test/Api.Test/Billing/Controllers/ProviderClientsControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/ProviderClientsControllerTests.cs @@ -1,13 +1,11 @@ using System.Security.Claims; using Bit.Api.Billing.Controllers; using Bit.Api.Billing.Models.Requests; -using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Services; -using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Models.Business; using Bit.Core.Repositories; @@ -19,6 +17,8 @@ using NSubstitute; using NSubstitute.ReturnsExtensions; using Xunit; +using static Bit.Api.Test.Billing.Utilities; + namespace Bit.Api.Test.Billing.Controllers; [ControllerCustomize(typeof(ProviderClientsController))] @@ -26,100 +26,38 @@ namespace Bit.Api.Test.Billing.Controllers; public class ProviderClientsControllerTests { #region CreateAsync - [Theory, BitAutoData] - public async Task CreateAsync_FFDisabled_NotFound( - Guid providerId, - CreateClientOrganizationRequestBody requestBody, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(false); - - var result = await sutProvider.Sut.CreateAsync(providerId, requestBody); - - Assert.IsType(result); - } [Theory, BitAutoData] public async Task CreateAsync_NoPrincipalUser_Unauthorized( - Guid providerId, + Provider provider, CreateClientOrganizationRequestBody requestBody, SutProvider sutProvider) { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); + ConfigureStableAdminInputs(provider, sutProvider); sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).ReturnsNull(); - var result = await sutProvider.Sut.CreateAsync(providerId, requestBody); + var result = await sutProvider.Sut.CreateAsync(provider.Id, requestBody); Assert.IsType(result); } - [Theory, BitAutoData] - public async Task CreateAsync_NotProviderAdmin_Unauthorized( - Guid providerId, - CreateClientOrganizationRequestBody requestBody, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(new User()); - - sutProvider.GetDependency().ManageProviderOrganizations(providerId) - .Returns(false); - - var result = await sutProvider.Sut.CreateAsync(providerId, requestBody); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task CreateAsync_NoProvider_NotFound( - Guid providerId, - CreateClientOrganizationRequestBody requestBody, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(new User()); - - sutProvider.GetDependency().ManageProviderOrganizations(providerId) - .Returns(true); - - sutProvider.GetDependency().GetByIdAsync(providerId) - .ReturnsNull(); - - var result = await sutProvider.Sut.CreateAsync(providerId, requestBody); - - Assert.IsType(result); - } - [Theory, BitAutoData] public async Task CreateAsync_MissingClientOrganization_ServerError( - Guid providerId, + Provider provider, CreateClientOrganizationRequestBody requestBody, SutProvider sutProvider) { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); + ConfigureStableAdminInputs(provider, sutProvider); var user = new User(); sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); - sutProvider.GetDependency().ManageProviderOrganizations(providerId) - .Returns(true); - - sutProvider.GetDependency().GetByIdAsync(providerId) - .Returns(new Provider()); - var clientOrganizationId = Guid.NewGuid(); sutProvider.GetDependency().CreateOrganizationAsync( - providerId, + provider.Id, Arg.Any(), requestBody.OwnerEmail, user) @@ -130,37 +68,28 @@ public class ProviderClientsControllerTests sutProvider.GetDependency().GetByIdAsync(clientOrganizationId).ReturnsNull(); - var result = await sutProvider.Sut.CreateAsync(providerId, requestBody); + var result = await sutProvider.Sut.CreateAsync(provider.Id, requestBody); Assert.IsType(result); } [Theory, BitAutoData] public async Task CreateAsync_OK( - Guid providerId, + Provider provider, CreateClientOrganizationRequestBody requestBody, SutProvider sutProvider) { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); + ConfigureStableAdminInputs(provider, sutProvider); var user = new User(); sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()) .Returns(user); - sutProvider.GetDependency().ManageProviderOrganizations(providerId) - .Returns(true); - - var provider = new Provider(); - - sutProvider.GetDependency().GetByIdAsync(providerId) - .Returns(provider); - var clientOrganizationId = Guid.NewGuid(); sutProvider.GetDependency().CreateOrganizationAsync( - providerId, + provider.Id, Arg.Is(signup => signup.Name == requestBody.Name && signup.Plan == requestBody.PlanType && @@ -181,7 +110,7 @@ public class ProviderClientsControllerTests sutProvider.GetDependency().GetByIdAsync(clientOrganizationId) .Returns(clientOrganization); - var result = await sutProvider.Sut.CreateAsync(providerId, requestBody); + var result = await sutProvider.Sut.CreateAsync(provider.Id, requestBody); Assert.IsType(result); @@ -189,105 +118,37 @@ public class ProviderClientsControllerTests provider, clientOrganization); } + #endregion #region UpdateAsync - [Theory, BitAutoData] - public async Task UpdateAsync_FFDisabled_NotFound( - Guid providerId, - Guid providerOrganizationId, - UpdateClientOrganizationRequestBody requestBody, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(false); - - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_NotProviderAdmin_Unauthorized( - Guid providerId, - Guid providerOrganizationId, - UpdateClientOrganizationRequestBody requestBody, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(false); - - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_NoProvider_NotFound( - Guid providerId, - Guid providerOrganizationId, - UpdateClientOrganizationRequestBody requestBody, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(true); - - sutProvider.GetDependency().GetByIdAsync(providerId) - .ReturnsNull(); - - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); - - Assert.IsType(result); - } [Theory, BitAutoData] public async Task UpdateAsync_NoProviderOrganization_NotFound( - Guid providerId, + Provider provider, Guid providerOrganizationId, UpdateClientOrganizationRequestBody requestBody, - Provider provider, SutProvider sutProvider) { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(true); - - sutProvider.GetDependency().GetByIdAsync(providerId) - .Returns(provider); + ConfigureStableServiceUserInputs(provider, sutProvider); sutProvider.GetDependency().GetByIdAsync(providerOrganizationId) .ReturnsNull(); - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); + var result = await sutProvider.Sut.UpdateAsync(provider.Id, providerOrganizationId, requestBody); Assert.IsType(result); } [Theory, BitAutoData] public async Task UpdateAsync_NoOrganization_ServerError( - Guid providerId, + Provider provider, Guid providerOrganizationId, UpdateClientOrganizationRequestBody requestBody, - Provider provider, ProviderOrganization providerOrganization, SutProvider sutProvider) { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(true); - - sutProvider.GetDependency().GetByIdAsync(providerId) - .Returns(provider); + ConfigureStableServiceUserInputs(provider, sutProvider); sutProvider.GetDependency().GetByIdAsync(providerOrganizationId) .Returns(providerOrganization); @@ -295,29 +156,21 @@ public class ProviderClientsControllerTests sutProvider.GetDependency().GetByIdAsync(providerOrganization.OrganizationId) .ReturnsNull(); - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); + var result = await sutProvider.Sut.UpdateAsync(provider.Id, providerOrganizationId, requestBody); Assert.IsType(result); } [Theory, BitAutoData] public async Task UpdateAsync_AssignedSeats_NoContent( - Guid providerId, + Provider provider, Guid providerOrganizationId, UpdateClientOrganizationRequestBody requestBody, - Provider provider, ProviderOrganization providerOrganization, Organization organization, SutProvider sutProvider) { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(true); - - sutProvider.GetDependency().GetByIdAsync(providerId) - .Returns(provider); + ConfigureStableServiceUserInputs(provider, sutProvider); sutProvider.GetDependency().GetByIdAsync(providerOrganizationId) .Returns(providerOrganization); @@ -325,7 +178,7 @@ public class ProviderClientsControllerTests sutProvider.GetDependency().GetByIdAsync(providerOrganization.OrganizationId) .Returns(organization); - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); + var result = await sutProvider.Sut.UpdateAsync(provider.Id, providerOrganizationId, requestBody); await sutProvider.GetDependency().Received(1) .AssignSeatsToClientOrganization( @@ -341,22 +194,14 @@ public class ProviderClientsControllerTests [Theory, BitAutoData] public async Task UpdateAsync_Name_NoContent( - Guid providerId, + Provider provider, Guid providerOrganizationId, UpdateClientOrganizationRequestBody requestBody, - Provider provider, ProviderOrganization providerOrganization, Organization organization, SutProvider sutProvider) { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().ProviderProviderAdmin(providerId) - .Returns(true); - - sutProvider.GetDependency().GetByIdAsync(providerId) - .Returns(provider); + ConfigureStableServiceUserInputs(provider, sutProvider); sutProvider.GetDependency().GetByIdAsync(providerOrganizationId) .Returns(providerOrganization); @@ -366,7 +211,7 @@ public class ProviderClientsControllerTests requestBody.AssignedSeats = organization.Seats!.Value; - var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); + var result = await sutProvider.Sut.UpdateAsync(provider.Id, providerOrganizationId, requestBody); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() .AssignSeatsToClientOrganization( @@ -379,5 +224,6 @@ public class ProviderClientsControllerTests Assert.IsType(result); } + #endregion } diff --git a/test/Api.Test/Billing/Utilities.cs b/test/Api.Test/Billing/Utilities.cs new file mode 100644 index 000000000..7c361b760 --- /dev/null +++ b/test/Api.Test/Billing/Utilities.cs @@ -0,0 +1,47 @@ +using Bit.Api.Billing.Controllers; +using Bit.Core; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Context; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture; +using NSubstitute; + +namespace Bit.Api.Test.Billing; + +public static class Utilities +{ + public static void ConfigureStableAdminInputs( + Provider provider, + SutProvider sutProvider) where T : BaseProviderController + { + ConfigureBaseInputs(provider, sutProvider); + + sutProvider.GetDependency().ProviderProviderAdmin(provider.Id) + .Returns(true); + } + + public static void ConfigureStableServiceUserInputs( + Provider provider, + SutProvider sutProvider) where T : BaseProviderController + { + ConfigureBaseInputs(provider, sutProvider); + + sutProvider.GetDependency().ProviderUser(provider.Id) + .Returns(true); + } + + private static void ConfigureBaseInputs( + Provider provider, + SutProvider sutProvider) where T : BaseProviderController + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + provider.Type = ProviderType.Msp; + provider.Status = ProviderStatusType.Billable; + + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + } +}