diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs index a9592dfcc..dd97aaca0 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs @@ -9,6 +9,7 @@ using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -45,6 +46,7 @@ public class ProviderService : IProviderService private readonly IFeatureService _featureService; private readonly IDataProtectorTokenFactory _providerDeleteTokenDataFactory; private readonly IApplicationCacheService _applicationCacheService; + private readonly IProviderBillingService _providerBillingService; public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository, IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository, @@ -53,7 +55,7 @@ public class ProviderService : IProviderService IOrganizationRepository organizationRepository, GlobalSettings globalSettings, ICurrentContext currentContext, IStripeAdapter stripeAdapter, IFeatureService featureService, IDataProtectorTokenFactory providerDeleteTokenDataFactory, - IApplicationCacheService applicationCacheService) + IApplicationCacheService applicationCacheService, IProviderBillingService providerBillingService) { _providerRepository = providerRepository; _providerUserRepository = providerUserRepository; @@ -71,9 +73,10 @@ public class ProviderService : IProviderService _featureService = featureService; _providerDeleteTokenDataFactory = providerDeleteTokenDataFactory; _applicationCacheService = applicationCacheService; + _providerBillingService = providerBillingService; } - public async Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key) + public async Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo = null) { var owner = await _userService.GetUserByIdAsync(ownerUserId); if (owner == null) @@ -98,8 +101,24 @@ public class ProviderService : IProviderService throw new BadRequestException("Invalid owner."); } - provider.Status = ProviderStatusType.Created; - await _providerRepository.UpsertAsync(provider); + if (!_featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) + { + provider.Status = ProviderStatusType.Created; + await _providerRepository.UpsertAsync(provider); + } + else + { + if (taxInfo == null || string.IsNullOrEmpty(taxInfo.BillingAddressCountry) || string.IsNullOrEmpty(taxInfo.BillingAddressPostalCode)) + { + throw new BadRequestException("Both address and postal code are required to set up your provider."); + } + var customer = await _providerBillingService.SetupCustomer(provider, taxInfo); + provider.GatewayCustomerId = customer.Id; + var subscription = await _providerBillingService.SetupSubscription(provider); + provider.GatewaySubscriptionId = subscription.Id; + provider.Status = ProviderStatusType.Billable; + await _providerRepository.UpsertAsync(provider); + } providerUser.Key = key; await _providerUserRepository.ReplaceAsync(providerUser); diff --git a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs index 2ee7f606d..08c0da08b 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs @@ -9,11 +9,11 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Entities; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; -using Bit.Core.Billing.Models; using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Enums; +using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Repositories; using Bit.Core.Services; @@ -22,7 +22,6 @@ using Bit.Core.Utilities; using CsvHelper; using Microsoft.Extensions.Logging; using Stripe; -using static Bit.Core.Billing.Utilities; namespace Bit.Commercial.Core.Billing; @@ -69,67 +68,6 @@ public class ProviderBillingService( await organizationRepository.ReplaceAsync(organization); } - public async Task CreateCustomer( - Provider provider, - TaxInfo taxInfo) - { - ArgumentNullException.ThrowIfNull(provider); - ArgumentNullException.ThrowIfNull(taxInfo); - - if (string.IsNullOrEmpty(taxInfo.BillingAddressCountry) || - string.IsNullOrEmpty(taxInfo.BillingAddressPostalCode)) - { - logger.LogError("Cannot create customer for provider ({ProviderID}) without both a country and postal code", provider.Id); - - throw ContactSupport(); - } - - var providerDisplayName = provider.DisplayName(); - - var customerCreateOptions = new CustomerCreateOptions - { - Address = new AddressOptions - { - Country = taxInfo.BillingAddressCountry, - PostalCode = taxInfo.BillingAddressPostalCode, - Line1 = taxInfo.BillingAddressLine1, - Line2 = taxInfo.BillingAddressLine2, - City = taxInfo.BillingAddressCity, - State = taxInfo.BillingAddressState - }, - Description = provider.DisplayBusinessName(), - Email = provider.BillingEmail, - InvoiceSettings = new CustomerInvoiceSettingsOptions - { - CustomFields = - [ - new CustomerInvoiceSettingsCustomFieldOptions - { - Name = provider.SubscriberType(), - Value = providerDisplayName.Length <= 30 - ? providerDisplayName - : providerDisplayName[..30] - } - ] - }, - Metadata = new Dictionary - { - { "region", globalSettings.BaseServiceUri.CloudRegion } - }, - TaxIdData = taxInfo.HasTaxId ? - [ - new CustomerTaxIdDataOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber } - ] - : null - }; - - var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); - - provider.GatewayCustomerId = customer.Id; - - await providerRepository.ReplaceAsync(provider); - } - public async Task CreateCustomerForClientOrganization( Provider provider, Organization organization) @@ -204,15 +142,14 @@ public class ProviderBillingService( public async Task GenerateClientInvoiceReport( string invoiceId) { - if (string.IsNullOrEmpty(invoiceId)) - { - throw new ArgumentNullException(nameof(invoiceId)); - } + ArgumentException.ThrowIfNullOrEmpty(invoiceId); var invoiceItems = await providerInvoiceItemRepository.GetByInvoiceId(invoiceId); if (invoiceItems.Count == 0) { + logger.LogError("No provider invoice item records were found for invoice ({InvoiceID})", invoiceId); + return null; } @@ -245,14 +182,14 @@ public class ProviderBillingService( "Could not find provider ({ID}) when retrieving assigned seat total", providerId); - throw ContactSupport(); + throw new BillingException(); } if (provider.Type == ProviderType.Reseller) { logger.LogError("Assigned seats cannot be retrieved for reseller-type provider ({ID})", providerId); - throw ContactSupport("Consolidated billing does not support reseller-type providers"); + throw new BillingException(); } var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId); @@ -264,39 +201,6 @@ public class ProviderBillingService( .Sum(providerOrganization => providerOrganization.Seats ?? 0); } - public async Task GetConsolidatedBillingSubscription( - Provider provider) - { - ArgumentNullException.ThrowIfNull(provider); - - var subscription = await subscriberService.GetSubscription(provider, new SubscriptionGetOptions - { - Expand = ["customer", "test_clock"] - }); - - if (subscription == null) - { - return null; - } - - var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); - - var configuredProviderPlans = providerPlans - .Where(providerPlan => providerPlan.IsConfigured()) - .Select(ConfiguredProviderPlanDTO.From) - .ToList(); - - var taxInformation = await subscriberService.GetTaxInformation(provider); - - var suspension = await GetSuspensionAsync(stripeAdapter, subscription); - - return new ConsolidatedBillingSubscriptionDTO( - configuredProviderPlans, - subscription, - taxInformation, - suspension); - } - public async Task ScaleSeats( Provider provider, PlanType planType, @@ -308,14 +212,14 @@ public class ProviderBillingService( { logger.LogError("Non-MSP provider ({ProviderID}) cannot scale their seats", provider.Id); - throw ContactSupport(); + throw new BillingException(); } if (!planType.SupportsConsolidatedBilling()) { logger.LogError("Cannot scale provider ({ProviderID}) seats for plan type {PlanType} as it does not support consolidated billing", provider.Id, planType.ToString()); - throw ContactSupport(); + throw new BillingException(); } var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); @@ -326,7 +230,7 @@ public class ProviderBillingService( { logger.LogError("Cannot scale provider ({ProviderID}) seats for plan type {PlanType} when their matching provider plan is not configured", provider.Id, planType); - throw ContactSupport(); + throw new BillingException(); } var seatMinimum = providerPlan.SeatMinimum.GetValueOrDefault(0); @@ -362,7 +266,7 @@ public class ProviderBillingService( { logger.LogError("Service user for provider ({ProviderID}) cannot scale a provider's seat count over the seat minimum", provider.Id); - throw ContactSupport(); + throw new BillingException(); } await update( @@ -393,7 +297,64 @@ public class ProviderBillingService( } } - public async Task StartSubscription( + public async Task SetupCustomer( + Provider provider, + TaxInfo taxInfo) + { + ArgumentNullException.ThrowIfNull(provider); + ArgumentNullException.ThrowIfNull(taxInfo); + + if (string.IsNullOrEmpty(taxInfo.BillingAddressCountry) || + string.IsNullOrEmpty(taxInfo.BillingAddressPostalCode)) + { + logger.LogError("Cannot create customer for provider ({ProviderID}) without both a country and postal code", provider.Id); + + throw new BillingException(); + } + + var providerDisplayName = provider.DisplayName(); + + var customerCreateOptions = new CustomerCreateOptions + { + Address = new AddressOptions + { + Country = taxInfo.BillingAddressCountry, + PostalCode = taxInfo.BillingAddressPostalCode, + Line1 = taxInfo.BillingAddressLine1, + Line2 = taxInfo.BillingAddressLine2, + City = taxInfo.BillingAddressCity, + State = taxInfo.BillingAddressState + }, + Description = provider.DisplayBusinessName(), + Email = provider.BillingEmail, + InvoiceSettings = new CustomerInvoiceSettingsOptions + { + CustomFields = + [ + new CustomerInvoiceSettingsCustomFieldOptions + { + Name = provider.SubscriberType(), + Value = providerDisplayName?.Length <= 30 + ? providerDisplayName + : providerDisplayName?[..30] + } + ] + }, + Metadata = new Dictionary + { + { "region", globalSettings.BaseServiceUri.CloudRegion } + }, + TaxIdData = taxInfo.HasTaxId ? + [ + new CustomerTaxIdDataOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber } + ] + : null + }; + + return await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + } + + public async Task SetupSubscription( Provider provider) { ArgumentNullException.ThrowIfNull(provider); @@ -406,7 +367,7 @@ public class ProviderBillingService( { logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured plans", provider.Id); - throw ContactSupport(); + throw new BillingException(); } var subscriptionItemOptionsList = new List(); @@ -418,7 +379,7 @@ public class ProviderBillingService( { logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured Teams plan", provider.Id); - throw ContactSupport(); + throw new BillingException(); } var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); @@ -436,7 +397,7 @@ public class ProviderBillingService( { logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured Enterprise plan", provider.Id); - throw ContactSupport(); + throw new BillingException(); } var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); @@ -465,22 +426,27 @@ public class ProviderBillingService( ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations }; - var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); - - provider.GatewaySubscriptionId = subscription.Id; - - if (subscription.Status == StripeConstants.SubscriptionStatus.Incomplete) + try { - await providerRepository.ReplaceAsync(provider); + var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); - logger.LogError("Started incomplete provider ({ProviderID}) subscription ({SubscriptionID})", provider.Id, subscription.Id); + if (subscription.Status == StripeConstants.SubscriptionStatus.Active) + { + return subscription; + } - throw ContactSupport(); + logger.LogError( + "Newly created provider ({ProviderID}) subscription ({SubscriptionID}) has inactive status: {Status}", + provider.Id, + subscription.Id, + subscription.Status); + + throw new BillingException(); + } + catch (StripeException stripeException) when (stripeException.StripeError?.Code == StripeConstants.ErrorCodes.CustomerTaxLocationInvalid) + { + throw new BadRequestException("Your location wasn't recognized. Please ensure your country and postal code are valid."); } - - provider.Status = ProviderStatusType.Billable; - - await providerRepository.ReplaceAsync(provider); } private Func CurrySeatScalingUpdate( diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs index b34348c09..4beda0060 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs @@ -8,6 +8,7 @@ using Bit.Core.AdminConsole.Models.Business.Provider; using Bit.Core.AdminConsole.Models.Business.Tokenables; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -81,6 +82,51 @@ public class ProviderServiceTests .ReplaceAsync(Arg.Is(pu => pu.UserId == user.Id && pu.ProviderId == provider.Id && pu.Key == key)); } + [Theory, BitAutoData] + public async Task CompleteSetupAsync_ConsolidatedBilling_Success(User user, Provider provider, string key, TaxInfo taxInfo, + [ProviderUser(ProviderUserStatusType.Confirmed, ProviderUserType.ProviderAdmin)] ProviderUser providerUser, + SutProvider sutProvider) + { + providerUser.ProviderId = provider.Id; + providerUser.UserId = user.Id; + var userService = sutProvider.GetDependency(); + userService.GetUserByIdAsync(user.Id).Returns(user); + + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser); + + var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); + var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); + sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + var providerBillingService = sutProvider.GetDependency(); + + var customer = new Customer { Id = "customer_id" }; + providerBillingService.SetupCustomer(provider, taxInfo).Returns(customer); + + var subscription = new Subscription { Id = "subscription_id" }; + providerBillingService.SetupSubscription(provider).Returns(subscription); + + sutProvider.Create(); + + var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + await sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo); + + await sutProvider.GetDependency().Received().UpsertAsync(Arg.Is( + p => + p.GatewayCustomerId == customer.Id && + p.GatewaySubscriptionId == subscription.Id && + p.Status == ProviderStatusType.Billable)); + + await sutProvider.GetDependency().Received() + .ReplaceAsync(Arg.Is(pu => pu.UserId == user.Id && pu.ProviderId == provider.Id && pu.Key == key)); + } + [Theory, BitAutoData] public async Task UpdateAsync_ProviderIdIsInvalid_Throws(Provider provider, SutProvider sutProvider) { diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs index a56a6f5ab..0aa1a164f 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs @@ -11,7 +11,6 @@ using Bit.Core.Billing; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Entities; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models; using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Services; using Bit.Core.Context; @@ -87,7 +86,7 @@ public class ProviderBillingServiceTests { organization.PlanType = PlanType.FamiliesAnnually; - await ThrowsContactSupportAsync(() => + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats)); } @@ -105,7 +104,7 @@ public class ProviderBillingServiceTests new() { Id = Guid.NewGuid(), PlanType = PlanType.TeamsMonthly, ProviderId = provider.Id } }); - await ThrowsContactSupportAsync(() => + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats)); } @@ -247,7 +246,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(false); - await ThrowsContactSupportAsync(() => + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats)); } @@ -493,105 +492,6 @@ public class ProviderBillingServiceTests #endregion - #region CreateCustomer - - [Theory, BitAutoData] - public async Task CreateCustomer_NullProvider_ThrowsArgumentNullException( - SutProvider sutProvider, - TaxInfo taxInfo) => - await Assert.ThrowsAsync(() => sutProvider.Sut.CreateCustomer(null, taxInfo)); - - [Theory, BitAutoData] - public async Task CreateCustomer_NullTaxInfo_ThrowsArgumentNullException( - SutProvider sutProvider, - Provider provider) => - await Assert.ThrowsAsync(() => sutProvider.Sut.CreateCustomer(provider, null)); - - [Theory, BitAutoData] - public async Task CreateCustomer_MissingCountry_ContactSupport( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo) - { - taxInfo.BillingAddressCountry = null; - - await ThrowsContactSupportAsync(() => sutProvider.Sut.CreateCustomer(provider, taxInfo)); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .CustomerGetAsync(Arg.Any(), Arg.Any()); - } - - [Theory, BitAutoData] - public async Task CreateCustomer_MissingPostalCode_ContactSupport( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo) - { - taxInfo.BillingAddressCountry = null; - - await ThrowsContactSupportAsync(() => sutProvider.Sut.CreateCustomer(provider, taxInfo)); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .CustomerGetAsync(Arg.Any(), Arg.Any()); - } - - [Theory, BitAutoData] - public async Task CreateCustomer_Success( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo) - { - provider.Name = "MSP"; - - taxInfo.BillingAddressCountry = "AD"; - - var stripeAdapter = sutProvider.GetDependency(); - - stripeAdapter.CustomerCreateAsync(Arg.Is(o => - o.Address.Country == taxInfo.BillingAddressCountry && - o.Address.PostalCode == taxInfo.BillingAddressPostalCode && - o.Address.Line1 == taxInfo.BillingAddressLine1 && - o.Address.Line2 == taxInfo.BillingAddressLine2 && - o.Address.City == taxInfo.BillingAddressCity && - o.Address.State == taxInfo.BillingAddressState && - o.Description == WebUtility.HtmlDecode(provider.BusinessName) && - o.Email == provider.BillingEmail && - o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && - o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && - o.Metadata["region"] == "" && - o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && - o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) - .Returns(new Customer - { - Id = "customer_id", - Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } - }); - - await sutProvider.Sut.CreateCustomer(provider, taxInfo); - - await stripeAdapter.Received(1).CustomerCreateAsync(Arg.Is(o => - o.Address.Country == taxInfo.BillingAddressCountry && - o.Address.PostalCode == taxInfo.BillingAddressPostalCode && - o.Address.Line1 == taxInfo.BillingAddressLine1 && - o.Address.Line2 == taxInfo.BillingAddressLine2 && - o.Address.City == taxInfo.BillingAddressCity && - o.Address.State == taxInfo.BillingAddressState && - o.Description == WebUtility.HtmlDecode(provider.BusinessName) && - o.Email == provider.BillingEmail && - o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && - o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && - o.Metadata["region"] == "" && - o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && - o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)); - - await sutProvider.GetDependency() - .ReplaceAsync(Arg.Is(p => p.GatewayCustomerId == "customer_id")); - } - - #endregion - #region CreateCustomerForClientOrganization [Theory, BitAutoData] @@ -777,7 +677,7 @@ public class ProviderBillingServiceTests public async Task GetAssignedSeatTotalForPlanOrThrow_NullProvider_ContactSupport( Guid providerId, SutProvider sutProvider) - => await ThrowsContactSupportAsync(() => + => await ThrowsBillingExceptionAsync(() => sutProvider.Sut.GetAssignedSeatTotalForPlanOrThrow(providerId, PlanType.TeamsMonthly)); [Theory, BitAutoData] @@ -790,9 +690,8 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetByIdAsync(providerId).Returns(provider); - await ThrowsContactSupportAsync( - () => sutProvider.Sut.GetAssignedSeatTotalForPlanOrThrow(providerId, PlanType.TeamsMonthly), - internalMessage: "Consolidated billing does not support reseller-type providers"); + await ThrowsBillingExceptionAsync( + () => sutProvider.Sut.GetAssignedSeatTotalForPlanOrThrow(providerId, PlanType.TeamsMonthly)); } [Theory, BitAutoData] @@ -836,197 +735,100 @@ public class ProviderBillingServiceTests #endregion - #region GetConsolidatedBillingSubscription + #region SetupCustomer [Theory, BitAutoData] - public async Task GetConsolidatedBillingSubscription_NullProvider_ThrowsArgumentNullException( - SutProvider sutProvider) => - await Assert.ThrowsAsync(() => sutProvider.Sut.GetConsolidatedBillingSubscription(null)); - - [Theory, BitAutoData] - public async Task GetConsolidatedBillingSubscription_NullSubscription_ReturnsNull( + public async Task SetupCustomer_NullProvider_ThrowsArgumentNullException( SutProvider sutProvider, - Provider provider) + TaxInfo taxInfo) => + await Assert.ThrowsAsync(() => sutProvider.Sut.SetupCustomer(null, taxInfo)); + + [Theory, BitAutoData] + public async Task SetupCustomer_NullTaxInfo_ThrowsArgumentNullException( + SutProvider sutProvider, + Provider provider) => + await Assert.ThrowsAsync(() => sutProvider.Sut.SetupCustomer(provider, null)); + + [Theory, BitAutoData] + public async Task SetupCustomer_MissingCountry_ContactSupport( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) { - var consolidatedBillingSubscription = await sutProvider.Sut.GetConsolidatedBillingSubscription(provider); + taxInfo.BillingAddressCountry = null; - Assert.Null(consolidatedBillingSubscription); + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupCustomer(provider, taxInfo)); - await sutProvider.GetDependency().Received(1).GetSubscription( - provider, - Arg.Is( - options => options.Expand.Count == 2 && options.Expand.First() == "customer" && options.Expand.Last() == "test_clock")); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .CustomerGetAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] - public async Task GetConsolidatedBillingSubscription_Active_NoSuspension_Success( + public async Task SetupCustomer_MissingPostalCode_ContactSupport( SutProvider sutProvider, - Provider provider) + Provider provider, + TaxInfo taxInfo) { - var subscriberService = sutProvider.GetDependency(); + taxInfo.BillingAddressCountry = null; - var subscription = new Subscription - { - Status = "active" - }; + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupCustomer(provider, taxInfo)); - subscriberService.GetSubscription(provider, Arg.Is( - options => options.Expand.Count == 2 && options.Expand.First() == "customer" && options.Expand.Last() == "test_clock")).Returns(subscription); - - var providerPlanRepository = sutProvider.GetDependency(); - - var enterprisePlan = new ProviderPlan - { - Id = Guid.NewGuid(), - ProviderId = provider.Id, - PlanType = PlanType.EnterpriseMonthly, - SeatMinimum = 100, - PurchasedSeats = 0, - AllocatedSeats = 0 - }; - - var teamsPlan = new ProviderPlan - { - Id = Guid.NewGuid(), - ProviderId = provider.Id, - PlanType = PlanType.TeamsMonthly, - SeatMinimum = 50, - PurchasedSeats = 10, - AllocatedSeats = 60 - }; - - var providerPlans = new List { enterprisePlan, teamsPlan, }; - - providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); - - var taxInformation = - new TaxInformationDTO("US", "12345", "123456789", "123 Example St.", null, "Example Town", "NY"); - - subscriberService.GetTaxInformation(provider).Returns(taxInformation); - - var (gotProviderPlans, gotSubscription, gotTaxInformation, gotSuspension) = await sutProvider.Sut.GetConsolidatedBillingSubscription(provider); - - Assert.Equal(2, gotProviderPlans.Count); - - var configuredEnterprisePlan = - gotProviderPlans.FirstOrDefault(configuredPlan => - configuredPlan.PlanType == PlanType.EnterpriseMonthly); - - var configuredTeamsPlan = - gotProviderPlans.FirstOrDefault(configuredPlan => - configuredPlan.PlanType == PlanType.TeamsMonthly); - - Compare(enterprisePlan, configuredEnterprisePlan); - - Compare(teamsPlan, configuredTeamsPlan); - - Assert.Equivalent(subscription, gotSubscription); - - Assert.Equivalent(taxInformation, gotTaxInformation); - - Assert.Null(gotSuspension); - - return; - - void Compare(ProviderPlan providerPlan, ConfiguredProviderPlanDTO configuredProviderPlan) - { - Assert.NotNull(configuredProviderPlan); - Assert.Equal(providerPlan.Id, configuredProviderPlan.Id); - Assert.Equal(providerPlan.ProviderId, configuredProviderPlan.ProviderId); - Assert.Equal(providerPlan.SeatMinimum!.Value, configuredProviderPlan.SeatMinimum); - Assert.Equal(providerPlan.PurchasedSeats!.Value, configuredProviderPlan.PurchasedSeats); - Assert.Equal(providerPlan.AllocatedSeats!.Value, configuredProviderPlan.AssignedSeats); - } + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .CustomerGetAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] - public async Task GetConsolidatedBillingSubscription_PastDue_HasSuspension_Success( + public async Task SetupCustomer_Success( SutProvider sutProvider, - Provider provider) + Provider provider, + TaxInfo taxInfo) { - var subscriberService = sutProvider.GetDependency(); + provider.Name = "MSP"; - var subscription = new Subscription - { - Id = "subscription_id", - Status = "past_due", - CollectionMethod = "send_invoice" - }; - - subscriberService.GetSubscription(provider, Arg.Is( - options => options.Expand.Count == 2 && options.Expand.First() == "customer" && options.Expand.Last() == "test_clock")).Returns(subscription); - - var providerPlanRepository = sutProvider.GetDependency(); - - var enterprisePlan = new ProviderPlan - { - Id = Guid.NewGuid(), - ProviderId = provider.Id, - PlanType = PlanType.EnterpriseMonthly, - SeatMinimum = 100, - PurchasedSeats = 0, - AllocatedSeats = 0 - }; - - var teamsPlan = new ProviderPlan - { - Id = Guid.NewGuid(), - ProviderId = provider.Id, - PlanType = PlanType.TeamsMonthly, - SeatMinimum = 50, - PurchasedSeats = 10, - AllocatedSeats = 60 - }; - - var providerPlans = new List { enterprisePlan, teamsPlan, }; - - providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); - - var taxInformation = - new TaxInformationDTO("US", "12345", "123456789", "123 Example St.", null, "Example Town", "NY"); - - subscriberService.GetTaxInformation(provider).Returns(taxInformation); + taxInfo.BillingAddressCountry = "AD"; var stripeAdapter = sutProvider.GetDependency(); - var openInvoice = new Invoice + var expected = new Customer { - Id = "invoice_id", - Status = "open", - DueDate = new DateTime(2024, 6, 1), - Created = new DateTime(2024, 5, 1), - PeriodEnd = new DateTime(2024, 6, 1) + Id = "customer_id", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } }; - stripeAdapter.InvoiceSearchAsync(Arg.Is(options => - options.Query == $"subscription:'{subscription.Id}' status:'open'")) - .Returns([openInvoice]); + stripeAdapter.CustomerCreateAsync(Arg.Is(o => + o.Address.Country == taxInfo.BillingAddressCountry && + o.Address.PostalCode == taxInfo.BillingAddressPostalCode && + o.Address.Line1 == taxInfo.BillingAddressLine1 && + o.Address.Line2 == taxInfo.BillingAddressLine2 && + o.Address.City == taxInfo.BillingAddressCity && + o.Address.State == taxInfo.BillingAddressState && + o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Email == provider.BillingEmail && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.Metadata["region"] == "" && + o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && + o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + .Returns(expected); - var (gotProviderPlans, gotSubscription, gotTaxInformation, gotSuspension) = await sutProvider.Sut.GetConsolidatedBillingSubscription(provider); + var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo); - Assert.Equal(2, gotProviderPlans.Count); - - Assert.Equivalent(subscription, gotSubscription); - - Assert.Equivalent(taxInformation, gotTaxInformation); - - Assert.NotNull(gotSuspension); - Assert.Equal(openInvoice.DueDate.Value.AddDays(30), gotSuspension.SuspensionDate); - Assert.Equal(openInvoice.PeriodEnd, gotSuspension.UnpaidPeriodEndDate); - Assert.Equal(30, gotSuspension.GracePeriod); + Assert.Equivalent(expected, actual); } #endregion - #region StartSubscription + #region SetupSubscription [Theory, BitAutoData] - public async Task StartSubscription_NullProvider_ThrowsArgumentNullException( + public async Task SetupSubscription_NullProvider_ThrowsArgumentNullException( SutProvider sutProvider) => - await Assert.ThrowsAsync(() => sutProvider.Sut.StartSubscription(null)); + await Assert.ThrowsAsync(() => sutProvider.Sut.SetupSubscription(null)); [Theory, BitAutoData] - public async Task StartSubscription_NoProviderPlans_ContactSupport( + public async Task SetupSubscription_NoProviderPlans_ContactSupport( SutProvider sutProvider, Provider provider) { @@ -1041,7 +843,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetByProviderId(provider.Id) .Returns(new List()); - await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider)); + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupSubscription(provider)); await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() @@ -1049,7 +851,7 @@ public class ProviderBillingServiceTests } [Theory, BitAutoData] - public async Task StartSubscription_NoProviderTeamsPlan_ContactSupport( + public async Task SetupSubscription_NoProviderTeamsPlan_ContactSupport( SutProvider sutProvider, Provider provider) { @@ -1066,7 +868,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetByProviderId(provider.Id) .Returns(providerPlans); - await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider)); + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupSubscription(provider)); await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() @@ -1074,7 +876,7 @@ public class ProviderBillingServiceTests } [Theory, BitAutoData] - public async Task StartSubscription_NoProviderEnterprisePlan_ContactSupport( + public async Task SetupSubscription_NoProviderEnterprisePlan_ContactSupport( SutProvider sutProvider, Provider provider) { @@ -1091,7 +893,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetByProviderId(provider.Id) .Returns(providerPlans); - await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider)); + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupSubscription(provider)); await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() @@ -1099,7 +901,7 @@ public class ProviderBillingServiceTests } [Theory, BitAutoData] - public async Task StartSubscription_SubscriptionIncomplete_ThrowsBillingException( + public async Task SetupSubscription_SubscriptionIncomplete_ThrowsBillingException( SutProvider sutProvider, Provider provider) { @@ -1140,14 +942,11 @@ public class ProviderBillingServiceTests .Returns( new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Incomplete }); - await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider)); - - await sutProvider.GetDependency().Received(1) - .ReplaceAsync(Arg.Is(p => p.GatewaySubscriptionId == "subscription_id")); + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupSubscription(provider)); } [Theory, BitAutoData] - public async Task StartSubscription_Succeeds( + public async Task SetupSubscription_Succeeds( SutProvider sutProvider, Provider provider) { @@ -1187,6 +986,8 @@ public class ProviderBillingServiceTests var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && @@ -1200,16 +1001,11 @@ public class ProviderBillingServiceTests sub.Items.ElementAt(1).Quantity == 100 && sub.Metadata["providerId"] == provider.Id.ToString() && sub.OffSession == true && - sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations)).Returns(new Subscription - { - Id = "subscription_id", - Status = StripeConstants.SubscriptionStatus.Active - }); + sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations)).Returns(expected); - await sutProvider.Sut.StartSubscription(provider); + var actual = await sutProvider.Sut.SetupSubscription(provider); - await sutProvider.GetDependency().Received(1) - .ReplaceAsync(Arg.Is(p => p.GatewaySubscriptionId == "subscription_id")); + Assert.Equivalent(expected, actual); } #endregion diff --git a/src/Api/AdminConsole/Controllers/ProvidersController.cs b/src/Api/AdminConsole/Controllers/ProvidersController.cs index 51cf4c7e3..be119744b 100644 --- a/src/Api/AdminConsole/Controllers/ProvidersController.cs +++ b/src/Api/AdminConsole/Controllers/ProvidersController.cs @@ -1,9 +1,7 @@ using Bit.Api.AdminConsole.Models.Request.Providers; using Bit.Api.AdminConsole.Models.Response.Providers; -using Bit.Core; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; -using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -23,23 +21,15 @@ public class ProvidersController : Controller private readonly IProviderService _providerService; private readonly ICurrentContext _currentContext; private readonly GlobalSettings _globalSettings; - private readonly IFeatureService _featureService; - private readonly ILogger _logger; - private readonly IProviderBillingService _providerBillingService; public ProvidersController(IUserService userService, IProviderRepository providerRepository, - IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings, - IFeatureService featureService, ILogger logger, - IProviderBillingService providerBillingService) + IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings) { _userService = userService; _providerRepository = providerRepository; _providerService = providerService; _currentContext = currentContext; _globalSettings = globalSettings; - _featureService = featureService; - _logger = logger; - _providerBillingService = providerBillingService; } [HttpGet("{id:guid}")] @@ -94,12 +84,8 @@ public class ProvidersController : Controller var userId = _userService.GetProperUserId(User).Value; - var response = - await _providerService.CompleteSetupAsync(model.ToProvider(provider), userId, model.Token, model.Key); - - if (_featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) - { - var taxInfo = new TaxInfo + var taxInfo = model.TaxInfo != null + ? new TaxInfo { BillingAddressCountry = model.TaxInfo.Country, BillingAddressPostalCode = model.TaxInfo.PostalCode, @@ -108,20 +94,12 @@ public class ProvidersController : Controller BillingAddressLine2 = model.TaxInfo.Line2, BillingAddressCity = model.TaxInfo.City, BillingAddressState = model.TaxInfo.State - }; - - try - { - await _providerBillingService.CreateCustomer(provider, taxInfo); - - await _providerBillingService.StartSubscription(provider); } - catch - { - // We don't want to trap the user on the setup page, so we'll let this go through but the provider will be in an un-billable state. - _logger.LogError("Failed to create subscription for provider with ID {ID} during setup", provider.Id); - } - } + : null; + + var response = + await _providerService.CompleteSetupAsync(model.ToProvider(provider), userId, model.Token, model.Key, + taxInfo); return new ProviderResponseModel(response); } diff --git a/src/Api/Billing/Controllers/BaseProviderController.cs b/src/Api/Billing/Controllers/BaseProviderController.cs index 24fdf4864..37d804498 100644 --- a/src/Api/Billing/Controllers/BaseProviderController.cs +++ b/src/Api/Billing/Controllers/BaseProviderController.cs @@ -3,7 +3,9 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Extensions; using Bit.Core.Context; +using Bit.Core.Models.Api; using Bit.Core.Services; +using Microsoft.AspNetCore.Http.HttpResults; using Microsoft.AspNetCore.Mvc; namespace Bit.Api.Billing.Controllers; @@ -11,8 +13,25 @@ namespace Bit.Api.Billing.Controllers; public abstract class BaseProviderController( ICurrentContext currentContext, IFeatureService featureService, - IProviderRepository providerRepository) : Controller + ILogger logger, + IProviderRepository providerRepository, + IUserService userService) : Controller { + protected readonly IUserService UserService = userService; + + protected static NotFound NotFoundResponse() => + TypedResults.NotFound(new ErrorResponseModel("Resource not found.")); + + protected static JsonHttpResult ServerErrorResponse(string errorMessage) => + TypedResults.Json( + new ErrorResponseModel(errorMessage), + statusCode: StatusCodes.Status500InternalServerError); + + protected static JsonHttpResult UnauthorizedResponse() => + TypedResults.Json( + new ErrorResponseModel("Unauthorized."), + statusCode: StatusCodes.Status401Unauthorized); + protected Task<(Provider, IResult)> TryGetBillableProviderForAdminOperation( Guid providerId) => TryGetBillableProviderAsync(providerId, currentContext.ProviderProviderAdmin); @@ -25,26 +44,53 @@ public abstract class BaseProviderController( { if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) { - return (null, TypedResults.NotFound()); + logger.LogError( + "Cannot run Consolidated Billing operation for provider ({ProviderID}) while feature flag is disabled", + providerId); + + return (null, NotFoundResponse()); } var provider = await providerRepository.GetByIdAsync(providerId); if (provider == null) { - return (null, TypedResults.NotFound()); + logger.LogError( + "Cannot find provider ({ProviderID}) for Consolidated Billing operation", + providerId); + + return (null, NotFoundResponse()); } if (!checkAuthorization(providerId)) { - return (null, TypedResults.Unauthorized()); + var user = await UserService.GetUserByPrincipalAsync(User); + + logger.LogError( + "User ({UserID}) is not authorized to perform Consolidated Billing operation for provider ({ProviderID})", + user?.Id, providerId); + + return (null, UnauthorizedResponse()); } if (!provider.IsBillable()) { - return (null, TypedResults.Unauthorized()); + logger.LogError( + "Cannot run Consolidated Billing operation for provider ({ProviderID}) that is not billable", + providerId); + + return (null, UnauthorizedResponse()); } - return (provider, null); + if (provider.IsStripeEnabled()) + { + return (provider, null); + } + + logger.LogError( + "Cannot run Consolidated Billing operation for provider ({ProviderID}) that is missing Stripe configuration", + providerId); + + return (null, ServerErrorResponse("Something went wrong with your request. Please contact support.")); } } diff --git a/src/Api/Billing/Controllers/ProviderBillingController.cs b/src/Api/Billing/Controllers/ProviderBillingController.cs index fda7eddd0..40a1ebdf2 100644 --- a/src/Api/Billing/Controllers/ProviderBillingController.cs +++ b/src/Api/Billing/Controllers/ProviderBillingController.cs @@ -1,15 +1,19 @@ using Bit.Api.Billing.Models.Requests; using Bit.Api.Billing.Models.Responses; using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Services; using Bit.Core.Context; +using Bit.Core.Models.Api; +using Bit.Core.Models.BitStripe; using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; using Stripe; +using static Bit.Core.Billing.Utilities; + namespace Bit.Api.Billing.Controllers; [Route("providers/{providerId:guid}/billing")] @@ -17,10 +21,13 @@ namespace Bit.Api.Billing.Controllers; public class ProviderBillingController( ICurrentContext currentContext, IFeatureService featureService, + ILogger logger, IProviderBillingService providerBillingService, + IProviderPlanRepository providerPlanRepository, IProviderRepository providerRepository, + ISubscriberService subscriberService, IStripeAdapter stripeAdapter, - ISubscriberService subscriberService) : BaseProviderController(currentContext, featureService, providerRepository) + IUserService userService) : BaseProviderController(currentContext, featureService, logger, providerRepository, userService) { [HttpGet("invoices")] public async Task GetInvoicesAsync([FromRoute] Guid providerId) @@ -32,7 +39,10 @@ public class ProviderBillingController( return result; } - var invoices = await subscriberService.GetInvoices(provider); + var invoices = await stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions + { + Customer = provider.GatewayCustomerId + }); var response = InvoicesResponse.From(invoices); @@ -53,7 +63,7 @@ public class ProviderBillingController( if (reportContent == null) { - return TypedResults.NotFound(); + return ServerErrorResponse("We had a problem generating your invoice CSV. Please contact support."); } return TypedResults.File( @@ -61,95 +71,6 @@ public class ProviderBillingController( "text/csv"); } - [HttpGet("payment-information")] - public async Task GetPaymentInformationAsync([FromRoute] Guid providerId) - { - var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); - - if (provider == null) - { - return result; - } - - var paymentInformation = await subscriberService.GetPaymentInformation(provider); - - if (paymentInformation == null) - { - return TypedResults.NotFound(); - } - - var response = PaymentInformationResponse.From(paymentInformation); - - return TypedResults.Ok(response); - } - - [HttpGet("payment-method")] - public async Task GetPaymentMethodAsync([FromRoute] Guid providerId) - { - var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); - - if (provider == null) - { - return result; - } - - var maskedPaymentMethod = await subscriberService.GetPaymentMethod(provider); - - if (maskedPaymentMethod == null) - { - return TypedResults.NotFound(); - } - - var response = MaskedPaymentMethodResponse.From(maskedPaymentMethod); - - return TypedResults.Ok(response); - } - - [HttpPut("payment-method")] - public async Task UpdatePaymentMethodAsync( - [FromRoute] Guid providerId, - [FromBody] TokenizedPaymentMethodRequestBody requestBody) - { - var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); - - if (provider == null) - { - return result; - } - - var tokenizedPaymentMethod = new TokenizedPaymentMethodDTO( - requestBody.Type, - requestBody.Token); - - await subscriberService.UpdatePaymentMethod(provider, tokenizedPaymentMethod); - - await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, - new SubscriptionUpdateOptions - { - CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically - }); - - return TypedResults.Ok(); - } - - [HttpPost] - [Route("payment-method/verify-bank-account")] - public async Task VerifyBankAccountAsync( - [FromRoute] Guid providerId, - [FromBody] VerifyBankAccountRequestBody requestBody) - { - var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); - - if (provider == null) - { - return result; - } - - await subscriberService.VerifyBankAccount(provider, (requestBody.Amount1, requestBody.Amount2)); - - return TypedResults.Ok(); - } - [HttpGet("subscription")] public async Task GetSubscriptionAsync([FromRoute] Guid providerId) { @@ -160,36 +81,20 @@ public class ProviderBillingController( return result; } - var consolidatedBillingSubscription = await providerBillingService.GetConsolidatedBillingSubscription(provider); + var subscription = await stripeAdapter.SubscriptionGetAsync(provider.GatewaySubscriptionId, + new SubscriptionGetOptions { Expand = ["customer.tax_ids", "test_clock"] }); - if (consolidatedBillingSubscription == null) - { - return TypedResults.NotFound(); - } + var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); - var response = ConsolidatedBillingSubscriptionResponse.From(consolidatedBillingSubscription); + var taxInformation = GetTaxInformation(subscription.Customer); - return TypedResults.Ok(response); - } + var subscriptionSuspension = await GetSubscriptionSuspensionAsync(stripeAdapter, subscription); - [HttpGet("tax-information")] - public async Task GetTaxInformationAsync([FromRoute] Guid providerId) - { - var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); - - if (provider == null) - { - return result; - } - - var taxInformation = await subscriberService.GetTaxInformation(provider); - - if (taxInformation == null) - { - return TypedResults.NotFound(); - } - - var response = TaxInformationResponse.From(taxInformation); + var response = ProviderSubscriptionResponse.From( + subscription, + providerPlans, + taxInformation, + subscriptionSuspension); return TypedResults.Ok(response); } @@ -206,7 +111,13 @@ public class ProviderBillingController( return result; } - var taxInformation = new TaxInformationDTO( + if (requestBody is not { Country: not null, PostalCode: not null }) + { + return TypedResults.BadRequest( + new ErrorResponseModel("Country and postal code are required to update your tax information.")); + } + + var taxInformation = new TaxInformation( requestBody.Country, requestBody.PostalCode, requestBody.TaxId, diff --git a/src/Api/Billing/Controllers/ProviderClientsController.cs b/src/Api/Billing/Controllers/ProviderClientsController.cs index eaf5c054f..3fec4570f 100644 --- a/src/Api/Billing/Controllers/ProviderClientsController.cs +++ b/src/Api/Billing/Controllers/ProviderClientsController.cs @@ -15,13 +15,13 @@ namespace Bit.Api.Billing.Controllers; public class ProviderClientsController( ICurrentContext currentContext, IFeatureService featureService, - ILogger logger, + ILogger logger, IOrganizationRepository organizationRepository, IProviderBillingService providerBillingService, IProviderOrganizationRepository providerOrganizationRepository, IProviderRepository providerRepository, IProviderService providerService, - IUserService userService) : BaseProviderController(currentContext, featureService, providerRepository) + IUserService userService) : BaseProviderController(currentContext, featureService, logger, providerRepository, userService) { [HttpPost] public async Task CreateAsync( @@ -35,11 +35,11 @@ public class ProviderClientsController( return result; } - var user = await userService.GetUserByPrincipalAsync(User); + var user = await UserService.GetUserByPrincipalAsync(User); if (user == null) { - return TypedResults.Unauthorized(); + return UnauthorizedResponse(); } var organizationSignup = new OrganizationSignup @@ -63,13 +63,6 @@ public class ProviderClientsController( var clientOrganization = await organizationRepository.GetByIdAsync(providerOrganization.OrganizationId); - if (clientOrganization == null) - { - logger.LogError("Newly created client organization ({ID}) could not be found", providerOrganization.OrganizationId); - - return TypedResults.Problem(); - } - await providerBillingService.ScaleSeats( provider, requestBody.PlanType, @@ -103,18 +96,11 @@ public class ProviderClientsController( if (providerOrganization == null) { - return TypedResults.NotFound(); + return NotFoundResponse(); } var clientOrganization = await organizationRepository.GetByIdAsync(providerOrganization.OrganizationId); - if (clientOrganization == null) - { - logger.LogError("The client organization ({OrganizationID}) represented by provider organization ({ProviderOrganizationID}) could not be found.", providerOrganization.OrganizationId, providerOrganization.Id); - - return TypedResults.Problem(); - } - if (clientOrganization.Seats != requestBody.AssignedSeats) { await providerBillingService.AssignSeatsToClientOrganization( diff --git a/src/Api/Billing/Models/Responses/InvoicesResponse.cs b/src/Api/Billing/Models/Responses/InvoicesResponse.cs index 384b2fdd7..befbb4e53 100644 --- a/src/Api/Billing/Models/Responses/InvoicesResponse.cs +++ b/src/Api/Billing/Models/Responses/InvoicesResponse.cs @@ -3,16 +3,16 @@ namespace Bit.Api.Billing.Models.Responses; public record InvoicesResponse( - List Invoices) + List Invoices) { public static InvoicesResponse From(IEnumerable invoices) => new( invoices .Where(i => i.Status is "open" or "paid" or "uncollectible") .OrderByDescending(i => i.Created) - .Select(InvoiceDTO.From).ToList()); + .Select(InvoiceResponse.From).ToList()); } -public record InvoiceDTO( +public record InvoiceResponse( string Id, DateTime Date, string Number, @@ -21,7 +21,7 @@ public record InvoiceDTO( DateTime? DueDate, string Url) { - public static InvoiceDTO From(Invoice invoice) => new( + public static InvoiceResponse From(Invoice invoice) => new( invoice.Id, invoice.Created, invoice.Number, diff --git a/src/Api/Billing/Models/Responses/PaymentInformationResponse.cs b/src/Api/Billing/Models/Responses/PaymentInformationResponse.cs index 8e532d845..4ccb5889d 100644 --- a/src/Api/Billing/Models/Responses/PaymentInformationResponse.cs +++ b/src/Api/Billing/Models/Responses/PaymentInformationResponse.cs @@ -5,7 +5,7 @@ namespace Bit.Api.Billing.Models.Responses; public record PaymentInformationResponse( long AccountCredit, MaskedPaymentMethodDTO PaymentMethod, - TaxInformationDTO TaxInformation) + TaxInformation TaxInformation) { public static PaymentInformationResponse From(PaymentInformationDTO paymentInformation) => new( diff --git a/src/Api/Billing/Models/Responses/ConsolidatedBillingSubscriptionResponse.cs b/src/Api/Billing/Models/Responses/ProviderSubscriptionResponse.cs similarity index 52% rename from src/Api/Billing/Models/Responses/ConsolidatedBillingSubscriptionResponse.cs rename to src/Api/Billing/Models/Responses/ProviderSubscriptionResponse.cs index 0e1656913..ded2e027c 100644 --- a/src/Api/Billing/Models/Responses/ConsolidatedBillingSubscriptionResponse.cs +++ b/src/Api/Billing/Models/Responses/ProviderSubscriptionResponse.cs @@ -1,43 +1,48 @@ -using Bit.Core.Billing.Models; +using Bit.Core.Billing.Entities; +using Bit.Core.Billing.Models; using Bit.Core.Utilities; +using Stripe; namespace Bit.Api.Billing.Models.Responses; -public record ConsolidatedBillingSubscriptionResponse( +public record ProviderSubscriptionResponse( string Status, DateTime CurrentPeriodEndDate, decimal? DiscountPercentage, string CollectionMethod, IEnumerable Plans, long AccountCredit, - TaxInformationDTO TaxInformation, + TaxInformation TaxInformation, DateTime? CancelAt, - SubscriptionSuspensionDTO Suspension) + SubscriptionSuspension Suspension) { private const string _annualCadence = "Annual"; private const string _monthlyCadence = "Monthly"; - public static ConsolidatedBillingSubscriptionResponse From( - ConsolidatedBillingSubscriptionDTO consolidatedBillingSubscription) + public static ProviderSubscriptionResponse From( + Subscription subscription, + ICollection providerPlans, + TaxInformation taxInformation, + SubscriptionSuspension subscriptionSuspension) { - var (providerPlans, subscription, taxInformation, suspension) = consolidatedBillingSubscription; - var providerPlanResponses = providerPlans - .Select(providerPlan => + .Where(providerPlan => providerPlan.IsConfigured()) + .Select(ConfiguredProviderPlan.From) + .Select(configuredProviderPlan => { - var plan = StaticStore.GetPlan(providerPlan.PlanType); - var cost = (providerPlan.SeatMinimum + providerPlan.PurchasedSeats) * plan.PasswordManager.ProviderPortalSeatPrice; + var plan = StaticStore.GetPlan(configuredProviderPlan.PlanType); + var cost = (configuredProviderPlan.SeatMinimum + configuredProviderPlan.PurchasedSeats) * plan.PasswordManager.ProviderPortalSeatPrice; var cadence = plan.IsAnnual ? _annualCadence : _monthlyCadence; return new ProviderPlanResponse( plan.Name, - providerPlan.SeatMinimum, - providerPlan.PurchasedSeats, - providerPlan.AssignedSeats, + configuredProviderPlan.SeatMinimum, + configuredProviderPlan.PurchasedSeats, + configuredProviderPlan.AssignedSeats, cost, cadence); }); - return new ConsolidatedBillingSubscriptionResponse( + return new ProviderSubscriptionResponse( subscription.Status, subscription.CurrentPeriodEnd, subscription.Customer?.Discount?.Coupon?.PercentOff, @@ -46,7 +51,7 @@ public record ConsolidatedBillingSubscriptionResponse( subscription.Customer?.Balance ?? 0, taxInformation, subscription.CancelAt, - suspension); + subscriptionSuspension); } } diff --git a/src/Api/Billing/Models/Responses/TaxInformationResponse.cs b/src/Api/Billing/Models/Responses/TaxInformationResponse.cs index 53e2de19d..02349d74f 100644 --- a/src/Api/Billing/Models/Responses/TaxInformationResponse.cs +++ b/src/Api/Billing/Models/Responses/TaxInformationResponse.cs @@ -11,7 +11,7 @@ public record TaxInformationResponse( string City, string State) { - public static TaxInformationResponse From(TaxInformationDTO taxInformation) + public static TaxInformationResponse From(TaxInformation taxInformation) => new( taxInformation.Country, taxInformation.PostalCode, diff --git a/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs b/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs index 4dadedeb3..15e8bb295 100644 --- a/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs +++ b/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs @@ -1,4 +1,6 @@ -using Bit.Api.Models.Public.Response; +using System.Text; +using Bit.Api.Models.Public.Response; +using Bit.Core.Billing; using Bit.Core.Exceptions; using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.Filters; @@ -49,18 +51,18 @@ public class ExceptionHandlerFilterAttribute : ExceptionFilterAttribute errorMessage = badRequestException.Message; } } - else if (exception is StripeException stripeException && stripeException?.StripeError?.Type == "card_error") + else if (exception is StripeException { StripeError.Type: "card_error" } stripeCardErrorException) { context.HttpContext.Response.StatusCode = 400; if (_publicApi) { - publicErrorModel = new ErrorResponseModel(stripeException.StripeError.Param, - stripeException.Message); + publicErrorModel = new ErrorResponseModel(stripeCardErrorException.StripeError.Param, + stripeCardErrorException.Message); } else { - internalErrorModel = new InternalApi.ErrorResponseModel(stripeException.StripeError.Param, - stripeException.Message); + internalErrorModel = new InternalApi.ErrorResponseModel(stripeCardErrorException.StripeError.Param, + stripeCardErrorException.Message); } } else if (exception is GatewayException) @@ -68,6 +70,40 @@ public class ExceptionHandlerFilterAttribute : ExceptionFilterAttribute errorMessage = exception.Message; context.HttpContext.Response.StatusCode = 400; } + else if (exception is BillingException billingException) + { + errorMessage = billingException.Response; + context.HttpContext.Response.StatusCode = StatusCodes.Status500InternalServerError; + } + else if (exception is StripeException stripeException) + { + var logger = context.HttpContext.RequestServices.GetRequiredService>(); + + var error = stripeException.Message; + + if (stripeException.StripeError != null) + { + var stringBuilder = new StringBuilder(); + + if (!string.IsNullOrEmpty(stripeException.StripeError.Code)) + { + stringBuilder.Append($"{stripeException.StripeError.Code} | "); + } + + stringBuilder.Append(stripeException.StripeError.Message); + + if (!string.IsNullOrEmpty(stripeException.StripeError.DocUrl)) + { + stringBuilder.Append($" > {stripeException.StripeError.DocUrl}"); + } + + error = stringBuilder.ToString(); + } + + logger.LogError("An unhandled error occurred while communicating with Stripe: {Error}", error); + errorMessage = "Something went wrong with your request. Please contact support."; + context.HttpContext.Response.StatusCode = StatusCodes.Status500InternalServerError; + } else if (exception is NotSupportedException && !string.IsNullOrWhiteSpace(exception.Message)) { errorMessage = exception.Message; diff --git a/src/Core/AdminConsole/Services/IProviderService.cs b/src/Core/AdminConsole/Services/IProviderService.cs index c12bda37d..8999b3cb8 100644 --- a/src/Core/AdminConsole/Services/IProviderService.cs +++ b/src/Core/AdminConsole/Services/IProviderService.cs @@ -7,7 +7,7 @@ namespace Bit.Core.AdminConsole.Services; public interface IProviderService { - Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key); + Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo = null); Task UpdateAsync(Provider provider, bool updateBilling = false); Task> InviteUserAsync(ProviderUserInvite invite); diff --git a/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs b/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs index 26d8dae03..bd3a75766 100644 --- a/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs +++ b/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs @@ -7,7 +7,7 @@ namespace Bit.Core.AdminConsole.Services.NoopImplementations; public class NoopProviderService : IProviderService { - public Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key) => throw new NotImplementedException(); + public Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo = null) => throw new NotImplementedException(); public Task UpdateAsync(Provider provider, bool updateBilling = false) => throw new NotImplementedException(); diff --git a/src/Core/Billing/BillingException.cs b/src/Core/Billing/BillingException.cs index a6944b3ed..cdb3ce6b5 100644 --- a/src/Core/Billing/BillingException.cs +++ b/src/Core/Billing/BillingException.cs @@ -1,9 +1,9 @@ namespace Bit.Core.Billing; public class BillingException( - string clientFriendlyMessage, - string internalMessage = null, - Exception innerException = null) : Exception(internalMessage, innerException) + string response = null, + string message = null, + Exception innerException = null) : Exception(message, innerException) { - public string ClientFriendlyMessage { get; set; } = clientFriendlyMessage; + public string Response { get; } = response ?? "Something went wrong with your request. Please contact support."; } diff --git a/src/Core/Billing/Constants/StripeConstants.cs b/src/Core/Billing/Constants/StripeConstants.cs index aa5737e3d..026638ecd 100644 --- a/src/Core/Billing/Constants/StripeConstants.cs +++ b/src/Core/Billing/Constants/StripeConstants.cs @@ -21,6 +21,11 @@ public static class StripeConstants public const string SecretsManagerStandalone = "sm-standalone"; } + public static class ErrorCodes + { + public const string CustomerTaxLocationInvalid = "customer_tax_location_invalid"; + } + public static class PaymentMethodTypes { public const string Card = "card"; diff --git a/src/Core/Billing/Extensions/BillingExtensions.cs b/src/Core/Billing/Extensions/BillingExtensions.cs index c3ba756ed..d6fa0988b 100644 --- a/src/Core/Billing/Extensions/BillingExtensions.cs +++ b/src/Core/Billing/Extensions/BillingExtensions.cs @@ -2,6 +2,7 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.Billing.Enums; +using Bit.Core.Entities; using Bit.Core.Enums; using Stripe; @@ -24,9 +25,9 @@ public static class BillingExtensions PlanType: PlanType.TeamsMonthly or PlanType.EnterpriseMonthly }; - public static bool IsStripeEnabled(this Organization organization) - => !string.IsNullOrEmpty(organization.GatewayCustomerId) && - !string.IsNullOrEmpty(organization.GatewaySubscriptionId); + public static bool IsStripeEnabled(this ISubscriber subscriber) + => !string.IsNullOrEmpty(subscriber.GatewayCustomerId) && + !string.IsNullOrEmpty(subscriber.GatewaySubscriptionId); public static bool IsUnverifiedBankAccount(this SetupIntent setupIntent) => setupIntent is diff --git a/src/Core/Billing/Models/ConfiguredProviderPlanDTO.cs b/src/Core/Billing/Models/ConfiguredProviderPlan.cs similarity index 78% rename from src/Core/Billing/Models/ConfiguredProviderPlanDTO.cs rename to src/Core/Billing/Models/ConfiguredProviderPlan.cs index d8ada5716..dadb17653 100644 --- a/src/Core/Billing/Models/ConfiguredProviderPlanDTO.cs +++ b/src/Core/Billing/Models/ConfiguredProviderPlan.cs @@ -3,7 +3,7 @@ using Bit.Core.Billing.Enums; namespace Bit.Core.Billing.Models; -public record ConfiguredProviderPlanDTO( +public record ConfiguredProviderPlan( Guid Id, Guid ProviderId, PlanType PlanType, @@ -11,9 +11,9 @@ public record ConfiguredProviderPlanDTO( int PurchasedSeats, int AssignedSeats) { - public static ConfiguredProviderPlanDTO From(ProviderPlan providerPlan) => + public static ConfiguredProviderPlan From(ProviderPlan providerPlan) => providerPlan.IsConfigured() - ? new ConfiguredProviderPlanDTO( + ? new ConfiguredProviderPlan( providerPlan.Id, providerPlan.ProviderId, providerPlan.PlanType, diff --git a/src/Core/Billing/Models/ConsolidatedBillingSubscriptionDTO.cs b/src/Core/Billing/Models/ConsolidatedBillingSubscriptionDTO.cs deleted file mode 100644 index 4b2f46adc..000000000 --- a/src/Core/Billing/Models/ConsolidatedBillingSubscriptionDTO.cs +++ /dev/null @@ -1,9 +0,0 @@ -using Stripe; - -namespace Bit.Core.Billing.Models; - -public record ConsolidatedBillingSubscriptionDTO( - List ProviderPlans, - Subscription Subscription, - TaxInformationDTO TaxInformation, - SubscriptionSuspensionDTO Suspension); diff --git a/src/Core/Billing/Models/PaymentInformationDTO.cs b/src/Core/Billing/Models/PaymentInformationDTO.cs index fe3195b3e..897d6a950 100644 --- a/src/Core/Billing/Models/PaymentInformationDTO.cs +++ b/src/Core/Billing/Models/PaymentInformationDTO.cs @@ -3,4 +3,4 @@ public record PaymentInformationDTO( long AccountCredit, MaskedPaymentMethodDTO PaymentMethod, - TaxInformationDTO TaxInformation); + TaxInformation TaxInformation); diff --git a/src/Core/Billing/Models/SubscriptionSuspensionDTO.cs b/src/Core/Billing/Models/SubscriptionSuspension.cs similarity index 75% rename from src/Core/Billing/Models/SubscriptionSuspensionDTO.cs rename to src/Core/Billing/Models/SubscriptionSuspension.cs index ac0261f2c..889c6e2be 100644 --- a/src/Core/Billing/Models/SubscriptionSuspensionDTO.cs +++ b/src/Core/Billing/Models/SubscriptionSuspension.cs @@ -1,6 +1,6 @@ namespace Bit.Core.Billing.Models; -public record SubscriptionSuspensionDTO( +public record SubscriptionSuspension( DateTime SuspensionDate, DateTime UnpaidPeriodEndDate, int GracePeriod); diff --git a/src/Core/Billing/Models/TaxInformationDTO.cs b/src/Core/Billing/Models/TaxInformation.cs similarity index 99% rename from src/Core/Billing/Models/TaxInformationDTO.cs rename to src/Core/Billing/Models/TaxInformation.cs index a5243b9ea..a2e6e187f 100644 --- a/src/Core/Billing/Models/TaxInformationDTO.cs +++ b/src/Core/Billing/Models/TaxInformation.cs @@ -1,6 +1,6 @@ namespace Bit.Core.Billing.Models; -public record TaxInformationDTO( +public record TaxInformation( string Country, string PostalCode, string TaxId, diff --git a/src/Core/Billing/Services/IProviderBillingService.cs b/src/Core/Billing/Services/IProviderBillingService.cs index 5c215bd71..0d136b503 100644 --- a/src/Core/Billing/Services/IProviderBillingService.cs +++ b/src/Core/Billing/Services/IProviderBillingService.cs @@ -3,8 +3,8 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.Billing.Entities; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models; using Bit.Core.Models.Business; +using Stripe; namespace Bit.Core.Billing.Services; @@ -24,16 +24,6 @@ public interface IProviderBillingService Organization organization, int seats); - /// - /// Create a Stripe for the specified utilizing the provided . - /// - /// The to create a Stripe customer for. - /// The to use for calculating the customer's automatic tax. - /// - Task CreateCustomer( - Provider provider, - TaxInfo taxInfo); - /// /// Create a Stripe for the provided client utilizing /// the address and tax information of its . @@ -65,15 +55,6 @@ public interface IProviderBillingService Guid providerId, PlanType planType); - /// - /// Retrieves the 's consolidated billing subscription, which includes their Stripe subscription and configured provider plans. - /// - /// The provider to retrieve the consolidated billing subscription for. - /// A containing the provider's Stripe and a list of s representing their configured plans. - /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. - Task GetConsolidatedBillingSubscription( - Provider provider); - /// /// Scales the 's seats for the specified using the provided . /// This operation may autoscale the provider's Stripe depending on the 's seat minimum for the @@ -88,11 +69,23 @@ public interface IProviderBillingService int seatAdjustment); /// - /// Starts a Stripe for the given given it has an existing Stripe . + /// For use during the provider setup process, this method creates a Stripe for the specified utilizing the provided . + /// + /// The to create a Stripe customer for. + /// The to use for calculating the customer's automatic tax. + /// The newly created for the . + Task SetupCustomer( + Provider provider, + TaxInfo taxInfo); + + /// + /// For use during the provider setup process, this method starts a Stripe for the given . /// subscriptions will always be started with a for both the /// and plan, and the quantity for each item will be equal the provider's seat minimum for each respective plan. /// /// The provider to create the for. - Task StartSubscription( + /// The newly created for the . + /// This method requires the to already have a linked Stripe via its field. + Task SetupSubscription( Provider provider); } diff --git a/src/Core/Billing/Services/ISubscriberService.cs b/src/Core/Billing/Services/ISubscriberService.cs index 115bd6f32..5183d49be 100644 --- a/src/Core/Billing/Services/ISubscriberService.cs +++ b/src/Core/Billing/Services/ISubscriberService.cs @@ -1,7 +1,6 @@ using Bit.Core.Billing.Models; using Bit.Core.Entities; using Bit.Core.Enums; -using Bit.Core.Models.BitStripe; using Stripe; namespace Bit.Core.Billing.Services; @@ -47,18 +46,6 @@ public interface ISubscriberService ISubscriber subscriber, CustomerGetOptions customerGetOptions = null); - /// - /// Retrieves a list of Stripe objects using the 's property. - /// - /// The subscriber to retrieve the Stripe invoices for. - /// Optional parameters that can be passed to Stripe to expand, modify or filter the invoices. The 's - /// will be automatically attached to the provided options as the parameter. - /// A list of Stripe objects. - /// This method opts for returning an empty list rather than throwing exceptions, making it ideal for surfacing data from API endpoints. - Task> GetInvoices( - ISubscriber subscriber, - StripeInvoiceListOptions invoiceListOptions = null); - /// /// Retrieves the account credit, a masked representation of the default payment method and the tax information for the /// provided . This is essentially a consolidated invocation of the @@ -106,10 +93,10 @@ public interface ISubscriberService /// Retrieves the 's tax information using their Stripe 's . /// /// The subscriber to retrieve the tax information for. - /// A representing the 's tax information. + /// A representing the 's tax information. /// Thrown when the is . /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. - Task GetTaxInformation( + Task GetTaxInformation( ISubscriber subscriber); /// @@ -137,10 +124,10 @@ public interface ISubscriberService /// Updates the tax information for the provided . /// /// The to update the tax information for. - /// A representing the 's updated tax information. + /// A representing the 's updated tax information. Task UpdateTaxInformation( ISubscriber subscriber, - TaxInformationDTO taxInformation); + TaxInformation taxInformation); /// /// Verifies the subscriber's pending bank account using the provided . diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs index 92f245c3b..850e6737f 100644 --- a/src/Core/Billing/Services/Implementations/SubscriberService.cs +++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs @@ -2,7 +2,6 @@ using Bit.Core.Billing.Models; using Bit.Core.Entities; using Bit.Core.Enums; -using Bit.Core.Models.BitStripe; using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; @@ -37,7 +36,7 @@ public class SubscriberService( { logger.LogWarning("Cannot cancel subscription ({ID}) that's already inactive", subscription.Id); - throw ContactSupport(); + throw new BillingException(); } var metadata = new Dictionary @@ -148,7 +147,7 @@ public class SubscriberService( { logger.LogError("Cannot retrieve customer for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewayCustomerId)); - throw ContactSupport(); + throw new BillingException(); } try @@ -163,48 +162,16 @@ public class SubscriberService( logger.LogError("Could not find Stripe customer ({CustomerID}) for subscriber ({SubscriberID})", subscriber.GatewayCustomerId, subscriber.Id); - throw ContactSupport(); + throw new BillingException(); } - catch (StripeException exception) + catch (StripeException stripeException) { logger.LogError("An error occurred while trying to retrieve Stripe customer ({CustomerID}) for subscriber ({SubscriberID}): {Error}", - subscriber.GatewayCustomerId, subscriber.Id, exception.Message); + subscriber.GatewayCustomerId, subscriber.Id, stripeException.Message); - throw ContactSupport("An error occurred while trying to retrieve a Stripe Customer", exception); - } - } - - public async Task> GetInvoices( - ISubscriber subscriber, - StripeInvoiceListOptions invoiceListOptions = null) - { - ArgumentNullException.ThrowIfNull(subscriber); - - if (string.IsNullOrEmpty(subscriber.GatewayCustomerId)) - { - logger.LogError("Cannot retrieve invoices for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewayCustomerId)); - - return []; - } - - try - { - if (invoiceListOptions == null) - { - invoiceListOptions = new StripeInvoiceListOptions { Customer = subscriber.GatewayCustomerId }; - } - else - { - invoiceListOptions.Customer = subscriber.GatewayCustomerId; - } - - return await stripeAdapter.InvoiceListAsync(invoiceListOptions); - } - catch (StripeException exception) - { - logger.LogError("An error occurred while trying to retrieve Stripe invoices for subscriber ({SubscriberID}): {Error}", subscriber.Id, exception.Message); - - return []; + throw new BillingException( + message: "An error occurred while trying to retrieve a Stripe customer", + innerException: stripeException); } } @@ -294,7 +261,7 @@ public class SubscriberService( { logger.LogError("Cannot retrieve subscription for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewaySubscriptionId)); - throw ContactSupport(); + throw new BillingException(); } try @@ -309,18 +276,20 @@ public class SubscriberService( logger.LogError("Could not find Stripe subscription ({SubscriptionID}) for subscriber ({SubscriberID})", subscriber.GatewaySubscriptionId, subscriber.Id); - throw ContactSupport(); + throw new BillingException(); } - catch (StripeException exception) + catch (StripeException stripeException) { logger.LogError("An error occurred while trying to retrieve Stripe subscription ({SubscriptionID}) for subscriber ({SubscriberID}): {Error}", - subscriber.GatewaySubscriptionId, subscriber.Id, exception.Message); + subscriber.GatewaySubscriptionId, subscriber.Id, stripeException.Message); - throw ContactSupport("An error occurred while trying to retrieve a Stripe Subscription", exception); + throw new BillingException( + message: "An error occurred while trying to retrieve a Stripe subscription", + innerException: stripeException); } } - public async Task GetTaxInformation( + public async Task GetTaxInformation( ISubscriber subscriber) { ArgumentNullException.ThrowIfNull(subscriber); @@ -337,7 +306,7 @@ public class SubscriberService( if (string.IsNullOrEmpty(subscriber.GatewayCustomerId)) { - throw ContactSupport(); + throw new BillingException(); } var stripeCustomer = await GetCustomerOrThrow(subscriber, new CustomerGetOptions @@ -353,7 +322,7 @@ public class SubscriberService( { logger.LogError("Failed to retrieve Braintree customer ({ID}) when removing payment method", braintreeCustomerId); - throw ContactSupport(); + throw new BillingException(); } if (braintreeCustomer.DefaultPaymentMethod != null) @@ -369,7 +338,7 @@ public class SubscriberService( logger.LogError("Failed to update payment method for Braintree customer ({ID}) | Message: {Message}", braintreeCustomerId, updateCustomerResult.Message); - throw ContactSupport(); + throw new BillingException(); } var deletePaymentMethodResult = await braintreeGateway.PaymentMethod.DeleteAsync(existingDefaultPaymentMethod.Token); @@ -384,7 +353,7 @@ public class SubscriberService( "Failed to delete Braintree payment method for Customer ({ID}), re-linked payment method. Message: {Message}", braintreeCustomerId, deletePaymentMethodResult.Message); - throw ContactSupport(); + throw new BillingException(); } } else @@ -437,7 +406,7 @@ public class SubscriberService( { logger.LogError("Updated payment method for ({SubscriberID}) must contain a token", subscriber.Id); - throw ContactSupport(); + throw new BillingException(); } // ReSharper disable once SwitchStatementHandlesSomeKnownEnumValuesWithDefault @@ -462,7 +431,7 @@ public class SubscriberService( { logger.LogError("There were more than 1 setup intents for subscriber's ({SubscriberID}) updated payment method", subscriber.Id); - throw ContactSupport(); + throw new BillingException(); } var matchingSetupIntent = setupIntentsForUpdatedPaymentMethod.First(); @@ -551,7 +520,7 @@ public class SubscriberService( { logger.LogError("Failed to retrieve Braintree customer ({BraintreeCustomerId}) when updating payment method for subscriber ({SubscriberID})", braintreeCustomerId, subscriber.Id); - throw ContactSupport(); + throw new BillingException(); } await ReplaceBraintreePaymentMethodAsync(braintreeCustomer, token); @@ -570,14 +539,14 @@ public class SubscriberService( { logger.LogError("Cannot update subscriber's ({SubscriberID}) payment method to type ({PaymentMethodType}) as it is not supported", subscriber.Id, type.ToString()); - throw ContactSupport(); + throw new BillingException(); } } } public async Task UpdateTaxInformation( ISubscriber subscriber, - TaxInformationDTO taxInformation) + TaxInformation taxInformation) { ArgumentNullException.ThrowIfNull(subscriber); ArgumentNullException.ThrowIfNull(taxInformation); @@ -635,7 +604,7 @@ public class SubscriberService( { logger.LogError("No setup intent ID exists to verify for subscriber with ID ({SubscriberID})", subscriber.Id); - throw ContactSupport(); + throw new BillingException(); } var (amount1, amount2) = microdeposits; @@ -706,7 +675,7 @@ public class SubscriberService( logger.LogError("Failed to create Braintree customer for subscriber ({ID})", subscriber.Id); - throw ContactSupport(); + throw new BillingException(); } private async Task GetMaskedPaymentMethodDTOAsync( @@ -751,7 +720,7 @@ public class SubscriberService( return MaskedPaymentMethodDTO.From(setupIntent); } - private static TaxInformationDTO GetTaxInformationDTOFrom( + private static TaxInformation GetTaxInformationDTOFrom( Customer customer) { if (customer.Address == null) @@ -759,7 +728,7 @@ public class SubscriberService( return null; } - return new TaxInformationDTO( + return new TaxInformation( customer.Address.Country, customer.Address.PostalCode, customer.TaxIds?.FirstOrDefault()?.Value, @@ -825,7 +794,7 @@ public class SubscriberService( { logger.LogError("Failed to replace payment method for Braintree customer ({ID}) - Creation of new payment method failed | Error: {Error}", customer.Id, createPaymentMethodResult.Message); - throw ContactSupport(); + throw new BillingException(); } var updateCustomerResult = await braintreeGateway.Customer.UpdateAsync( @@ -839,7 +808,7 @@ public class SubscriberService( await braintreeGateway.PaymentMethod.DeleteAsync(createPaymentMethodResult.Target.Token); - throw ContactSupport(); + throw new BillingException(); } if (existingDefaultPaymentMethod != null) diff --git a/src/Core/Billing/Utilities.cs b/src/Core/Billing/Utilities.cs index 2c5ad8547..b8bc1887b 100644 --- a/src/Core/Billing/Utilities.cs +++ b/src/Core/Billing/Utilities.cs @@ -8,12 +8,7 @@ public static class Utilities { public const string BraintreeCustomerIdKey = "btCustomerId"; - public static BillingException ContactSupport( - string internalMessage = null, - Exception innerException = null) => new("Something went wrong with your request. Please contact support.", - internalMessage, innerException); - - public static async Task GetSuspensionAsync( + public static async Task GetSubscriptionSuspensionAsync( IStripeAdapter stripeAdapter, Subscription subscription) { @@ -49,7 +44,7 @@ public static class Utilities const int gracePeriod = 14; - return new SubscriptionSuspensionDTO( + return new SubscriptionSuspension( firstOverdueInvoice.Created.AddDays(gracePeriod), firstOverdueInvoice.PeriodEnd, gracePeriod); @@ -67,7 +62,7 @@ public static class Utilities const int gracePeriod = 30; - return new SubscriptionSuspensionDTO( + return new SubscriptionSuspension( firstOverdueInvoice.DueDate.Value.AddDays(gracePeriod), firstOverdueInvoice.PeriodEnd, gracePeriod); @@ -75,4 +70,21 @@ public static class Utilities default: return null; } } + + public static TaxInformation GetTaxInformation(Customer customer) + { + if (customer.Address == null) + { + return null; + } + + return new TaxInformation( + customer.Address.Country, + customer.Address.PostalCode, + customer.TaxIds?.FirstOrDefault()?.Value, + customer.Address.Line1, + customer.Address.Line2, + customer.Address.City, + customer.Address.State); + } } diff --git a/src/Core/Models/Business/ProviderSubscriptionUpdate.cs b/src/Core/Models/Business/ProviderSubscriptionUpdate.cs index cffce8a7a..d66013ad1 100644 --- a/src/Core/Models/Business/ProviderSubscriptionUpdate.cs +++ b/src/Core/Models/Business/ProviderSubscriptionUpdate.cs @@ -1,9 +1,8 @@ -using Bit.Core.Billing.Enums; +using Bit.Core.Billing; +using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Stripe; -using static Bit.Core.Billing.Utilities; - namespace Bit.Core.Models.Business; public class ProviderSubscriptionUpdate : SubscriptionUpdate @@ -21,7 +20,8 @@ public class ProviderSubscriptionUpdate : SubscriptionUpdate { if (!planType.SupportsConsolidatedBilling()) { - throw ContactSupport($"Cannot create a {nameof(ProviderSubscriptionUpdate)} for {nameof(PlanType)} that doesn't support consolidated billing"); + throw new BillingException( + message: $"Cannot create a {nameof(ProviderSubscriptionUpdate)} for {nameof(PlanType)} that doesn't support consolidated billing"); } var plan = Utilities.StaticStore.GetPlan(planType); diff --git a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs index acd6721a5..e596b5e3d 100644 --- a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs @@ -6,15 +6,19 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Entities; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Repositories; using Bit.Core.Billing.Services; using Bit.Core.Context; -using Bit.Core.Enums; +using Bit.Core.Models.Api; +using Bit.Core.Models.BitStripe; using Bit.Core.Services; using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.HttpResults; using NSubstitute; using NSubstitute.ReturnsExtensions; @@ -29,7 +33,74 @@ namespace Bit.Api.Test.Billing.Controllers; [SutProviderCustomize] public class ProviderBillingControllerTests { - #region GetInvoicesAsync + #region GetInvoicesAsync & TryGetBillableProviderForAdminOperations + + [Theory, BitAutoData] + public async Task GetInvoicesAsync_FFDisabled_NotFound( + Guid providerId, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(false); + + var result = await sutProvider.Sut.GetInvoicesAsync(providerId); + + AssertNotFound(result); + } + + [Theory, BitAutoData] + public async Task GetInvoicesAsync_NullProvider_NotFound( + Guid providerId, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().GetByIdAsync(providerId).ReturnsNull(); + + var result = await sutProvider.Sut.GetInvoicesAsync(providerId); + + AssertNotFound(result); + } + + [Theory, BitAutoData] + public async Task GetInvoicesAsync_NotProviderUser_Unauthorized( + Provider provider, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + + sutProvider.GetDependency().ProviderProviderAdmin(provider.Id) + .Returns(false); + + var result = await sutProvider.Sut.GetInvoicesAsync(provider.Id); + + AssertUnauthorized(result); + } + + [Theory, BitAutoData] + public async Task GetInvoicesAsync_ProviderNotBillable_Unauthorized( + Provider provider, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); + + provider.Type = ProviderType.Reseller; + provider.Status = ProviderStatusType.Created; + + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + + sutProvider.GetDependency().ProviderProviderAdmin(provider.Id) + .Returns(true); + + var result = await sutProvider.Sut.GetInvoicesAsync(provider.Id); + + AssertUnauthorized(result); + } [Theory, BitAutoData] public async Task GetInvoices_Ok( @@ -73,7 +144,9 @@ public class ProviderBillingControllerTests } }; - sutProvider.GetDependency().GetInvoices(provider).Returns(invoices); + sutProvider.GetDependency().InvoiceListAsync(Arg.Is( + options => + options.Customer == provider.GatewayCustomerId)).Returns(invoices); var result = await sutProvider.Sut.GetInvoicesAsync(provider.Id); @@ -108,6 +181,27 @@ public class ProviderBillingControllerTests #region GenerateClientInvoiceReportAsync + [Theory, BitAutoData] + public async Task GenerateClientInvoiceReportAsync_NullReportContent_ServerError( + Provider provider, + string invoiceId, + SutProvider sutProvider) + { + ConfigureStableAdminInputs(provider, sutProvider); + + sutProvider.GetDependency().GenerateClientInvoiceReport(invoiceId) + .ReturnsNull(); + + var result = await sutProvider.Sut.GenerateClientInvoiceReportAsync(provider.Id, invoiceId); + + Assert.IsType>(result); + + var response = (JsonHttpResult)result; + + Assert.Equal(StatusCodes.Status500InternalServerError, response.StatusCode); + Assert.Equal("We had a problem generating your invoice CSV. Please contact support.", response.Value.Message); + } + [Theory, BitAutoData] public async Task GenerateClientInvoiceReportAsync_Ok( Provider provider, @@ -133,158 +227,6 @@ public class ProviderBillingControllerTests #endregion - #region GetPaymentInformationAsync & TryGetBillableProviderForAdminOperation - - [Theory, BitAutoData] - public async Task GetPaymentInformationAsync_FFDisabled_NotFound( - Guid providerId, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(false); - - var result = await sutProvider.Sut.GetPaymentInformationAsync(providerId); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetPaymentInformationAsync_NullProvider_NotFound( - Guid providerId, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().GetByIdAsync(providerId).ReturnsNull(); - - var result = await sutProvider.Sut.GetPaymentInformationAsync(providerId); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetPaymentInformationAsync_NotProviderUser_Unauthorized( - Provider provider, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - - sutProvider.GetDependency().ProviderProviderAdmin(provider.Id) - .Returns(false); - - var result = await sutProvider.Sut.GetPaymentInformationAsync(provider.Id); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetPaymentInformationAsync_ProviderNotBillable_Unauthorized( - Provider provider, - SutProvider sutProvider) - { - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) - .Returns(true); - - provider.Type = ProviderType.Reseller; - provider.Status = ProviderStatusType.Created; - - sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); - - sutProvider.GetDependency().ProviderProviderAdmin(provider.Id) - .Returns(true); - - var result = await sutProvider.Sut.GetPaymentInformationAsync(provider.Id); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetPaymentInformation_PaymentInformationNull_NotFound( - Provider provider, - SutProvider sutProvider) - { - ConfigureStableAdminInputs(provider, sutProvider); - - sutProvider.GetDependency().GetPaymentInformation(provider).ReturnsNull(); - - var result = await sutProvider.Sut.GetPaymentInformationAsync(provider.Id); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetPaymentInformation_Ok( - Provider provider, - SutProvider sutProvider) - { - ConfigureStableAdminInputs(provider, sutProvider); - - var maskedPaymentMethod = new MaskedPaymentMethodDTO(PaymentMethodType.Card, "VISA *1234", false); - - var taxInformation = - new TaxInformationDTO("US", "12345", "123456789", "123 Example St.", null, "Example Town", "NY"); - - sutProvider.GetDependency().GetPaymentInformation(provider).Returns(new PaymentInformationDTO( - 100, - maskedPaymentMethod, - taxInformation)); - - var result = await sutProvider.Sut.GetPaymentInformationAsync(provider.Id); - - Assert.IsType>(result); - - var response = ((Ok)result).Value; - - Assert.Equal(100, response.AccountCredit); - Assert.Equal(maskedPaymentMethod.Description, response.PaymentMethod.Description); - Assert.Equal(taxInformation.TaxId, response.TaxInformation.TaxId); - } - - #endregion - - #region GetPaymentMethodAsync - - [Theory, BitAutoData] - public async Task GetPaymentMethod_PaymentMethodNull_NotFound( - Provider provider, - SutProvider sutProvider) - { - ConfigureStableAdminInputs(provider, sutProvider); - - sutProvider.GetDependency().GetPaymentMethod(provider).ReturnsNull(); - - var result = await sutProvider.Sut.GetPaymentMethodAsync(provider.Id); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetPaymentMethod_Ok( - Provider provider, - SutProvider sutProvider) - { - ConfigureStableAdminInputs(provider, sutProvider); - - sutProvider.GetDependency().GetPaymentMethod(provider).Returns(new MaskedPaymentMethodDTO( - PaymentMethodType.Card, "Description", false)); - - var result = await sutProvider.Sut.GetPaymentMethodAsync(provider.Id); - - Assert.IsType>(result); - - var response = ((Ok)result).Value; - - Assert.Equal(PaymentMethodType.Card, response.Type); - Assert.Equal("Description", response.Description); - Assert.False(response.NeedsVerification); - } - - #endregion - #region GetSubscriptionAsync & TryGetBillableProviderForServiceUserOperation [Theory, BitAutoData] @@ -297,7 +239,7 @@ public class ProviderBillingControllerTests var result = await sutProvider.Sut.GetSubscriptionAsync(providerId); - Assert.IsType(result); + AssertNotFound(result); } [Theory, BitAutoData] @@ -312,7 +254,7 @@ public class ProviderBillingControllerTests var result = await sutProvider.Sut.GetSubscriptionAsync(providerId); - Assert.IsType(result); + AssertNotFound(result); } [Theory, BitAutoData] @@ -330,7 +272,7 @@ public class ProviderBillingControllerTests var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id); - Assert.IsType(result); + AssertUnauthorized(result); } [Theory, BitAutoData] @@ -351,21 +293,7 @@ public class ProviderBillingControllerTests var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id); - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionAsync_NullConsolidatedBillingSubscription_NotFound( - Provider provider, - SutProvider sutProvider) - { - ConfigureStableServiceUserInputs(provider, sutProvider); - - sutProvider.GetDependency().GetConsolidatedBillingSubscription(provider).ReturnsNull(); - - var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id); - - Assert.IsType(result); + AssertUnauthorized(result); } [Theory, BitAutoData] @@ -375,51 +303,83 @@ public class ProviderBillingControllerTests { ConfigureStableServiceUserInputs(provider, sutProvider); - var configuredProviderPlans = new List - { - new (Guid.NewGuid(), provider.Id, PlanType.TeamsMonthly, 50, 10, 30), - new (Guid.NewGuid(), provider.Id , PlanType.EnterpriseMonthly, 100, 0, 90) - }; + var stripeAdapter = sutProvider.GetDependency(); + + var (thisYear, thisMonth, _) = DateTime.UtcNow; + var daysInThisMonth = DateTime.DaysInMonth(thisYear, thisMonth); var subscription = new Subscription { - Status = "unpaid", - CurrentPeriodEnd = new DateTime(2024, 6, 30), + CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, + CurrentPeriodEnd = new DateTime(thisYear, thisMonth, daysInThisMonth), Customer = new Customer { - Balance = 100000, - Discount = new Discount + Address = new Address { - Coupon = new Coupon - { - PercentOff = 10 - } - } + Country = "US", + PostalCode = "12345", + Line1 = "123 Example St.", + Line2 = "Unit 1", + City = "Example Town", + State = "NY" + }, + Balance = 100000, + Discount = new Discount { Coupon = new Coupon { PercentOff = 10 } }, + TaxIds = new StripeList { Data = [new TaxId { Value = "123456789" }] } + }, + Status = "unpaid", + }; + + stripeAdapter.SubscriptionGetAsync(provider.GatewaySubscriptionId, Arg.Is( + options => + options.Expand.Contains("customer.tax_ids") && + options.Expand.Contains("test_clock"))).Returns(subscription); + + var lastMonth = thisMonth - 1; + var daysInLastMonth = DateTime.DaysInMonth(thisYear, lastMonth); + + var overdueInvoice = new Invoice + { + Id = "invoice_id", + Status = "open", + Created = new DateTime(thisYear, lastMonth, 1), + PeriodEnd = new DateTime(thisYear, lastMonth, daysInLastMonth), + Attempted = true + }; + + stripeAdapter.InvoiceSearchAsync(Arg.Is( + options => options.Query == $"subscription:'{subscription.Id}' status:'open'")) + .Returns([overdueInvoice]); + + var providerPlans = new List + { + new () + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.TeamsMonthly, + SeatMinimum = 50, + PurchasedSeats = 10, + AllocatedSeats = 60 + }, + new () + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.EnterpriseMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 90 } }; - var taxInformation = - new TaxInformationDTO("US", "12345", "123456789", "123 Example St.", null, "Example Town", "NY"); - - var suspension = new SubscriptionSuspensionDTO( - new DateTime(2024, 7, 30), - new DateTime(2024, 5, 30), - 30); - - var consolidatedBillingSubscription = new ConsolidatedBillingSubscriptionDTO( - configuredProviderPlans, - subscription, - taxInformation, - suspension); - - sutProvider.GetDependency().GetConsolidatedBillingSubscription(provider) - .Returns(consolidatedBillingSubscription); + sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id); - Assert.IsType>(result); + Assert.IsType>(result); - var response = ((Ok)result).Value; + var response = ((Ok)result).Value; Assert.Equal(subscription.Status, response.Status); Assert.Equal(subscription.CurrentPeriodEnd, response.CurrentPeriodEndDate); @@ -431,7 +391,7 @@ public class ProviderBillingControllerTests Assert.NotNull(providerTeamsPlan); Assert.Equal(50, providerTeamsPlan.SeatMinimum); Assert.Equal(10, providerTeamsPlan.PurchasedSeats); - Assert.Equal(30, providerTeamsPlan.AssignedSeats); + Assert.Equal(60, providerTeamsPlan.AssignedSeats); Assert.Equal(60 * teamsPlan.PasswordManager.ProviderPortalSeatPrice, providerTeamsPlan.Cost); Assert.Equal("Monthly", providerTeamsPlan.Cadence); @@ -445,87 +405,46 @@ public class ProviderBillingControllerTests Assert.Equal("Monthly", providerEnterprisePlan.Cadence); Assert.Equal(100000, response.AccountCredit); - Assert.Equal(taxInformation, response.TaxInformation); + + var customer = subscription.Customer; + Assert.Equal(customer.Address.Country, response.TaxInformation.Country); + Assert.Equal(customer.Address.PostalCode, response.TaxInformation.PostalCode); + Assert.Equal(customer.TaxIds.First().Value, response.TaxInformation.TaxId); + Assert.Equal(customer.Address.Line1, response.TaxInformation.Line1); + Assert.Equal(customer.Address.Line2, response.TaxInformation.Line2); + Assert.Equal(customer.Address.City, response.TaxInformation.City); + Assert.Equal(customer.Address.State, response.TaxInformation.State); + Assert.Null(response.CancelAt); - Assert.Equal(suspension, response.Suspension); - } - #endregion - - #region GetTaxInformationAsync - - [Theory, BitAutoData] - public async Task GetTaxInformation_TaxInformationNull_NotFound( - Provider provider, - SutProvider sutProvider) - { - ConfigureStableAdminInputs(provider, sutProvider); - - sutProvider.GetDependency().GetTaxInformation(provider).ReturnsNull(); - - var result = await sutProvider.Sut.GetTaxInformationAsync(provider.Id); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task GetTaxInformation_Ok( - Provider provider, - SutProvider sutProvider) - { - ConfigureStableAdminInputs(provider, sutProvider); - - sutProvider.GetDependency().GetTaxInformation(provider).Returns(new TaxInformationDTO( - "US", - "12345", - "123456789", - "123 Example St.", - null, - "Example Town", - "NY")); - - var result = await sutProvider.Sut.GetTaxInformationAsync(provider.Id); - - Assert.IsType>(result); - - var response = ((Ok)result).Value; - - Assert.Equal("US", response.Country); - Assert.Equal("12345", response.PostalCode); - Assert.Equal("123456789", response.TaxId); - Assert.Equal("123 Example St.", response.Line1); - Assert.Null(response.Line2); - Assert.Equal("Example Town", response.City); - Assert.Equal("NY", response.State); - } - - #endregion - - #region UpdatePaymentMethodAsync - - [Theory, BitAutoData] - public async Task UpdatePaymentMethod_Ok( - Provider provider, - TokenizedPaymentMethodRequestBody requestBody, - SutProvider sutProvider) - { - ConfigureStableAdminInputs(provider, sutProvider); - - await sutProvider.Sut.UpdatePaymentMethodAsync(provider.Id, requestBody); - - await sutProvider.GetDependency().Received(1).UpdatePaymentMethod( - provider, Arg.Is( - options => options.Type == requestBody.Type && options.Token == requestBody.Token)); - - await sutProvider.GetDependency().Received(1).SubscriptionUpdateAsync( - provider.GatewaySubscriptionId, Arg.Is( - options => options.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically)); + Assert.Equal(overdueInvoice.Created.AddDays(14), response.Suspension.SuspensionDate); + Assert.Equal(overdueInvoice.PeriodEnd, response.Suspension.UnpaidPeriodEndDate); + Assert.Equal(14, response.Suspension.GracePeriod); } #endregion #region UpdateTaxInformationAsync + [Theory, BitAutoData] + public async Task UpdateTaxInformation_NoCountry_BadRequest( + Provider provider, + TaxInformationRequestBody requestBody, + SutProvider sutProvider) + { + ConfigureStableAdminInputs(provider, sutProvider); + + requestBody.Country = null; + + var result = await sutProvider.Sut.UpdateTaxInformationAsync(provider.Id, requestBody); + + Assert.IsType>(result); + + var response = (BadRequest)result; + + Assert.Equal("Country and postal code are required to update your tax information.", response.Value.Message); + } + [Theory, BitAutoData] public async Task UpdateTaxInformation_Ok( Provider provider, @@ -537,7 +456,7 @@ public class ProviderBillingControllerTests await sutProvider.Sut.UpdateTaxInformationAsync(provider.Id, requestBody); await sutProvider.GetDependency().Received(1).UpdateTaxInformation( - provider, Arg.Is( + provider, Arg.Is( options => options.Country == requestBody.Country && options.PostalCode == requestBody.PostalCode && @@ -549,25 +468,4 @@ public class ProviderBillingControllerTests } #endregion - - #region VerifyBankAccount - - [Theory, BitAutoData] - public async Task VerifyBankAccount_Ok( - Provider provider, - VerifyBankAccountRequestBody requestBody, - SutProvider sutProvider) - { - ConfigureStableAdminInputs(provider, sutProvider); - - var result = await sutProvider.Sut.VerifyBankAccountAsync(provider.Id, requestBody); - - Assert.IsType(result); - - await sutProvider.GetDependency().Received(1).VerifyBankAccount( - provider, - (requestBody.Amount1, requestBody.Amount2)); - } - - #endregion } diff --git a/test/Api.Test/Billing/Controllers/ProviderClientsControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderClientsControllerTests.cs index 92d03f1e9..d0a79e15c 100644 --- a/test/Api.Test/Billing/Controllers/ProviderClientsControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/ProviderClientsControllerTests.cs @@ -39,38 +39,7 @@ public class ProviderClientsControllerTests var result = await sutProvider.Sut.CreateAsync(provider.Id, requestBody); - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task CreateAsync_MissingClientOrganization_ServerError( - Provider provider, - CreateClientOrganizationRequestBody requestBody, - SutProvider sutProvider) - { - ConfigureStableAdminInputs(provider, sutProvider); - - var user = new User(); - - sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); - - var clientOrganizationId = Guid.NewGuid(); - - sutProvider.GetDependency().CreateOrganizationAsync( - provider.Id, - Arg.Any(), - requestBody.OwnerEmail, - user) - .Returns(new ProviderOrganization - { - OrganizationId = clientOrganizationId - }); - - sutProvider.GetDependency().GetByIdAsync(clientOrganizationId).ReturnsNull(); - - var result = await sutProvider.Sut.CreateAsync(provider.Id, requestBody); - - Assert.IsType(result); + AssertUnauthorized(result); } [Theory, BitAutoData] @@ -137,32 +106,11 @@ public class ProviderClientsControllerTests var result = await sutProvider.Sut.UpdateAsync(provider.Id, providerOrganizationId, requestBody); - Assert.IsType(result); + AssertNotFound(result); } [Theory, BitAutoData] - public async Task UpdateAsync_NoOrganization_ServerError( - Provider provider, - Guid providerOrganizationId, - UpdateClientOrganizationRequestBody requestBody, - ProviderOrganization providerOrganization, - SutProvider sutProvider) - { - ConfigureStableServiceUserInputs(provider, sutProvider); - - sutProvider.GetDependency().GetByIdAsync(providerOrganizationId) - .Returns(providerOrganization); - - sutProvider.GetDependency().GetByIdAsync(providerOrganization.OrganizationId) - .ReturnsNull(); - - var result = await sutProvider.Sut.UpdateAsync(provider.Id, providerOrganizationId, requestBody); - - Assert.IsType(result); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_AssignedSeats_NoContent( + public async Task UpdateAsync_AssignedSeats_Ok( Provider provider, Guid providerOrganizationId, UpdateClientOrganizationRequestBody requestBody, @@ -193,7 +141,7 @@ public class ProviderClientsControllerTests } [Theory, BitAutoData] - public async Task UpdateAsync_Name_NoContent( + public async Task UpdateAsync_Name_Ok( Provider provider, Guid providerOrganizationId, UpdateClientOrganizationRequestBody requestBody, diff --git a/test/Api.Test/Billing/Utilities.cs b/test/Api.Test/Billing/Utilities.cs index 7c361b760..ce528477d 100644 --- a/test/Api.Test/Billing/Utilities.cs +++ b/test/Api.Test/Billing/Utilities.cs @@ -4,14 +4,37 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Context; +using Bit.Core.Models.Api; using Bit.Core.Services; using Bit.Test.Common.AutoFixture; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.HttpResults; using NSubstitute; +using Xunit; namespace Bit.Api.Test.Billing; public static class Utilities { + public static void AssertNotFound(IResult result) + { + Assert.IsType>(result); + + var response = ((NotFound)result).Value; + + Assert.Equal("Resource not found.", response.Message); + } + + public static void AssertUnauthorized(IResult result) + { + Assert.IsType>(result); + + var response = (JsonHttpResult)result; + + Assert.Equal(StatusCodes.Status401Unauthorized, response.StatusCode); + Assert.Equal("Unauthorized.", response.Value.Message); + } + public static void ConfigureStableAdminInputs( Provider provider, SutProvider sutProvider) where T : BaseProviderController diff --git a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs index 6c2fdcd9f..6cbb3fb67 100644 --- a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs +++ b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs @@ -5,7 +5,6 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; using Bit.Core.Billing.Services.Implementations; using Bit.Core.Enums; -using Bit.Core.Models.BitStripe; using Bit.Core.Services; using Bit.Core.Settings; using Bit.Test.Common.AutoFixture; @@ -29,8 +28,9 @@ namespace Bit.Core.Test.Billing.Services; public class SubscriberServiceTests { #region CancelSubscription + [Theory, BitAutoData] - public async Task CancelSubscription_SubscriptionInactive_ContactSupport( + public async Task CancelSubscription_SubscriptionInactive_ThrowsBillingException( Organization organization, SutProvider sutProvider) { @@ -45,7 +45,7 @@ public class SubscriberServiceTests .SubscriptionGetAsync(organization.GatewaySubscriptionId) .Returns(subscription); - await ThrowsContactSupportAsync(() => + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.CancelSubscription(organization, new OffboardingSurveyResponse(), false)); await stripeAdapter @@ -192,9 +192,11 @@ public class SubscriberServiceTests .DidNotReceiveWithAnyArgs() .SubscriptionCancelAsync(Arg.Any(), Arg.Any()); ; } + #endregion #region GetCustomer + [Theory, BitAutoData] public async Task GetCustomer_NullSubscriber_ThrowsArgumentNullException( SutProvider sutProvider) @@ -256,9 +258,11 @@ public class SubscriberServiceTests Assert.Equivalent(customer, gotCustomer); } + #endregion #region GetCustomerOrThrow + [Theory, BitAutoData] public async Task GetCustomerOrThrow_NullSubscriber_ThrowsArgumentNullException( SutProvider sutProvider) @@ -266,17 +270,17 @@ public class SubscriberServiceTests async () => await sutProvider.Sut.GetCustomerOrThrow(null)); [Theory, BitAutoData] - public async Task GetCustomerOrThrow_NoGatewayCustomerId_ContactSupport( + public async Task GetCustomerOrThrow_NoGatewayCustomerId_ThrowsBillingException( Organization organization, SutProvider sutProvider) { organization.GatewayCustomerId = null; - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization)); + await ThrowsBillingExceptionAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization)); } [Theory, BitAutoData] - public async Task GetCustomerOrThrow_NoCustomer_ContactSupport( + public async Task GetCustomerOrThrow_NoCustomer_ThrowsBillingException( Organization organization, SutProvider sutProvider) { @@ -284,11 +288,11 @@ public class SubscriberServiceTests .CustomerGetAsync(organization.GatewayCustomerId) .ReturnsNull(); - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization)); + await ThrowsBillingExceptionAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization)); } [Theory, BitAutoData] - public async Task GetCustomerOrThrow_StripeException_ContactSupport( + public async Task GetCustomerOrThrow_StripeException_ThrowsBillingException( Organization organization, SutProvider sutProvider) { @@ -298,10 +302,10 @@ public class SubscriberServiceTests .CustomerGetAsync(organization.GatewayCustomerId) .ThrowsAsync(stripeException); - await ThrowsContactSupportAsync( + await ThrowsBillingExceptionAsync( async () => await sutProvider.Sut.GetCustomerOrThrow(organization), - "An error occurred while trying to retrieve a Stripe Customer", - stripeException); + message: "An error occurred while trying to retrieve a Stripe customer", + innerException: stripeException); } [Theory, BitAutoData] @@ -319,108 +323,6 @@ public class SubscriberServiceTests Assert.Equivalent(customer, gotCustomer); } - #endregion - - #region GetInvoices - - [Theory, BitAutoData] - public async Task GetInvoices_NullSubscriber_ThrowsArgumentNullException( - SutProvider sutProvider) - => await Assert.ThrowsAsync( - async () => await sutProvider.Sut.GetInvoices(null)); - - [Theory, BitAutoData] - public async Task GetCustomer_NoGatewayCustomerId_ReturnsEmptyList( - Organization organization, - SutProvider sutProvider) - { - organization.GatewayCustomerId = null; - - var invoices = await sutProvider.Sut.GetInvoices(organization); - - Assert.Empty(invoices); - } - - [Theory, BitAutoData] - public async Task GetInvoices_StripeException_ReturnsEmptyList( - Organization organization, - SutProvider sutProvider) - { - sutProvider.GetDependency() - .InvoiceListAsync(Arg.Any()) - .ThrowsAsync(); - - var invoices = await sutProvider.Sut.GetInvoices(organization); - - Assert.Empty(invoices); - } - - [Theory, BitAutoData] - public async Task GetInvoices_NullOptions_Succeeds( - Organization organization, - SutProvider sutProvider) - { - var invoices = new List - { - new () - { - Created = new DateTime(2024, 6, 1), - Number = "2", - Status = "open", - Total = 100000, - HostedInvoiceUrl = "https://example.com/invoice/2", - InvoicePdf = "https://example.com/invoice/2/pdf" - }, - new () - { - Created = new DateTime(2024, 5, 1), - Number = "1", - Status = "paid", - Total = 100000, - HostedInvoiceUrl = "https://example.com/invoice/1", - InvoicePdf = "https://example.com/invoice/1/pdf" - } - }; - - sutProvider.GetDependency() - .InvoiceListAsync(Arg.Is(options => options.Customer == organization.GatewayCustomerId)) - .Returns(invoices); - - var gotInvoices = await sutProvider.Sut.GetInvoices(organization); - - Assert.Equivalent(invoices, gotInvoices); - } - - [Theory, BitAutoData] - public async Task GetInvoices_ProvidedOptions_Succeeds( - Organization organization, - SutProvider sutProvider) - { - var invoices = new List - { - new () - { - Created = new DateTime(2024, 5, 1), - Number = "1", - Status = "paid", - Total = 100000, - } - }; - - sutProvider.GetDependency() - .InvoiceListAsync(Arg.Is( - options => - options.Customer == organization.GatewayCustomerId && - options.Status == "paid")) - .Returns(invoices); - - var gotInvoices = await sutProvider.Sut.GetInvoices(organization, new StripeInvoiceListOptions - { - Status = "paid" - }); - - Assert.Equivalent(invoices, gotInvoices); - } #endregion @@ -795,17 +697,17 @@ public class SubscriberServiceTests async () => await sutProvider.Sut.GetSubscriptionOrThrow(null)); [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_NoGatewaySubscriptionId_ContactSupport( + public async Task GetSubscriptionOrThrow_NoGatewaySubscriptionId_ThrowsBillingException( Organization organization, SutProvider sutProvider) { organization.GatewaySubscriptionId = null; - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); + await ThrowsBillingExceptionAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); } [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_NoSubscription_ContactSupport( + public async Task GetSubscriptionOrThrow_NoSubscription_ThrowsBillingException( Organization organization, SutProvider sutProvider) { @@ -813,11 +715,11 @@ public class SubscriberServiceTests .SubscriptionGetAsync(organization.GatewaySubscriptionId) .ReturnsNull(); - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); + await ThrowsBillingExceptionAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); } [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_StripeException_ContactSupport( + public async Task GetSubscriptionOrThrow_StripeException_ThrowsBillingException( Organization organization, SutProvider sutProvider) { @@ -827,10 +729,10 @@ public class SubscriberServiceTests .SubscriptionGetAsync(organization.GatewaySubscriptionId) .ThrowsAsync(stripeException); - await ThrowsContactSupportAsync( + await ThrowsBillingExceptionAsync( async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization), - "An error occurred while trying to retrieve a Stripe Subscription", - stripeException); + message: "An error occurred while trying to retrieve a Stripe subscription", + innerException: stripeException); } [Theory, BitAutoData] @@ -911,12 +813,12 @@ public class SubscriberServiceTests #region RemovePaymentMethod [Theory, BitAutoData] - public async Task RemovePaymentMethod_NullSubscriber_ArgumentNullException( + public async Task RemovePaymentMethod_NullSubscriber_ThrowsArgumentNullException( SutProvider sutProvider) => await Assert.ThrowsAsync(() => sutProvider.Sut.RemovePaymentMethod(null)); [Theory, BitAutoData] - public async Task RemovePaymentMethod_Braintree_NoCustomer_ContactSupport( + public async Task RemovePaymentMethod_Braintree_NoCustomer_ThrowsBillingException( Organization organization, SutProvider sutProvider) { @@ -940,7 +842,7 @@ public class SubscriberServiceTests braintreeGateway.Customer.Returns(customerGateway); - await ThrowsContactSupportAsync(() => sutProvider.Sut.RemovePaymentMethod(organization)); + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.RemovePaymentMethod(organization)); await customerGateway.Received(1).FindAsync(braintreeCustomerId); @@ -987,7 +889,7 @@ public class SubscriberServiceTests } [Theory, BitAutoData] - public async Task RemovePaymentMethod_Braintree_CustomerUpdateFails_ContactSupport( + public async Task RemovePaymentMethod_Braintree_CustomerUpdateFails_ThrowsBillingException( Organization organization, SutProvider sutProvider) { @@ -1028,7 +930,7 @@ public class SubscriberServiceTests Arg.Is(request => request.DefaultPaymentMethodToken == null)) .Returns(updateBraintreeCustomerResult); - await ThrowsContactSupportAsync(() => sutProvider.Sut.RemovePaymentMethod(organization)); + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.RemovePaymentMethod(organization)); await customerGateway.Received(1).FindAsync(braintreeCustomerId); @@ -1042,7 +944,7 @@ public class SubscriberServiceTests } [Theory, BitAutoData] - public async Task RemovePaymentMethod_Braintree_PaymentMethodDeleteFails_RollBack_ContactSupport( + public async Task RemovePaymentMethod_Braintree_PaymentMethodDeleteFails_RollBack_ThrowsBillingException( Organization organization, SutProvider sutProvider) { @@ -1086,7 +988,7 @@ public class SubscriberServiceTests paymentMethodGateway.DeleteAsync(paymentMethod.Token).Returns(deleteBraintreePaymentMethodResult); - await ThrowsContactSupportAsync(() => sutProvider.Sut.RemovePaymentMethod(organization)); + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.RemovePaymentMethod(organization)); await customerGateway.Received(1).FindAsync(braintreeCustomerId); @@ -1206,42 +1108,42 @@ public class SubscriberServiceTests #region UpdatePaymentMethod [Theory, BitAutoData] - public async Task UpdatePaymentMethod_NullSubscriber_ArgumentNullException( + public async Task UpdatePaymentMethod_NullSubscriber_ThrowsArgumentNullException( SutProvider sutProvider) => await Assert.ThrowsAsync(() => sutProvider.Sut.UpdatePaymentMethod(null, null)); [Theory, BitAutoData] - public async Task UpdatePaymentMethod_NullTokenizedPaymentMethod_ArgumentNullException( + public async Task UpdatePaymentMethod_NullTokenizedPaymentMethod_ThrowsArgumentNullException( Provider provider, SutProvider sutProvider) => await Assert.ThrowsAsync(() => sutProvider.Sut.UpdatePaymentMethod(provider, null)); [Theory, BitAutoData] - public async Task UpdatePaymentMethod_NoToken_ContactSupport( + public async Task UpdatePaymentMethod_NoToken_ThrowsBillingException( Provider provider, SutProvider sutProvider) { sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) .Returns(new Customer()); - await ThrowsContactSupportAsync(() => + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.Card, null))); } [Theory, BitAutoData] - public async Task UpdatePaymentMethod_UnsupportedPaymentMethod_ContactSupport( + public async Task UpdatePaymentMethod_UnsupportedPaymentMethod_ThrowsBillingException( Provider provider, SutProvider sutProvider) { sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) .Returns(new Customer()); - await ThrowsContactSupportAsync(() => + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.BitPay, "TOKEN"))); } [Theory, BitAutoData] - public async Task UpdatePaymentMethod_BankAccount_IncorrectNumberOfSetupIntentsForToken_ContactSupport( + public async Task UpdatePaymentMethod_BankAccount_IncorrectNumberOfSetupIntentsForToken_ThrowsBillingException( Provider provider, SutProvider sutProvider) { @@ -1253,7 +1155,7 @@ public class SubscriberServiceTests stripeAdapter.SetupIntentList(Arg.Is(options => options.PaymentMethod == "TOKEN")) .Returns([new SetupIntent(), new SetupIntent()]); - await ThrowsContactSupportAsync(() => + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.BankAccount, "TOKEN"))); } @@ -1348,7 +1250,7 @@ public class SubscriberServiceTests } [Theory, BitAutoData] - public async Task UpdatePaymentMethod_Braintree_NullCustomer_ContactSupport( + public async Task UpdatePaymentMethod_Braintree_NullCustomer_ThrowsBillingException( Provider provider, SutProvider sutProvider) { @@ -1368,13 +1270,13 @@ public class SubscriberServiceTests customerGateway.FindAsync(braintreeCustomerId).ReturnsNull(); - await ThrowsContactSupportAsync(() => sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.PayPal, "TOKEN"))); + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.PayPal, "TOKEN"))); await paymentMethodGateway.DidNotReceiveWithAnyArgs().CreateAsync(Arg.Any()); } [Theory, BitAutoData] - public async Task UpdatePaymentMethod_Braintree_ReplacePaymentMethod_CreatePaymentMethodFails_ContactSupport( + public async Task UpdatePaymentMethod_Braintree_ReplacePaymentMethod_CreatePaymentMethodFails_ThrowsBillingException( Provider provider, SutProvider sutProvider) { @@ -1406,13 +1308,13 @@ public class SubscriberServiceTests options => options.CustomerId == braintreeCustomerId && options.PaymentMethodNonce == "TOKEN")) .Returns(createPaymentMethodResult); - await ThrowsContactSupportAsync(() => sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.PayPal, "TOKEN"))); + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.PayPal, "TOKEN"))); await customerGateway.DidNotReceiveWithAnyArgs().UpdateAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] - public async Task UpdatePaymentMethod_Braintree_ReplacePaymentMethod_UpdateCustomerFails_DeletePaymentMethod_ContactSupport( + public async Task UpdatePaymentMethod_Braintree_ReplacePaymentMethod_UpdateCustomerFails_DeletePaymentMethod_ThrowsBillingException( Provider provider, SutProvider sutProvider) { @@ -1458,7 +1360,7 @@ public class SubscriberServiceTests options.DefaultPaymentMethodToken == createPaymentMethodResult.Target.Token)) .Returns(updateCustomerResult); - await ThrowsContactSupportAsync(() => sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.PayPal, "TOKEN"))); + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.PayPal, "TOKEN"))); await paymentMethodGateway.Received(1).DeleteAsync(createPaymentMethodResult.Target.Token); } @@ -1531,7 +1433,7 @@ public class SubscriberServiceTests } [Theory, BitAutoData] - public async Task UpdatePaymentMethod_Braintree_CreateCustomer_CustomerUpdateFails_ContactSupport( + public async Task UpdatePaymentMethod_Braintree_CreateCustomer_CustomerUpdateFails_ThrowsBillingException( Provider provider, SutProvider sutProvider) { @@ -1564,7 +1466,7 @@ public class SubscriberServiceTests options.PaymentMethodNonce == "TOKEN")) .Returns(createCustomerResult); - await ThrowsContactSupportAsync(() => + await ThrowsBillingExceptionAsync(() => sutProvider.Sut.UpdatePaymentMethod(provider, new TokenizedPaymentMethodDTO(PaymentMethodType.PayPal, "TOKEN"))); @@ -1648,7 +1550,7 @@ public class SubscriberServiceTests stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("tax_ids"))).Returns(customer); - var taxInformation = new TaxInformationDTO( + var taxInformation = new TaxInformation( "US", "12345", "123456789", @@ -1685,9 +1587,9 @@ public class SubscriberServiceTests () => sutProvider.Sut.VerifyBankAccount(null, (0, 0))); [Theory, BitAutoData] - public async Task VerifyBankAccount_NoSetupIntentId_ContactSupport( + public async Task VerifyBankAccount_NoSetupIntentId_ThrowsBillingException( Provider provider, - SutProvider sutProvider) => await ThrowsContactSupportAsync(() => sutProvider.Sut.VerifyBankAccount(provider, (1, 1))); + SutProvider sutProvider) => await ThrowsBillingExceptionAsync(() => sutProvider.Sut.VerifyBankAccount(provider, (1, 1))); [Theory, BitAutoData] public async Task VerifyBankAccount_MakesCorrectInvocations( diff --git a/test/Core.Test/Billing/Utilities.cs b/test/Core.Test/Billing/Utilities.cs index a66feebef..79383af6e 100644 --- a/test/Core.Test/Billing/Utilities.cs +++ b/test/Core.Test/Billing/Utilities.cs @@ -1,23 +1,22 @@ using Bit.Core.Billing; using Xunit; -using static Bit.Core.Billing.Utilities; - namespace Bit.Core.Test.Billing; public static class Utilities { - public static async Task ThrowsContactSupportAsync( + public static async Task ThrowsBillingExceptionAsync( Func function, - string internalMessage = null, + string response = null, + string message = null, Exception innerException = null) { - var contactSupport = ContactSupport(internalMessage, innerException); + var expected = new BillingException(response, message, innerException); - var exception = await Assert.ThrowsAsync(function); + var actual = await Assert.ThrowsAsync(function); - Assert.Equal(contactSupport.ClientFriendlyMessage, exception.ClientFriendlyMessage); - Assert.Equal(contactSupport.Message, exception.Message); - Assert.Equal(contactSupport.InnerException, exception.InnerException); + Assert.Equal(expected.Response, actual.Response); + Assert.Equal(expected.Message, actual.Message); + Assert.Equal(expected.InnerException, actual.InnerException); } }