diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs index 3fb0b0598..ffb4b889b 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs @@ -4,15 +4,14 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Providers.Interfaces; using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Commands; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Utilities; -using Microsoft.Extensions.Logging; using Stripe; namespace Bit.Commercial.Core.AdminConsole.Providers; @@ -20,35 +19,35 @@ namespace Bit.Commercial.Core.AdminConsole.Providers; public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProviderCommand { private readonly IEventService _eventService; - private readonly ILogger _logger; private readonly IMailService _mailService; private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationService _organizationService; private readonly IProviderOrganizationRepository _providerOrganizationRepository; private readonly IStripeAdapter _stripeAdapter; - private readonly IScaleSeatsCommand _scaleSeatsCommand; private readonly IFeatureService _featureService; + private readonly IProviderBillingService _providerBillingService; + private readonly ISubscriberService _subscriberService; public RemoveOrganizationFromProviderCommand( IEventService eventService, - ILogger logger, IMailService mailService, IOrganizationRepository organizationRepository, IOrganizationService organizationService, IProviderOrganizationRepository providerOrganizationRepository, IStripeAdapter stripeAdapter, - IScaleSeatsCommand scaleSeatsCommand, - IFeatureService featureService) + IFeatureService featureService, + IProviderBillingService providerBillingService, + ISubscriberService subscriberService) { _eventService = eventService; - _logger = logger; _mailService = mailService; _organizationRepository = organizationRepository; _organizationService = organizationService; _providerOrganizationRepository = providerOrganizationRepository; _stripeAdapter = stripeAdapter; - _scaleSeatsCommand = scaleSeatsCommand; _featureService = featureService; + _providerBillingService = providerBillingService; + _subscriberService = subscriberService; } public async Task RemoveOrganizationFromProvider( @@ -99,23 +98,19 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv Provider provider, IEnumerable organizationOwnerEmails) { - if (!organization.IsStripeEnabled()) - { - return; - } - var isConsolidatedBillingEnabled = _featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling); - var customerUpdateOptions = new CustomerUpdateOptions + if (isConsolidatedBillingEnabled && + provider.Status == ProviderStatusType.Billable && + organization.Status == OrganizationStatusType.Managed && + !string.IsNullOrEmpty(organization.GatewayCustomerId)) { - Coupon = string.Empty, - Email = organization.BillingEmail - }; + await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions + { + Description = string.Empty, + Email = organization.BillingEmail + }); - await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, customerUpdateOptions); - - if (isConsolidatedBillingEnabled && provider.Status == ProviderStatusType.Billable) - { var plan = StaticStore.GetPlan(organization.PlanType).PasswordManager; var subscriptionCreateOptions = new SubscriptionCreateOptions @@ -136,19 +131,25 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv var subscription = await _stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); organization.GatewaySubscriptionId = subscription.Id; + organization.Status = OrganizationStatusType.Created; - await _scaleSeatsCommand.ScalePasswordManagerSeats(provider, organization.PlanType, - -(organization.Seats ?? 0)); + await _providerBillingService.ScaleSeats(provider, organization.PlanType, -organization.Seats ?? 0); } - else + else if (organization.IsStripeEnabled()) { - var subscriptionUpdateOptions = new SubscriptionUpdateOptions + await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions { - CollectionMethod = "send_invoice", - DaysUntilDue = 30 - }; + Coupon = string.Empty, + Email = organization.BillingEmail + }); - await _stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId, subscriptionUpdateOptions); + await _stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId, new SubscriptionUpdateOptions + { + CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, + DaysUntilDue = 30 + }); + + await _subscriberService.RemovePaymentMethod(organization); } await _mailService.SendProviderUpdatePaymentMethod( diff --git a/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs new file mode 100644 index 000000000..330ab1617 --- /dev/null +++ b/bitwarden_license/src/Commercial.Core/Billing/ProviderBillingService.cs @@ -0,0 +1,512 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Entities; +using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Models; +using Bit.Core.Billing.Repositories; +using Bit.Core.Billing.Services; +using Bit.Core.Enums; +using Bit.Core.Models.Business; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Microsoft.Extensions.Logging; +using Stripe; +using static Bit.Core.Billing.Utilities; + +namespace Bit.Commercial.Core.Billing; + +public class ProviderBillingService( + IGlobalSettings globalSettings, + ILogger logger, + IOrganizationRepository organizationRepository, + IPaymentService paymentService, + IProviderOrganizationRepository providerOrganizationRepository, + IProviderPlanRepository providerPlanRepository, + IProviderRepository providerRepository, + IStripeAdapter stripeAdapter, + ISubscriberService subscriberService) : IProviderBillingService +{ + public async Task AssignSeatsToClientOrganization( + Provider provider, + Organization organization, + int seats) + { + ArgumentNullException.ThrowIfNull(organization); + + if (seats < 0) + { + throw new BillingException( + "You cannot assign negative seats to a client.", + "MSP cannot assign negative seats to a client organization"); + } + + if (seats == organization.Seats) + { + logger.LogWarning("Client organization ({ID}) already has {Seats} seats assigned to it", organization.Id, organization.Seats); + + return; + } + + var seatAdjustment = seats - (organization.Seats ?? 0); + + await ScaleSeats(provider, organization.PlanType, seatAdjustment); + + organization.Seats = seats; + + await organizationRepository.ReplaceAsync(organization); + } + + public async Task CreateCustomer( + Provider provider, + TaxInfo taxInfo) + { + ArgumentNullException.ThrowIfNull(provider); + ArgumentNullException.ThrowIfNull(taxInfo); + + if (string.IsNullOrEmpty(taxInfo.BillingAddressCountry) || + string.IsNullOrEmpty(taxInfo.BillingAddressPostalCode)) + { + logger.LogError("Cannot create Stripe customer for provider ({ID}) - Both the provider's country and postal code are required", provider.Id); + + throw ContactSupport(); + } + + var providerDisplayName = provider.DisplayName(); + + var customerCreateOptions = new CustomerCreateOptions + { + Address = new AddressOptions + { + Country = taxInfo.BillingAddressCountry, + PostalCode = taxInfo.BillingAddressPostalCode, + Line1 = taxInfo.BillingAddressLine1, + Line2 = taxInfo.BillingAddressLine2, + City = taxInfo.BillingAddressCity, + State = taxInfo.BillingAddressState + }, + Coupon = "msp-discount-35", + Description = provider.DisplayBusinessName(), + Email = provider.BillingEmail, + InvoiceSettings = new CustomerInvoiceSettingsOptions + { + CustomFields = + [ + new CustomerInvoiceSettingsCustomFieldOptions + { + Name = provider.SubscriberType(), + Value = providerDisplayName.Length <= 30 + ? providerDisplayName + : providerDisplayName[..30] + } + ] + }, + Metadata = new Dictionary + { + { "region", globalSettings.BaseServiceUri.CloudRegion } + }, + TaxIdData = taxInfo.HasTaxId ? + [ + new CustomerTaxIdDataOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber } + ] + : null + }; + + var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + + provider.GatewayCustomerId = customer.Id; + + await providerRepository.ReplaceAsync(provider); + } + + public async Task CreateCustomerForClientOrganization( + Provider provider, + Organization organization) + { + ArgumentNullException.ThrowIfNull(provider); + ArgumentNullException.ThrowIfNull(organization); + + if (!string.IsNullOrEmpty(organization.GatewayCustomerId)) + { + logger.LogWarning("Client organization ({ID}) already has a populated {FieldName}", organization.Id, nameof(organization.GatewayCustomerId)); + + return; + } + + var providerCustomer = await subscriberService.GetCustomerOrThrow(provider, new CustomerGetOptions + { + Expand = ["tax_ids"] + }); + + var providerTaxId = providerCustomer.TaxIds.FirstOrDefault(); + + var organizationDisplayName = organization.DisplayName(); + + var customerCreateOptions = new CustomerCreateOptions + { + Address = new AddressOptions + { + Country = providerCustomer.Address?.Country, + PostalCode = providerCustomer.Address?.PostalCode, + Line1 = providerCustomer.Address?.Line1, + Line2 = providerCustomer.Address?.Line2, + City = providerCustomer.Address?.City, + State = providerCustomer.Address?.State + }, + Name = organizationDisplayName, + Description = $"{provider.Name} Client Organization", + Email = provider.BillingEmail, + InvoiceSettings = new CustomerInvoiceSettingsOptions + { + CustomFields = + [ + new CustomerInvoiceSettingsCustomFieldOptions + { + Name = organization.SubscriberType(), + Value = organizationDisplayName.Length <= 30 + ? organizationDisplayName + : organizationDisplayName[..30] + } + ] + }, + Metadata = new Dictionary + { + { "region", globalSettings.BaseServiceUri.CloudRegion } + }, + TaxIdData = providerTaxId == null ? null : + [ + new CustomerTaxIdDataOptions + { + Type = providerTaxId.Type, + Value = providerTaxId.Value + } + ] + }; + + var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + + organization.GatewayCustomerId = customer.Id; + + await organizationRepository.ReplaceAsync(organization); + } + + public async Task GetAssignedSeatTotalForPlanOrThrow( + Guid providerId, + PlanType planType) + { + var provider = await providerRepository.GetByIdAsync(providerId); + + if (provider == null) + { + logger.LogError( + "Could not find provider ({ID}) when retrieving assigned seat total", + providerId); + + throw ContactSupport(); + } + + if (provider.Type == ProviderType.Reseller) + { + logger.LogError("Assigned seats cannot be retrieved for reseller-type provider ({ID})", providerId); + + throw ContactSupport("Consolidated billing does not support reseller-type providers"); + } + + var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId); + + var plan = StaticStore.GetPlan(planType); + + return providerOrganizations + .Where(providerOrganization => providerOrganization.Plan == plan.Name && providerOrganization.Status == OrganizationStatusType.Managed) + .Sum(providerOrganization => providerOrganization.Seats ?? 0); + } + + public async Task GetSubscriptionDTO(Guid providerId) + { + var provider = await providerRepository.GetByIdAsync(providerId); + + if (provider == null) + { + logger.LogError( + "Could not find provider ({ID}) when retrieving subscription data.", + providerId); + + return null; + } + + if (provider.Type == ProviderType.Reseller) + { + logger.LogError("Subscription data cannot be retrieved for reseller-type provider ({ID})", providerId); + + throw ContactSupport("Consolidated billing does not support reseller-type providers"); + } + + var subscription = await subscriberService.GetSubscription(provider, new SubscriptionGetOptions + { + Expand = ["customer"] + }); + + if (subscription == null) + { + return null; + } + + var providerPlans = await providerPlanRepository.GetByProviderId(providerId); + + var configuredProviderPlans = providerPlans + .Where(providerPlan => providerPlan.IsConfigured()) + .Select(ConfiguredProviderPlanDTO.From) + .ToList(); + + return new ProviderSubscriptionDTO( + configuredProviderPlans, + subscription); + } + + public async Task ScaleSeats( + Provider provider, + PlanType planType, + int seatAdjustment) + { + ArgumentNullException.ThrowIfNull(provider); + + if (provider.Type != ProviderType.Msp) + { + logger.LogError("Non-MSP provider ({ProviderID}) cannot scale their seats", provider.Id); + + throw ContactSupport(); + } + + if (!planType.SupportsConsolidatedBilling()) + { + logger.LogError("Cannot scale provider ({ProviderID}) seats for plan type {PlanType} as it does not support consolidated billing", provider.Id, planType.ToString()); + + throw ContactSupport(); + } + + var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); + + var providerPlan = providerPlans.FirstOrDefault(providerPlan => providerPlan.PlanType == planType); + + if (providerPlan == null || !providerPlan.IsConfigured()) + { + logger.LogError("Cannot scale provider ({ProviderID}) seats for plan type {PlanType} when their matching provider plan is not configured", provider.Id, planType); + + throw ContactSupport(); + } + + var seatMinimum = providerPlan.SeatMinimum.GetValueOrDefault(0); + + var currentlyAssignedSeatTotal = await GetAssignedSeatTotalForPlanOrThrow(provider.Id, planType); + + var newlyAssignedSeatTotal = currentlyAssignedSeatTotal + seatAdjustment; + + var update = CurrySeatScalingUpdate( + provider, + providerPlan, + newlyAssignedSeatTotal); + + /* + * Below the limit => Below the limit: + * No subscription update required. We can safely update the provider's allocated seats. + */ + if (currentlyAssignedSeatTotal <= seatMinimum && + newlyAssignedSeatTotal <= seatMinimum) + { + providerPlan.AllocatedSeats = newlyAssignedSeatTotal; + + await providerPlanRepository.ReplaceAsync(providerPlan); + } + /* + * Below the limit => Above the limit: + * We have to scale the subscription up from the seat minimum to the newly assigned seat total. + */ + else if (currentlyAssignedSeatTotal <= seatMinimum && + newlyAssignedSeatTotal > seatMinimum) + { + await update( + seatMinimum, + newlyAssignedSeatTotal); + } + /* + * Above the limit => Above the limit: + * We have to scale the subscription from the currently assigned seat total to the newly assigned seat total. + */ + else if (currentlyAssignedSeatTotal > seatMinimum && + newlyAssignedSeatTotal > seatMinimum) + { + await update( + currentlyAssignedSeatTotal, + newlyAssignedSeatTotal); + } + /* + * Above the limit => Below the limit: + * We have to scale the subscription down from the currently assigned seat total to the seat minimum. + */ + else if (currentlyAssignedSeatTotal > seatMinimum && + newlyAssignedSeatTotal <= seatMinimum) + { + await update( + currentlyAssignedSeatTotal, + seatMinimum); + } + } + + public async Task StartSubscription( + Provider provider) + { + ArgumentNullException.ThrowIfNull(provider); + + if (!string.IsNullOrEmpty(provider.GatewaySubscriptionId)) + { + logger.LogWarning("Cannot start Provider subscription - Provider ({ID}) already has a {FieldName}", provider.Id, nameof(provider.GatewaySubscriptionId)); + + throw ContactSupport(); + } + + var customer = await subscriberService.GetCustomerOrThrow(provider); + + var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); + + if (providerPlans == null || providerPlans.Count == 0) + { + logger.LogError("Cannot start Provider subscription - Provider ({ID}) has no configured plans", provider.Id); + + throw ContactSupport(); + } + + var subscriptionItemOptionsList = new List(); + + var teamsProviderPlan = + providerPlans.SingleOrDefault(providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly); + + if (teamsProviderPlan == null) + { + logger.LogError("Cannot start Provider subscription - Provider ({ID}) has no configured Teams Monthly plan", provider.Id); + + throw ContactSupport(); + } + + var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + + subscriptionItemOptionsList.Add(new SubscriptionItemOptions + { + Price = teamsPlan.PasswordManager.StripeSeatPlanId, + Quantity = teamsProviderPlan.SeatMinimum + }); + + var enterpriseProviderPlan = + providerPlans.SingleOrDefault(providerPlan => providerPlan.PlanType == PlanType.EnterpriseMonthly); + + if (enterpriseProviderPlan == null) + { + logger.LogError("Cannot start Provider subscription - Provider ({ID}) has no configured Enterprise Monthly plan", provider.Id); + + throw ContactSupport(); + } + + var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + + subscriptionItemOptionsList.Add(new SubscriptionItemOptions + { + Price = enterprisePlan.PasswordManager.StripeSeatPlanId, + Quantity = enterpriseProviderPlan.SeatMinimum + }); + + var subscriptionCreateOptions = new SubscriptionCreateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = true + }, + CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, + Customer = customer.Id, + DaysUntilDue = 30, + Items = subscriptionItemOptionsList, + Metadata = new Dictionary + { + { "providerId", provider.Id.ToString() } + }, + OffSession = true, + ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations + }; + + var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + + provider.GatewaySubscriptionId = subscription.Id; + + if (subscription.Status == StripeConstants.SubscriptionStatus.Incomplete) + { + await providerRepository.ReplaceAsync(provider); + + logger.LogError("Started incomplete Provider ({ProviderID}) subscription ({SubscriptionID})", provider.Id, subscription.Id); + + throw ContactSupport(); + } + + provider.Status = ProviderStatusType.Billable; + + await providerRepository.ReplaceAsync(provider); + } + + public async Task GetPaymentInformationAsync(Guid providerId) + { + var provider = await providerRepository.GetByIdAsync(providerId); + + if (provider == null) + { + logger.LogError( + "Could not find provider ({ID}) when retrieving payment information.", + providerId); + + return null; + } + + if (provider.Type == ProviderType.Reseller) + { + logger.LogError("payment information cannot be retrieved for reseller-type provider ({ID})", providerId); + + throw ContactSupport("Consolidated billing does not support reseller-type providers"); + } + + var taxInformation = await subscriberService.GetTaxInformationAsync(provider); + var billingInformation = await subscriberService.GetPaymentMethodAsync(provider); + + if (taxInformation == null && billingInformation == null) + { + return null; + } + + return new ProviderPaymentInfoDTO( + billingInformation, + taxInformation); + } + + private Func CurrySeatScalingUpdate( + Provider provider, + ProviderPlan providerPlan, + int newlyAssignedSeats) => async (currentlySubscribedSeats, newlySubscribedSeats) => + { + var plan = StaticStore.GetPlan(providerPlan.PlanType); + + await paymentService.AdjustSeats( + provider, + plan, + currentlySubscribedSeats, + newlySubscribedSeats); + + var newlyPurchasedSeats = newlySubscribedSeats > providerPlan.SeatMinimum + ? newlySubscribedSeats - providerPlan.SeatMinimum + : 0; + + providerPlan.PurchasedSeats = newlyPurchasedSeats; + providerPlan.AllocatedSeats = newlyAssignedSeats; + + await providerPlanRepository.ReplaceAsync(providerPlan); + }; +} diff --git a/bitwarden_license/src/Commercial.Core/Utilities/ServiceCollectionExtensions.cs b/bitwarden_license/src/Commercial.Core/Utilities/ServiceCollectionExtensions.cs index 53c089f9f..5ae5be884 100644 --- a/bitwarden_license/src/Commercial.Core/Utilities/ServiceCollectionExtensions.cs +++ b/bitwarden_license/src/Commercial.Core/Utilities/ServiceCollectionExtensions.cs @@ -1,7 +1,9 @@ using Bit.Commercial.Core.AdminConsole.Providers; using Bit.Commercial.Core.AdminConsole.Services; +using Bit.Commercial.Core.Billing; using Bit.Core.AdminConsole.Providers.Interfaces; using Bit.Core.AdminConsole.Services; +using Bit.Core.Billing.Services; using Microsoft.Extensions.DependencyInjection; namespace Bit.Commercial.Core.Utilities; @@ -13,5 +15,6 @@ public static class ServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddScoped(); + services.AddTransient(); } } diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs index e5b0a4e3d..a549c5007 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs @@ -4,17 +4,18 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Stripe; using Xunit; -using IMailService = Bit.Core.Services.IMailService; namespace Bit.Commercial.Core.Test.AdminConsole.ProviderFeatures; @@ -74,9 +75,9 @@ public class RemoveOrganizationFromProviderCommandTests providerOrganization.ProviderId = provider.Id; sutProvider.GetDependency().HasConfirmedOwnersExceptAsync( - providerOrganization.OrganizationId, - Array.Empty(), - includeProvider: false) + providerOrganization.OrganizationId, + [], + includeProvider: false) .Returns(false); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization)); @@ -85,56 +86,53 @@ public class RemoveOrganizationFromProviderCommandTests } [Theory, BitAutoData] - public async Task RemoveOrganizationFromProvider_NoStripeObjects_MakesCorrectInvocations( + public async Task RemoveOrganizationFromProvider_OrganizationNotStripeEnabled_MakesCorrectInvocations( Provider provider, ProviderOrganization providerOrganization, Organization organization, SutProvider sutProvider) { + providerOrganization.ProviderId = provider.Id; + organization.GatewayCustomerId = null; organization.GatewaySubscriptionId = null; - providerOrganization.ProviderId = provider.Id; - - var organizationRepository = sutProvider.GetDependency(); - sutProvider.GetDependency().HasConfirmedOwnersExceptAsync( providerOrganization.OrganizationId, - Array.Empty(), + [], includeProvider: false) .Returns(true); - var organizationOwnerEmails = new List { "a@gmail.com", "b@gmail.com" }; + var organizationRepository = sutProvider.GetDependency(); - organizationRepository.GetOwnerEmailAddressesById(organization.Id).Returns(organizationOwnerEmails); + organizationRepository.GetOwnerEmailAddressesById(organization.Id).Returns([ + "a@example.com", + "b@example.com" + ]); await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); - await organizationRepository.Received(1).ReplaceAsync(Arg.Is( - org => org.Id == organization.Id && org.BillingEmail == "a@gmail.com")); - - var stripeAdapter = sutProvider.GetDependency(); - - await stripeAdapter.DidNotReceiveWithAnyArgs().CustomerUpdateAsync(Arg.Any(), Arg.Any()); - - await stripeAdapter.DidNotReceiveWithAnyArgs().SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); - - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().SendProviderUpdatePaymentMethod( - Arg.Any(), - Arg.Any(), - Arg.Any(), - Arg.Any>()); + await organizationRepository.Received(1).ReplaceAsync(Arg.Is(org => org.BillingEmail == "a@example.com")); await sutProvider.GetDependency().Received(1) .DeleteAsync(providerOrganization); - await sutProvider.GetDependency().Received(1).LogProviderOrganizationEventAsync( - providerOrganization, - EventType.ProviderOrganization_Removed); + await sutProvider.GetDependency().Received(1) + .LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed); + + await sutProvider.GetDependency().Received(1) + .SendProviderUpdatePaymentMethod( + organization.Id, + organization.Name, + provider.Name, + Arg.Is>(emails => emails.FirstOrDefault() == "a@example.com")); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .CustomerUpdateAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] - public async Task RemoveOrganizationFromProvider_MakesCorrectInvocations__FeatureFlagOff( + public async Task RemoveOrganizationFromProvider_OrganizationStripeEnabled_NonConsolidatedBilling_MakesCorrectInvocations( Provider provider, ProviderOrganization providerOrganization, Organization organization, @@ -142,104 +140,126 @@ public class RemoveOrganizationFromProviderCommandTests { providerOrganization.ProviderId = provider.Id; - var organizationRepository = sutProvider.GetDependency(); - sutProvider.GetDependency().HasConfirmedOwnersExceptAsync( providerOrganization.OrganizationId, - Array.Empty(), + [], includeProvider: false) .Returns(true); - var organizationOwnerEmails = new List { "a@example.com", "b@example.com" }; + var organizationRepository = sutProvider.GetDependency(); - organizationRepository.GetOwnerEmailAddressesById(organization.Id).Returns(organizationOwnerEmails); - var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription - { - Id = "S-1", - CurrentPeriodEnd = DateTime.Today.AddDays(10), - }); + organizationRepository.GetOwnerEmailAddressesById(organization.Id).Returns([ + "a@example.com", + "b@example.com" + ]); + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(false); await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); - await organizationRepository.Received(1).ReplaceAsync(Arg.Is( - org => org.Id == organization.Id && org.BillingEmail == "a@example.com")); + var stripeAdapter = sutProvider.GetDependency(); - await stripeAdapter.Received(1).CustomerUpdateAsync( - organization.GatewayCustomerId, Arg.Is( - options => options.Coupon == string.Empty && options.Email == "a@example.com")); + await stripeAdapter.Received(1).CustomerUpdateAsync(organization.GatewayCustomerId, + Arg.Is(options => + options.Coupon == string.Empty && options.Email == "a@example.com")); - await sutProvider.GetDependency().Received(1).SendProviderUpdatePaymentMethod( - organization.Id, - organization.Name, - provider.Name, - Arg.Is>(emails => emails.Contains("a@example.com") && emails.Contains("b@example.com"))); + await stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + Arg.Is(options => + options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && + options.DaysUntilDue == 30)); + + await sutProvider.GetDependency().Received(1).RemovePaymentMethod(organization); + + await organizationRepository.Received(1).ReplaceAsync(Arg.Is(org => org.BillingEmail == "a@example.com")); await sutProvider.GetDependency().Received(1) .DeleteAsync(providerOrganization); - await sutProvider.GetDependency().Received(1).LogProviderOrganizationEventAsync( - providerOrganization, - EventType.ProviderOrganization_Removed); + await sutProvider.GetDependency().Received(1) + .LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed); + + await sutProvider.GetDependency().Received(1) + .SendProviderUpdatePaymentMethod( + organization.Id, + organization.Name, + provider.Name, + Arg.Is>(emails => emails.FirstOrDefault() == "a@example.com")); } [Theory, BitAutoData] - public async Task RemoveOrganizationFromProvider_CreatesSubscriptionAndScalesSeats_FeatureFlagON(Provider provider, + public async Task RemoveOrganizationFromProvider_OrganizationStripeEnabled_ConsolidatedBilling_MakesCorrectInvocations( + Provider provider, ProviderOrganization providerOrganization, Organization organization, SutProvider sutProvider) { - providerOrganization.ProviderId = provider.Id; provider.Status = ProviderStatusType.Billable; - var organizationRepository = sutProvider.GetDependency(); + + providerOrganization.ProviderId = provider.Id; + + organization.Status = OrganizationStatusType.Managed; + + organization.PlanType = PlanType.TeamsMonthly; + + var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + sutProvider.GetDependency().HasConfirmedOwnersExceptAsync( providerOrganization.OrganizationId, - Array.Empty(), + [], includeProvider: false) .Returns(true); - var organizationOwnerEmails = new List { "a@example.com", "b@example.com" }; + var organizationRepository = sutProvider.GetDependency(); - organizationRepository.GetOwnerEmailAddressesById(organization.Id).Returns(organizationOwnerEmails); + organizationRepository.GetOwnerEmailAddressesById(organization.Id).Returns([ + "a@example.com", + "b@example.com" + ]); + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling) + .Returns(true); var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription + + stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(new Subscription { - Id = "S-1", - CurrentPeriodEnd = DateTime.Today.AddDays(10), + Id = "subscription_id" }); - sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling).Returns(true); + await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); - await stripeAdapter.Received(1).CustomerUpdateAsync( - organization.GatewayCustomerId, Arg.Is( - options => options.Coupon == string.Empty && options.Email == "a@example.com")); - await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is(c => - c.Customer == organization.GatewayCustomerId && - c.CollectionMethod == "send_invoice" && - c.DaysUntilDue == 30 && - c.Items.Count == 1 - )); + await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is(options => + options.Customer == organization.GatewayCustomerId && + options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && + options.DaysUntilDue == 30 && + options.AutomaticTax.Enabled == true && + options.Metadata["organizationId"] == organization.Id.ToString() && + options.OffSession == true && + options.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations && + options.Items.First().Price == teamsMonthlyPlan.PasswordManager.StripeSeatPlanId && + options.Items.First().Quantity == organization.Seats)); - await sutProvider.GetDependency().Received(1) - .ScalePasswordManagerSeats(provider, organization.PlanType, -(int)organization.Seats); + await sutProvider.GetDependency().Received(1) + .ScaleSeats(provider, organization.PlanType, -organization.Seats ?? 0); await organizationRepository.Received(1).ReplaceAsync(Arg.Is( - org => org.Id == organization.Id && org.BillingEmail == "a@example.com" && - org.GatewaySubscriptionId == "S-1")); - - await sutProvider.GetDependency().Received(1).SendProviderUpdatePaymentMethod( - organization.Id, - organization.Name, - provider.Name, - Arg.Is>(emails => - emails.Contains("a@example.com") && emails.Contains("b@example.com"))); + org => + org.BillingEmail == "a@example.com" && + org.GatewaySubscriptionId == "subscription_id" && + org.Status == OrganizationStatusType.Created)); await sutProvider.GetDependency().Received(1) .DeleteAsync(providerOrganization); - await sutProvider.GetDependency().Received(1).LogProviderOrganizationEventAsync( - providerOrganization, - EventType.ProviderOrganization_Removed); + await sutProvider.GetDependency().Received(1) + .LogProviderOrganizationEventAsync(providerOrganization, EventType.ProviderOrganization_Removed); + + await sutProvider.GetDependency().Received(1) + .SendProviderUpdatePaymentMethod( + organization.Id, + organization.Name, + provider.Name, + Arg.Is>(emails => emails.FirstOrDefault() == "a@example.com")); } } diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs new file mode 100644 index 000000000..c5bcc4fc2 --- /dev/null +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/ProviderBillingServiceTests.cs @@ -0,0 +1,1110 @@ +using System.Net; +using Bit.Commercial.Core.Billing; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Models.Data.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Entities; +using Bit.Core.Billing.Models; +using Bit.Core.Billing.Repositories; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Models.Business; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using NSubstitute.ReturnsExtensions; +using Stripe; +using Xunit; +using static Bit.Core.Test.Billing.Utilities; + +namespace Bit.Commercial.Core.Test.Billing; + +[SutProviderCustomize] +public class ProviderBillingServiceTests +{ + #region AssignSeatsToClientOrganization & ScaleSeats + + [Theory, BitAutoData] + public Task AssignSeatsToClientOrganization_NullProvider_ArgumentNullException( + Organization organization, + int seats, + SutProvider sutProvider) + => Assert.ThrowsAsync(() => + sutProvider.Sut.AssignSeatsToClientOrganization(null, organization, seats)); + + [Theory, BitAutoData] + public Task AssignSeatsToClientOrganization_NullOrganization_ArgumentNullException( + Provider provider, + int seats, + SutProvider sutProvider) + => Assert.ThrowsAsync(() => + sutProvider.Sut.AssignSeatsToClientOrganization(provider, null, seats)); + + [Theory, BitAutoData] + public Task AssignSeatsToClientOrganization_NegativeSeats_BillingException( + Provider provider, + Organization organization, + SutProvider sutProvider) + => Assert.ThrowsAsync(() => + sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, -5)); + + [Theory, BitAutoData] + public async Task AssignSeatsToClientOrganization_CurrentSeatsMatchesNewSeats_NoOp( + Provider provider, + Organization organization, + int seats, + SutProvider sutProvider) + { + organization.PlanType = PlanType.TeamsMonthly; + + organization.Seats = seats; + + await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); + + await sutProvider.GetDependency().DidNotReceive().GetByProviderId(provider.Id); + } + + [Theory, BitAutoData] + public async Task + AssignSeatsToClientOrganization_OrganizationPlanTypeDoesNotSupportConsolidatedBilling_ContactSupport( + Provider provider, + Organization organization, + int seats, + SutProvider sutProvider) + { + organization.PlanType = PlanType.FamiliesAnnually; + + await ThrowsContactSupportAsync(() => + sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats)); + } + + [Theory, BitAutoData] + public async Task AssignSeatsToClientOrganization_ProviderPlanIsNotConfigured_ContactSupport( + Provider provider, + Organization organization, + int seats, + SutProvider sutProvider) + { + organization.PlanType = PlanType.TeamsMonthly; + + sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(new List + { + new() { Id = Guid.NewGuid(), PlanType = PlanType.TeamsMonthly, ProviderId = provider.Id } + }); + + await ThrowsContactSupportAsync(() => + sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats)); + } + + [Theory, BitAutoData] + public async Task AssignSeatsToClientOrganization_BelowToBelow_Succeeds( + Provider provider, + Organization organization, + SutProvider sutProvider) + { + organization.Seats = 10; + + organization.PlanType = PlanType.TeamsMonthly; + + // Scale up 10 seats + const int seats = 20; + + var providerPlans = new List + { + new() + { + Id = Guid.NewGuid(), + PlanType = PlanType.TeamsMonthly, + ProviderId = provider.Id, + PurchasedSeats = 0, + // 100 minimum + SeatMinimum = 100, + AllocatedSeats = 50 + }, + new() + { + Id = Guid.NewGuid(), + PlanType = PlanType.EnterpriseMonthly, + ProviderId = provider.Id, + PurchasedSeats = 0, + SeatMinimum = 500, + AllocatedSeats = 0 + } + }; + + var providerPlan = providerPlans.First(); + + sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); + + // 50 seats currently assigned with a seat minimum of 100 + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + + var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + + sutProvider.GetDependency().GetManyDetailsByProviderAsync(provider.Id).Returns( + [ + new ProviderOrganizationOrganizationDetails + { + Plan = teamsMonthlyPlan.Name, + Status = OrganizationStatusType.Managed, + Seats = 25 + }, + new ProviderOrganizationOrganizationDetails + { + Plan = teamsMonthlyPlan.Name, + Status = OrganizationStatusType.Managed, + Seats = 25 + } + ]); + + await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); + + // 50 assigned seats + 10 seat scale up = 60 seats, well below the 100 minimum + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().AdjustSeats( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()); + + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + org => org.Id == organization.Id && org.Seats == seats)); + + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + pPlan => pPlan.AllocatedSeats == 60)); + } + + [Theory, BitAutoData] + public async Task AssignSeatsToClientOrganization_BelowToAbove_Succeeds( + Provider provider, + Organization organization, + SutProvider sutProvider) + { + organization.Seats = 10; + + organization.PlanType = PlanType.TeamsMonthly; + + // Scale up 10 seats + const int seats = 20; + + var providerPlans = new List + { + new() + { + Id = Guid.NewGuid(), + PlanType = PlanType.TeamsMonthly, + ProviderId = provider.Id, + PurchasedSeats = 0, + // 100 minimum + SeatMinimum = 100, + AllocatedSeats = 95 + }, + new() + { + Id = Guid.NewGuid(), + PlanType = PlanType.EnterpriseMonthly, + ProviderId = provider.Id, + PurchasedSeats = 0, + SeatMinimum = 500, + AllocatedSeats = 0 + } + }; + + var providerPlan = providerPlans.First(); + + sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); + + // 95 seats currently assigned with a seat minimum of 100 + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + + var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + + sutProvider.GetDependency().GetManyDetailsByProviderAsync(provider.Id).Returns( + [ + new ProviderOrganizationOrganizationDetails + { + Plan = teamsMonthlyPlan.Name, + Status = OrganizationStatusType.Managed, + Seats = 60 + }, + new ProviderOrganizationOrganizationDetails + { + Plan = teamsMonthlyPlan.Name, + Status = OrganizationStatusType.Managed, + Seats = 35 + } + ]); + + await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); + + // 95 current + 10 seat scale = 105 seats, 5 above the minimum + await sutProvider.GetDependency().Received(1).AdjustSeats( + provider, + StaticStore.GetPlan(providerPlan.PlanType), + providerPlan.SeatMinimum!.Value, + 105); + + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + org => org.Id == organization.Id && org.Seats == seats)); + + // 105 total seats - 100 minimum = 5 purchased seats + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + pPlan => pPlan.Id == providerPlan.Id && pPlan.PurchasedSeats == 5 && pPlan.AllocatedSeats == 105)); + } + + [Theory, BitAutoData] + public async Task AssignSeatsToClientOrganization_AboveToAbove_Succeeds( + Provider provider, + Organization organization, + SutProvider sutProvider) + { + provider.Type = ProviderType.Msp; + + organization.Seats = 10; + + organization.PlanType = PlanType.TeamsMonthly; + + // Scale up 10 seats + const int seats = 20; + + var providerPlans = new List + { + new() + { + Id = Guid.NewGuid(), + PlanType = PlanType.TeamsMonthly, + ProviderId = provider.Id, + // 10 additional purchased seats + PurchasedSeats = 10, + // 100 seat minimum + SeatMinimum = 100, + AllocatedSeats = 110 + }, + new() + { + Id = Guid.NewGuid(), + PlanType = PlanType.EnterpriseMonthly, + ProviderId = provider.Id, + PurchasedSeats = 0, + SeatMinimum = 500, + AllocatedSeats = 0 + } + }; + + var providerPlan = providerPlans.First(); + + sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); + + // 110 seats currently assigned with a seat minimum of 100 + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + + var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + + sutProvider.GetDependency().GetManyDetailsByProviderAsync(provider.Id).Returns( + [ + new ProviderOrganizationOrganizationDetails + { + Plan = teamsMonthlyPlan.Name, + Status = OrganizationStatusType.Managed, + Seats = 60 + }, + new ProviderOrganizationOrganizationDetails + { + Plan = teamsMonthlyPlan.Name, + Status = OrganizationStatusType.Managed, + Seats = 50 + } + ]); + + await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); + + // 110 current + 10 seat scale up = 120 seats + await sutProvider.GetDependency().Received(1).AdjustSeats( + provider, + StaticStore.GetPlan(providerPlan.PlanType), + 110, + 120); + + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + org => org.Id == organization.Id && org.Seats == seats)); + + // 120 total seats - 100 seat minimum = 20 purchased seats + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + pPlan => pPlan.Id == providerPlan.Id && pPlan.PurchasedSeats == 20 && pPlan.AllocatedSeats == 120)); + } + + [Theory, BitAutoData] + public async Task AssignSeatsToClientOrganization_AboveToBelow_Succeeds( + Provider provider, + Organization organization, + SutProvider sutProvider) + { + organization.Seats = 50; + + organization.PlanType = PlanType.TeamsMonthly; + + // Scale down 30 seats + const int seats = 20; + + var providerPlans = new List + { + new() + { + Id = Guid.NewGuid(), + PlanType = PlanType.TeamsMonthly, + ProviderId = provider.Id, + // 10 additional purchased seats + PurchasedSeats = 10, + // 100 seat minimum + SeatMinimum = 100, + AllocatedSeats = 110 + }, + new() + { + Id = Guid.NewGuid(), + PlanType = PlanType.EnterpriseMonthly, + ProviderId = provider.Id, + PurchasedSeats = 0, + SeatMinimum = 500, + AllocatedSeats = 0 + } + }; + + var providerPlan = providerPlans.First(); + + sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); + + // 110 seats currently assigned with a seat minimum of 100 + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + + var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + + sutProvider.GetDependency().GetManyDetailsByProviderAsync(provider.Id).Returns( + [ + new ProviderOrganizationOrganizationDetails + { + Plan = teamsMonthlyPlan.Name, + Status = OrganizationStatusType.Managed, + Seats = 60 + }, + new ProviderOrganizationOrganizationDetails + { + Plan = teamsMonthlyPlan.Name, + Status = OrganizationStatusType.Managed, + Seats = 50 + } + ]); + + await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); + + // 110 seats - 30 scale down seats = 80 seats, below the 100 seat minimum. + await sutProvider.GetDependency().Received(1).AdjustSeats( + provider, + StaticStore.GetPlan(providerPlan.PlanType), + 110, + providerPlan.SeatMinimum!.Value); + + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + org => org.Id == organization.Id && org.Seats == seats)); + + // Being below the seat minimum means no purchased seats. + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + pPlan => pPlan.Id == providerPlan.Id && pPlan.PurchasedSeats == 0 && pPlan.AllocatedSeats == 80)); + } + + #endregion + + #region CreateCustomer + + [Theory, BitAutoData] + public async Task CreateCustomer_NullProvider_ThrowsArgumentNullException( + SutProvider sutProvider, + TaxInfo taxInfo) => + await Assert.ThrowsAsync(() => sutProvider.Sut.CreateCustomer(null, taxInfo)); + + [Theory, BitAutoData] + public async Task CreateCustomer_NullTaxInfo_ThrowsArgumentNullException( + SutProvider sutProvider, + Provider provider) => + await Assert.ThrowsAsync(() => sutProvider.Sut.CreateCustomer(provider, null)); + + [Theory, BitAutoData] + public async Task CreateCustomer_MissingCountry_ContactSupport( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + taxInfo.BillingAddressCountry = null; + + await ThrowsContactSupportAsync(() => sutProvider.Sut.CreateCustomer(provider, taxInfo)); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .CustomerGetAsync(Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task CreateCustomer_MissingPostalCode_ContactSupport( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + taxInfo.BillingAddressCountry = null; + + await ThrowsContactSupportAsync(() => sutProvider.Sut.CreateCustomer(provider, taxInfo)); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .CustomerGetAsync(Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task CreateCustomer_Success( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.Name = "MSP"; + + taxInfo.BillingAddressCountry = "AD"; + + var stripeAdapter = sutProvider.GetDependency(); + + stripeAdapter.CustomerCreateAsync(Arg.Is(o => + o.Address.Country == taxInfo.BillingAddressCountry && + o.Address.PostalCode == taxInfo.BillingAddressPostalCode && + o.Address.Line1 == taxInfo.BillingAddressLine1 && + o.Address.Line2 == taxInfo.BillingAddressLine2 && + o.Address.City == taxInfo.BillingAddressCity && + o.Address.State == taxInfo.BillingAddressState && + o.Coupon == "msp-discount-35" && + o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Email == provider.BillingEmail && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.Metadata["region"] == "" && + o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && + o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + .Returns(new Customer + { + Id = "customer_id", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }); + + await sutProvider.Sut.CreateCustomer(provider, taxInfo); + + await stripeAdapter.Received(1).CustomerCreateAsync(Arg.Is(o => + o.Address.Country == taxInfo.BillingAddressCountry && + o.Address.PostalCode == taxInfo.BillingAddressPostalCode && + o.Address.Line1 == taxInfo.BillingAddressLine1 && + o.Address.Line2 == taxInfo.BillingAddressLine2 && + o.Address.City == taxInfo.BillingAddressCity && + o.Address.State == taxInfo.BillingAddressState && + o.Coupon == "msp-discount-35" && + o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Email == provider.BillingEmail && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.Metadata["region"] == "" && + o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && + o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)); + + await sutProvider.GetDependency() + .ReplaceAsync(Arg.Is(p => p.GatewayCustomerId == "customer_id")); + } + + #endregion + + #region CreateCustomerForClientOrganization + + [Theory, BitAutoData] + public async Task CreateCustomerForClientOrganization_ProviderNull_ThrowsArgumentNullException( + Organization organization, + SutProvider sutProvider) => + await Assert.ThrowsAsync(() => + sutProvider.Sut.CreateCustomerForClientOrganization(null, organization)); + + [Theory, BitAutoData] + public async Task CreateCustomerForClientOrganization_OrganizationNull_ThrowsArgumentNullException( + Provider provider, + SutProvider sutProvider) => + await Assert.ThrowsAsync(() => + sutProvider.Sut.CreateCustomerForClientOrganization(provider, null)); + + [Theory, BitAutoData] + public async Task CreateCustomerForClientOrganization_HasGatewayCustomerId_NoOp( + Provider provider, + Organization organization, + SutProvider sutProvider) + { + organization.GatewayCustomerId = "customer_id"; + + await sutProvider.Sut.CreateCustomerForClientOrganization(provider, organization); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .GetCustomerOrThrow(Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task CreateCustomer_ForClientOrg_Succeeds( + Provider provider, + Organization organization, + SutProvider sutProvider) + { + organization.GatewayCustomerId = null; + organization.Name = "Name"; + organization.BusinessName = "BusinessName"; + + var providerCustomer = new Customer + { + Address = new Address + { + Country = "USA", + PostalCode = "12345", + Line1 = "123 Main St.", + Line2 = "Unit 4", + City = "Fake Town", + State = "Fake State" + }, + TaxIds = new StripeList + { + Data = + [ + new TaxId { Type = "TYPE", Value = "VALUE" } + ] + } + }; + + sutProvider.GetDependency().GetCustomerOrThrow(provider, Arg.Is( + options => options.Expand.FirstOrDefault() == "tax_ids")) + .Returns(providerCustomer); + + sutProvider.GetDependency().BaseServiceUri + .Returns(new Bit.Core.Settings.GlobalSettings.BaseServiceUriSettings(new Bit.Core.Settings.GlobalSettings()) + { + CloudRegion = "US" + }); + + sutProvider.GetDependency().CustomerCreateAsync(Arg.Is( + options => + options.Address.Country == providerCustomer.Address.Country && + options.Address.PostalCode == providerCustomer.Address.PostalCode && + options.Address.Line1 == providerCustomer.Address.Line1 && + options.Address.Line2 == providerCustomer.Address.Line2 && + options.Address.City == providerCustomer.Address.City && + options.Address.State == providerCustomer.Address.State && + options.Name == organization.DisplayName() && + options.Description == $"{provider.Name} Client Organization" && + options.Email == provider.BillingEmail && + options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" && + options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" && + options.Metadata["region"] == "US" && + options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type && + options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value)) + .Returns(new Customer { Id = "customer_id" }); + + await sutProvider.Sut.CreateCustomerForClientOrganization(provider, organization); + + await sutProvider.GetDependency().Received(1).CustomerCreateAsync(Arg.Is( + options => + options.Address.Country == providerCustomer.Address.Country && + options.Address.PostalCode == providerCustomer.Address.PostalCode && + options.Address.Line1 == providerCustomer.Address.Line1 && + options.Address.Line2 == providerCustomer.Address.Line2 && + options.Address.City == providerCustomer.Address.City && + options.Address.State == providerCustomer.Address.State && + options.Name == organization.DisplayName() && + options.Description == $"{provider.Name} Client Organization" && + options.Email == provider.BillingEmail && + options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" && + options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" && + options.Metadata["region"] == "US" && + options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type && + options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value)); + + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( + org => org.GatewayCustomerId == "customer_id")); + } + + #endregion + + #region GetAssignedSeatTotalForPlanOrThrow + + [Theory, BitAutoData] + public async Task GetAssignedSeatTotalForPlanOrThrow_NullProvider_ContactSupport( + Guid providerId, + SutProvider sutProvider) + => await ThrowsContactSupportAsync(() => + sutProvider.Sut.GetAssignedSeatTotalForPlanOrThrow(providerId, PlanType.TeamsMonthly)); + + [Theory, BitAutoData] + public async Task GetAssignedSeatTotalForPlanOrThrow_ResellerProvider_ContactSupport( + Guid providerId, + Provider provider, + SutProvider sutProvider) + { + provider.Type = ProviderType.Reseller; + + sutProvider.GetDependency().GetByIdAsync(providerId).Returns(provider); + + await ThrowsContactSupportAsync( + () => sutProvider.Sut.GetAssignedSeatTotalForPlanOrThrow(providerId, PlanType.TeamsMonthly), + internalMessage: "Consolidated billing does not support reseller-type providers"); + } + + [Theory, BitAutoData] + public async Task GetAssignedSeatTotalForPlanOrThrow_Succeeds( + Guid providerId, + Provider provider, + SutProvider sutProvider) + { + provider.Type = ProviderType.Msp; + + sutProvider.GetDependency().GetByIdAsync(providerId).Returns(provider); + + var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var enterpriseMonthlyPlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + + var providerOrganizationOrganizationDetailList = new List + { + new() { Plan = teamsMonthlyPlan.Name, Status = OrganizationStatusType.Managed, Seats = 10 }, + new() { Plan = teamsMonthlyPlan.Name, Status = OrganizationStatusType.Managed, Seats = 10 }, + new() + { + // Ignored because of status. + Plan = teamsMonthlyPlan.Name, Status = OrganizationStatusType.Created, Seats = 100 + }, + new() + { + // Ignored because of plan. + Plan = enterpriseMonthlyPlan.Name, Status = OrganizationStatusType.Managed, Seats = 30 + } + }; + + sutProvider.GetDependency() + .GetManyDetailsByProviderAsync(providerId) + .Returns(providerOrganizationOrganizationDetailList); + + var assignedSeatTotal = + await sutProvider.Sut.GetAssignedSeatTotalForPlanOrThrow(providerId, PlanType.TeamsMonthly); + + Assert.Equal(20, assignedSeatTotal); + } + + #endregion + + #region GetSubscriptionData + + [Theory, BitAutoData] + public async Task GetSubscriptionData_NullProvider_ReturnsNull( + SutProvider sutProvider, + Guid providerId) + { + var providerRepository = sutProvider.GetDependency(); + + providerRepository.GetByIdAsync(providerId).ReturnsNull(); + + var subscriptionData = await sutProvider.Sut.GetSubscriptionDTO(providerId); + + Assert.Null(subscriptionData); + + await providerRepository.Received(1).GetByIdAsync(providerId); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionData_NullSubscription_ReturnsNull( + SutProvider sutProvider, + Guid providerId, + Provider provider) + { + var providerRepository = sutProvider.GetDependency(); + + providerRepository.GetByIdAsync(providerId).Returns(provider); + + var subscriberService = sutProvider.GetDependency(); + + subscriberService.GetSubscription(provider).ReturnsNull(); + + var subscriptionData = await sutProvider.Sut.GetSubscriptionDTO(providerId); + + Assert.Null(subscriptionData); + + await providerRepository.Received(1).GetByIdAsync(providerId); + + await subscriberService.Received(1).GetSubscription( + provider, + Arg.Is( + options => options.Expand.Count == 1 && options.Expand.First() == "customer")); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionData_Success( + SutProvider sutProvider, + Guid providerId, + Provider provider) + { + var providerRepository = sutProvider.GetDependency(); + + providerRepository.GetByIdAsync(providerId).Returns(provider); + + var subscriberService = sutProvider.GetDependency(); + + var subscription = new Subscription(); + + subscriberService.GetSubscription(provider, Arg.Is( + options => options.Expand.Count == 1 && options.Expand.First() == "customer")).Returns(subscription); + + var providerPlanRepository = sutProvider.GetDependency(); + + var enterprisePlan = new ProviderPlan + { + Id = Guid.NewGuid(), + ProviderId = providerId, + PlanType = PlanType.EnterpriseMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 0 + }; + + var teamsPlan = new ProviderPlan + { + Id = Guid.NewGuid(), + ProviderId = providerId, + PlanType = PlanType.TeamsMonthly, + SeatMinimum = 50, + PurchasedSeats = 10, + AllocatedSeats = 60 + }; + + var providerPlans = new List { enterprisePlan, teamsPlan, }; + + providerPlanRepository.GetByProviderId(providerId).Returns(providerPlans); + + var subscriptionData = await sutProvider.Sut.GetSubscriptionDTO(providerId); + + Assert.NotNull(subscriptionData); + + Assert.Equivalent(subscriptionData.Subscription, subscription); + + Assert.Equal(2, subscriptionData.ProviderPlans.Count); + + var configuredEnterprisePlan = + subscriptionData.ProviderPlans.FirstOrDefault(configuredPlan => + configuredPlan.PlanType == PlanType.EnterpriseMonthly); + + var configuredTeamsPlan = + subscriptionData.ProviderPlans.FirstOrDefault(configuredPlan => + configuredPlan.PlanType == PlanType.TeamsMonthly); + + Compare(enterprisePlan, configuredEnterprisePlan); + + Compare(teamsPlan, configuredTeamsPlan); + + await providerRepository.Received(1).GetByIdAsync(providerId); + + await subscriberService.Received(1).GetSubscription( + provider, + Arg.Is( + options => options.Expand.Count == 1 && options.Expand.First() == "customer")); + + await providerPlanRepository.Received(1).GetByProviderId(providerId); + + return; + + void Compare(ProviderPlan providerPlan, ConfiguredProviderPlanDTO configuredProviderPlan) + { + Assert.NotNull(configuredProviderPlan); + Assert.Equal(providerPlan.Id, configuredProviderPlan.Id); + Assert.Equal(providerPlan.ProviderId, configuredProviderPlan.ProviderId); + Assert.Equal(providerPlan.SeatMinimum!.Value, configuredProviderPlan.SeatMinimum); + Assert.Equal(providerPlan.PurchasedSeats!.Value, configuredProviderPlan.PurchasedSeats); + Assert.Equal(providerPlan.AllocatedSeats!.Value, configuredProviderPlan.AssignedSeats); + } + } + + #endregion + + #region StartSubscription + + [Theory, BitAutoData] + public async Task StartSubscription_NullProvider_ThrowsArgumentNullException( + SutProvider sutProvider) => + await Assert.ThrowsAsync(() => sutProvider.Sut.StartSubscription(null)); + + [Theory, BitAutoData] + public async Task StartSubscription_AlreadyHasGatewaySubscriptionId_ContactSupport( + SutProvider sutProvider, + Provider provider) + { + provider.GatewaySubscriptionId = "subscription_id"; + + await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider)); + } + + [Theory, BitAutoData] + public async Task StartSubscription_NoProviderPlans_ContactSupport( + SutProvider sutProvider, + Provider provider) + { + provider.GatewaySubscriptionId = null; + + sutProvider.GetDependency().GetCustomerOrThrow(provider).Returns(new Customer + { + Id = "customer_id", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }); + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(new List()); + + await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider)); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SubscriptionCreateAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task StartSubscription_NoProviderTeamsPlan_ContactSupport( + SutProvider sutProvider, + Provider provider) + { + provider.GatewaySubscriptionId = null; + + sutProvider.GetDependency().GetCustomerOrThrow(provider).Returns(new Customer + { + Id = "customer_id", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }); + + var providerPlans = new List { new() { PlanType = PlanType.EnterpriseMonthly } }; + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(providerPlans); + + await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider)); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SubscriptionCreateAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task StartSubscription_NoProviderEnterprisePlan_ContactSupport( + SutProvider sutProvider, + Provider provider) + { + provider.GatewaySubscriptionId = null; + + sutProvider.GetDependency().GetCustomerOrThrow(provider).Returns(new Customer + { + Id = "customer_id", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }); + + var providerPlans = new List { new() { PlanType = PlanType.TeamsMonthly } }; + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(providerPlans); + + await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider)); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SubscriptionCreateAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task StartSubscription_SubscriptionIncomplete_ThrowsBillingException( + SutProvider sutProvider, + Provider provider) + { + provider.GatewaySubscriptionId = null; + + sutProvider.GetDependency().GetCustomerOrThrow(provider).Returns(new Customer + { + Id = "customer_id", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }); + + var providerPlans = new List + { + new() { PlanType = PlanType.TeamsMonthly, SeatMinimum = 100 }, + new() { PlanType = PlanType.EnterpriseMonthly, SeatMinimum = 100 } + }; + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(providerPlans); + + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Any()) + .Returns( + new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Incomplete }); + + await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider)); + + await sutProvider.GetDependency().Received(1) + .ReplaceAsync(Arg.Is(p => p.GatewaySubscriptionId == "subscription_id")); + } + + [Theory, BitAutoData] + public async Task StartSubscription_Succeeds( + SutProvider sutProvider, + Provider provider) + { + provider.GatewaySubscriptionId = null; + + sutProvider.GetDependency().GetCustomerOrThrow(provider).Returns(new Customer + { + Id = "customer_id", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }); + + var providerPlans = new List + { + new() { PlanType = PlanType.TeamsMonthly, SeatMinimum = 100 }, + new() { PlanType = PlanType.EnterpriseMonthly, SeatMinimum = 100 } + }; + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(providerPlans); + + var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sub => + sub.AutomaticTax.Enabled == true && + sub.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && + sub.Customer == "customer_id" && + sub.DaysUntilDue == 30 && + sub.Items.Count == 2 && + sub.Items.ElementAt(0).Price == teamsPlan.PasswordManager.StripeSeatPlanId && + sub.Items.ElementAt(0).Quantity == 100 && + sub.Items.ElementAt(1).Price == enterprisePlan.PasswordManager.StripeSeatPlanId && + sub.Items.ElementAt(1).Quantity == 100 && + sub.Metadata["providerId"] == provider.Id.ToString() && + sub.OffSession == true && + sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations)).Returns(new Subscription + { + Id = "subscription_id", + Status = StripeConstants.SubscriptionStatus.Active + }); + + await sutProvider.Sut.StartSubscription(provider); + + await sutProvider.GetDependency().Received(1) + .ReplaceAsync(Arg.Is(p => p.GatewaySubscriptionId == "subscription_id")); + } + + #endregion + + #region GetPaymentInformationAsync + [Theory, BitAutoData] + public async Task GetPaymentInformationAsync_NullProvider_ReturnsNull( + SutProvider sutProvider, + Guid providerId) + { + var providerRepository = sutProvider.GetDependency(); + providerRepository.GetByIdAsync(providerId).ReturnsNull(); + + var paymentService = sutProvider.GetDependency(); + paymentService.GetTaxInformationAsync(Arg.Any()).ReturnsNull(); + paymentService.GetPaymentMethodAsync(Arg.Any()).ReturnsNull(); + + var sut = sutProvider.Sut; + + var paymentInfo = await sut.GetPaymentInformationAsync(providerId); + + Assert.Null(paymentInfo); + await providerRepository.Received(1).GetByIdAsync(providerId); + await paymentService.DidNotReceive().GetTaxInformationAsync(Arg.Any()); + await paymentService.DidNotReceive().GetPaymentMethodAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task GetPaymentInformationAsync_NullSubscription_ReturnsNull( + SutProvider sutProvider, + Guid providerId, + Provider provider) + { + var providerRepository = sutProvider.GetDependency(); + + providerRepository.GetByIdAsync(providerId).Returns(provider); + + var subscriberService = sutProvider.GetDependency(); + + subscriberService.GetTaxInformationAsync(provider).ReturnsNull(); + subscriberService.GetPaymentMethodAsync(provider).ReturnsNull(); + + var paymentInformation = await sutProvider.Sut.GetPaymentInformationAsync(providerId); + + Assert.Null(paymentInformation); + await providerRepository.Received(1).GetByIdAsync(providerId); + await subscriberService.Received(1).GetTaxInformationAsync(provider); + await subscriberService.Received(1).GetPaymentMethodAsync(provider); + } + + [Theory, BitAutoData] + public async Task GetPaymentInformationAsync_ResellerProvider_ThrowContactSupport( + SutProvider sutProvider, + Guid providerId, + Provider provider) + { + provider.Id = providerId; + provider.Type = ProviderType.Reseller; + var providerRepository = sutProvider.GetDependency(); + providerRepository.GetByIdAsync(providerId).Returns(provider); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.GetPaymentInformationAsync(providerId)); + + Assert.Equal("Consolidated billing does not support reseller-type providers", exception.Message); + } + + [Theory, BitAutoData] + public async Task GetPaymentInformationAsync_Success_ReturnsProviderPaymentInfoDTO( + SutProvider sutProvider, + Guid providerId, + Provider provider) + { + provider.Id = providerId; + provider.Type = ProviderType.Msp; + var taxInformation = new TaxInfo { TaxIdNumber = "12345" }; + var paymentMethod = new PaymentMethod + { + Id = "pm_test123", + Type = "card", + Card = new PaymentMethodCard + { + Brand = "visa", + Last4 = "4242", + ExpMonth = 12, + ExpYear = 2024 + } + }; + var billingInformation = new BillingInfo { PaymentSource = new BillingInfo.BillingSource(paymentMethod) }; + + var providerRepository = sutProvider.GetDependency(); + providerRepository.GetByIdAsync(providerId).Returns(provider); + + var subscriberService = sutProvider.GetDependency(); + subscriberService.GetTaxInformationAsync(provider).Returns(taxInformation); + subscriberService.GetPaymentMethodAsync(provider).Returns(billingInformation.PaymentSource); + + var result = await sutProvider.Sut.GetPaymentInformationAsync(providerId); + + // Assert + Assert.NotNull(result); + Assert.Equal(billingInformation.PaymentSource, result.billingSource); + Assert.Equal(taxInformation, result.taxInfo); + } + #endregion +} diff --git a/src/Admin/AdminConsole/Controllers/OrganizationsController.cs b/src/Admin/AdminConsole/Controllers/OrganizationsController.cs index 88c046746..c72beb421 100644 --- a/src/Admin/AdminConsole/Controllers/OrganizationsController.cs +++ b/src/Admin/AdminConsole/Controllers/OrganizationsController.cs @@ -7,8 +7,8 @@ using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Providers.Interfaces; using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Commands; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -54,9 +54,8 @@ public class OrganizationsController : Controller private readonly IServiceAccountRepository _serviceAccountRepository; private readonly IProviderOrganizationRepository _providerOrganizationRepository; private readonly IRemoveOrganizationFromProviderCommand _removeOrganizationFromProviderCommand; - private readonly IRemovePaymentMethodCommand _removePaymentMethodCommand; private readonly IFeatureService _featureService; - private readonly IScaleSeatsCommand _scaleSeatsCommand; + private readonly IProviderBillingService _providerBillingService; public OrganizationsController( IOrganizationService organizationService, @@ -82,9 +81,8 @@ public class OrganizationsController : Controller IServiceAccountRepository serviceAccountRepository, IProviderOrganizationRepository providerOrganizationRepository, IRemoveOrganizationFromProviderCommand removeOrganizationFromProviderCommand, - IRemovePaymentMethodCommand removePaymentMethodCommand, IFeatureService featureService, - IScaleSeatsCommand scaleSeatsCommand) + IProviderBillingService providerBillingService) { _organizationService = organizationService; _organizationRepository = organizationRepository; @@ -109,9 +107,8 @@ public class OrganizationsController : Controller _serviceAccountRepository = serviceAccountRepository; _providerOrganizationRepository = providerOrganizationRepository; _removeOrganizationFromProviderCommand = removeOrganizationFromProviderCommand; - _removePaymentMethodCommand = removePaymentMethodCommand; _featureService = featureService; - _scaleSeatsCommand = scaleSeatsCommand; + _providerBillingService = providerBillingService; } [RequirePermission(Permission.Org_List_View)] @@ -256,7 +253,7 @@ public class OrganizationsController : Controller if (provider.IsBillable()) { - await _scaleSeatsCommand.ScalePasswordManagerSeats( + await _providerBillingService.ScaleSeats( provider, organization.PlanType, -organization.Seats ?? 0); @@ -378,11 +375,6 @@ public class OrganizationsController : Controller providerOrganization, organization); - if (organization.IsStripeEnabled()) - { - await _removePaymentMethodCommand.RemovePaymentMethod(organization); - } - return Json(null); } private async Task GetOrganization(Guid id, OrganizationEditModel model) @@ -443,5 +435,4 @@ public class OrganizationsController : Controller return organization; } - } diff --git a/src/Admin/AdminConsole/Controllers/ProviderOrganizationsController.cs b/src/Admin/AdminConsole/Controllers/ProviderOrganizationsController.cs index b143a2c42..217d6fc18 100644 --- a/src/Admin/AdminConsole/Controllers/ProviderOrganizationsController.cs +++ b/src/Admin/AdminConsole/Controllers/ProviderOrganizationsController.cs @@ -2,8 +2,6 @@ using Bit.Admin.Utilities; using Bit.Core.AdminConsole.Providers.Interfaces; using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Commands; -using Bit.Core.Billing.Extensions; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Utilities; @@ -20,19 +18,16 @@ public class ProviderOrganizationsController : Controller private readonly IProviderOrganizationRepository _providerOrganizationRepository; private readonly IOrganizationRepository _organizationRepository; private readonly IRemoveOrganizationFromProviderCommand _removeOrganizationFromProviderCommand; - private readonly IRemovePaymentMethodCommand _removePaymentMethodCommand; public ProviderOrganizationsController(IProviderRepository providerRepository, IProviderOrganizationRepository providerOrganizationRepository, IOrganizationRepository organizationRepository, - IRemoveOrganizationFromProviderCommand removeOrganizationFromProviderCommand, - IRemovePaymentMethodCommand removePaymentMethodCommand) + IRemoveOrganizationFromProviderCommand removeOrganizationFromProviderCommand) { _providerRepository = providerRepository; _providerOrganizationRepository = providerOrganizationRepository; _organizationRepository = organizationRepository; _removeOrganizationFromProviderCommand = removeOrganizationFromProviderCommand; - _removePaymentMethodCommand = removePaymentMethodCommand; } [HttpPost] @@ -69,12 +64,6 @@ public class ProviderOrganizationsController : Controller return BadRequest(ex.Message); } - - if (organization.IsStripeEnabled()) - { - await _removePaymentMethodCommand.RemovePaymentMethod(organization); - } - return Json(null); } } diff --git a/src/Api/AdminConsole/Controllers/OrganizationsController.cs b/src/Api/AdminConsole/Controllers/OrganizationsController.cs index 979c5d16d..b75d053c4 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationsController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationsController.cs @@ -19,8 +19,8 @@ using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; using Bit.Core.Auth.Services; -using Bit.Core.Billing.Commands; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -55,7 +55,7 @@ public class OrganizationsController : Controller private readonly IPushNotificationService _pushNotificationService; private readonly IOrganizationEnableCollectionEnhancementsCommand _organizationEnableCollectionEnhancementsCommand; private readonly IProviderRepository _providerRepository; - private readonly IScaleSeatsCommand _scaleSeatsCommand; + private readonly IProviderBillingService _providerBillingService; private readonly IDataProtectorTokenFactory _orgDeleteTokenDataFactory; public OrganizationsController( @@ -76,7 +76,7 @@ public class OrganizationsController : Controller IPushNotificationService pushNotificationService, IOrganizationEnableCollectionEnhancementsCommand organizationEnableCollectionEnhancementsCommand, IProviderRepository providerRepository, - IScaleSeatsCommand scaleSeatsCommand, + IProviderBillingService providerBillingService, IDataProtectorTokenFactory orgDeleteTokenDataFactory) { _organizationRepository = organizationRepository; @@ -96,7 +96,7 @@ public class OrganizationsController : Controller _pushNotificationService = pushNotificationService; _organizationEnableCollectionEnhancementsCommand = organizationEnableCollectionEnhancementsCommand; _providerRepository = providerRepository; - _scaleSeatsCommand = scaleSeatsCommand; + _providerBillingService = providerBillingService; _orgDeleteTokenDataFactory = orgDeleteTokenDataFactory; } @@ -274,7 +274,7 @@ public class OrganizationsController : Controller if (provider.IsBillable()) { - await _scaleSeatsCommand.ScalePasswordManagerSeats( + await _providerBillingService.ScaleSeats( provider, organization.PlanType, -organization.Seats ?? 0); @@ -305,7 +305,7 @@ public class OrganizationsController : Controller var provider = await _providerRepository.GetByOrganizationIdAsync(organization.Id); if (provider.IsBillable()) { - await _scaleSeatsCommand.ScalePasswordManagerSeats( + await _providerBillingService.ScaleSeats( provider, organization.PlanType, -organization.Seats ?? 0); diff --git a/src/Api/AdminConsole/Controllers/ProviderOrganizationsController.cs b/src/Api/AdminConsole/Controllers/ProviderOrganizationsController.cs index e2becc9b8..7cdab7348 100644 --- a/src/Api/AdminConsole/Controllers/ProviderOrganizationsController.cs +++ b/src/Api/AdminConsole/Controllers/ProviderOrganizationsController.cs @@ -4,8 +4,6 @@ using Bit.Api.Models.Response; using Bit.Core.AdminConsole.Providers.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; -using Bit.Core.Billing.Commands; -using Bit.Core.Billing.Extensions; using Bit.Core.Context; using Bit.Core.Exceptions; using Bit.Core.Repositories; @@ -26,7 +24,6 @@ public class ProviderOrganizationsController : Controller private readonly IProviderRepository _providerRepository; private readonly IProviderService _providerService; private readonly IRemoveOrganizationFromProviderCommand _removeOrganizationFromProviderCommand; - private readonly IRemovePaymentMethodCommand _removePaymentMethodCommand; private readonly IUserService _userService; public ProviderOrganizationsController( @@ -36,7 +33,6 @@ public class ProviderOrganizationsController : Controller IProviderRepository providerRepository, IProviderService providerService, IRemoveOrganizationFromProviderCommand removeOrganizationFromProviderCommand, - IRemovePaymentMethodCommand removePaymentMethodCommand, IUserService userService) { _currentContext = currentContext; @@ -45,7 +41,6 @@ public class ProviderOrganizationsController : Controller _providerRepository = providerRepository; _providerService = providerService; _removeOrganizationFromProviderCommand = removeOrganizationFromProviderCommand; - _removePaymentMethodCommand = removePaymentMethodCommand; _userService = userService; } @@ -112,10 +107,5 @@ public class ProviderOrganizationsController : Controller provider, providerOrganization, organization); - - if (organization.IsStripeEnabled()) - { - await _removePaymentMethodCommand.RemovePaymentMethod(organization); - } } } diff --git a/src/Api/AdminConsole/Controllers/ProvidersController.cs b/src/Api/AdminConsole/Controllers/ProvidersController.cs index e1c908a1b..51cf4c7e3 100644 --- a/src/Api/AdminConsole/Controllers/ProvidersController.cs +++ b/src/Api/AdminConsole/Controllers/ProvidersController.cs @@ -3,7 +3,7 @@ using Bit.Api.AdminConsole.Models.Response.Providers; using Bit.Core; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; -using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -24,13 +24,13 @@ public class ProvidersController : Controller private readonly ICurrentContext _currentContext; private readonly GlobalSettings _globalSettings; private readonly IFeatureService _featureService; - private readonly IStartSubscriptionCommand _startSubscriptionCommand; private readonly ILogger _logger; + private readonly IProviderBillingService _providerBillingService; public ProvidersController(IUserService userService, IProviderRepository providerRepository, IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings, - IFeatureService featureService, IStartSubscriptionCommand startSubscriptionCommand, - ILogger logger) + IFeatureService featureService, ILogger logger, + IProviderBillingService providerBillingService) { _userService = userService; _providerRepository = providerRepository; @@ -38,8 +38,8 @@ public class ProvidersController : Controller _currentContext = currentContext; _globalSettings = globalSettings; _featureService = featureService; - _startSubscriptionCommand = startSubscriptionCommand; _logger = logger; + _providerBillingService = providerBillingService; } [HttpGet("{id:guid}")] @@ -112,7 +112,9 @@ public class ProvidersController : Controller try { - await _startSubscriptionCommand.StartSubscription(provider, taxInfo); + await _providerBillingService.CreateCustomer(provider, taxInfo); + + await _providerBillingService.StartSubscription(provider); } catch { diff --git a/src/Api/Auth/Controllers/AccountsController.cs b/src/Api/Auth/Controllers/AccountsController.cs index da76f3540..0f8543378 100644 --- a/src/Api/Auth/Controllers/AccountsController.cs +++ b/src/Api/Auth/Controllers/AccountsController.cs @@ -21,9 +21,8 @@ using Bit.Core.Auth.Services; using Bit.Core.Auth.UserFeatures.UserKey; using Bit.Core.Auth.UserFeatures.UserMasterPassword.Interfaces; using Bit.Core.Auth.Utilities; -using Bit.Core.Billing.Commands; using Bit.Core.Billing.Models; -using Bit.Core.Billing.Queries; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -67,8 +66,7 @@ public class AccountsController : Controller private readonly ISetInitialMasterPasswordCommand _setInitialMasterPasswordCommand; private readonly IRotateUserKeyCommand _rotateUserKeyCommand; private readonly IFeatureService _featureService; - private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand; - private readonly ISubscriberQueries _subscriberQueries; + private readonly ISubscriberService _subscriberService; private readonly IReferenceEventService _referenceEventService; private readonly ICurrentContext _currentContext; @@ -102,8 +100,7 @@ public class AccountsController : Controller ISetInitialMasterPasswordCommand setInitialMasterPasswordCommand, IRotateUserKeyCommand rotateUserKeyCommand, IFeatureService featureService, - ICancelSubscriptionCommand cancelSubscriptionCommand, - ISubscriberQueries subscriberQueries, + ISubscriberService subscriberService, IReferenceEventService referenceEventService, ICurrentContext currentContext, IRotationValidator, IEnumerable> cipherValidator, @@ -131,8 +128,7 @@ public class AccountsController : Controller _setInitialMasterPasswordCommand = setInitialMasterPasswordCommand; _rotateUserKeyCommand = rotateUserKeyCommand; _featureService = featureService; - _cancelSubscriptionCommand = cancelSubscriptionCommand; - _subscriberQueries = subscriberQueries; + _subscriberService = subscriberService; _referenceEventService = referenceEventService; _currentContext = currentContext; _cipherValidator = cipherValidator; @@ -798,9 +794,7 @@ public class AccountsController : Controller throw new UnauthorizedAccessException(); } - var subscription = await _subscriberQueries.GetSubscriptionOrThrow(user); - - await _cancelSubscriptionCommand.CancelSubscription(subscription, + await _subscriberService.CancelSubscription(user, new OffboardingSurveyResponse { UserId = user.Id, @@ -841,7 +835,7 @@ public class AccountsController : Controller throw new UnauthorizedAccessException(); } - var taxInfo = await _paymentService.GetTaxInfoAsync(user); + var taxInfo = await _subscriberService.GetTaxInformationAsync(user); return new TaxInfoResponseModel(taxInfo); } diff --git a/src/Api/Billing/Controllers/OrganizationBillingController.cs b/src/Api/Billing/Controllers/OrganizationBillingController.cs index a16c0c42f..b0c754589 100644 --- a/src/Api/Billing/Controllers/OrganizationBillingController.cs +++ b/src/Api/Billing/Controllers/OrganizationBillingController.cs @@ -1,8 +1,7 @@ using Bit.Api.Billing.Models.Responses; using Bit.Api.Models.Response; -using Bit.Core.Billing.Queries; +using Bit.Core.Billing.Services; using Bit.Core.Context; -using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Utilities; @@ -14,15 +13,15 @@ namespace Bit.Api.Billing.Controllers; [Route("organizations/{organizationId:guid}/billing")] [Authorize("Application")] public class OrganizationBillingController( - IOrganizationBillingQueries organizationBillingQueries, ICurrentContext currentContext, + IOrganizationBillingService organizationBillingService, IOrganizationRepository organizationRepository, IPaymentService paymentService) : Controller { [HttpGet("metadata")] public async Task GetMetadataAsync([FromRoute] Guid organizationId) { - var metadata = await organizationBillingQueries.GetMetadata(organizationId); + var metadata = await organizationBillingService.GetMetadata(organizationId); if (metadata == null) { @@ -36,20 +35,24 @@ public class OrganizationBillingController( [HttpGet] [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetBilling(Guid organizationId) + public async Task GetBillingAsync(Guid organizationId) { if (!await currentContext.ViewBillingHistory(organizationId)) { - throw new NotFoundException(); + return TypedResults.Unauthorized(); } var organization = await organizationRepository.GetByIdAsync(organizationId); + if (organization == null) { - throw new NotFoundException(); + return TypedResults.NotFound(); } var billingInfo = await paymentService.GetBillingAsync(organization); - return new BillingResponseModel(billingInfo); + + var response = new BillingResponseModel(billingInfo); + + return TypedResults.Ok(response); } } diff --git a/src/Api/Billing/Controllers/OrganizationsController.cs b/src/Api/Billing/Controllers/OrganizationsController.cs index f418e07f9..f3718ab10 100644 --- a/src/Api/Billing/Controllers/OrganizationsController.cs +++ b/src/Api/Billing/Controllers/OrganizationsController.cs @@ -5,10 +5,9 @@ using Bit.Api.Models.Request; using Bit.Api.Models.Request.Organizations; using Bit.Api.Models.Response; using Bit.Core.AdminConsole.Entities; -using Bit.Core.Billing.Commands; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; -using Bit.Core.Billing.Queries; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -42,9 +41,8 @@ public class OrganizationsController( IUpdateSecretsManagerSubscriptionCommand updateSecretsManagerSubscriptionCommand, IUpgradeOrganizationPlanCommand upgradeOrganizationPlanCommand, IAddSecretsManagerSubscriptionCommand addSecretsManagerSubscriptionCommand, - ICancelSubscriptionCommand cancelSubscriptionCommand, - ISubscriberQueries subscriberQueries, - IReferenceEventService referenceEventService) + IReferenceEventService referenceEventService, + ISubscriberService subscriberService) : Controller { [HttpGet("{id}/billing-status")] @@ -261,9 +259,7 @@ public class OrganizationsController( throw new NotFoundException(); } - var subscription = await subscriberQueries.GetSubscriptionOrThrow(organization); - - await cancelSubscriptionCommand.CancelSubscription(subscription, + await subscriberService.CancelSubscription(organization, new OffboardingSurveyResponse { UserId = currentContext.UserId!.Value, @@ -308,7 +304,7 @@ public class OrganizationsController( throw new NotFoundException(); } - var taxInfo = await paymentService.GetTaxInfoAsync(organization); + var taxInfo = await subscriberService.GetTaxInformationAsync(organization); return new TaxInfoResponseModel(taxInfo); } diff --git a/src/Api/Billing/Controllers/ProviderBillingController.cs b/src/Api/Billing/Controllers/ProviderBillingController.cs index 2f33dd50d..3bc932fc4 100644 --- a/src/Api/Billing/Controllers/ProviderBillingController.cs +++ b/src/Api/Billing/Controllers/ProviderBillingController.cs @@ -1,6 +1,6 @@ using Bit.Api.Billing.Models.Responses; using Bit.Core; -using Bit.Core.Billing.Queries; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; @@ -13,7 +13,7 @@ namespace Bit.Api.Billing.Controllers; public class ProviderBillingController( ICurrentContext currentContext, IFeatureService featureService, - IProviderBillingQueries providerBillingQueries) : Controller + IProviderBillingService providerBillingService) : Controller { [HttpGet("subscription")] public async Task GetSubscriptionAsync([FromRoute] Guid providerId) @@ -28,7 +28,7 @@ public class ProviderBillingController( return TypedResults.Unauthorized(); } - var providerSubscriptionDTO = await providerBillingQueries.GetSubscriptionDTO(providerId); + var providerSubscriptionDTO = await providerBillingService.GetSubscriptionDTO(providerId); if (providerSubscriptionDTO == null) { @@ -41,4 +41,31 @@ public class ProviderBillingController( return TypedResults.Ok(providerSubscriptionResponse); } + + [HttpGet("payment-information")] + public async Task GetPaymentInformationAsync([FromRoute] Guid providerId) + { + if (!featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) + { + return TypedResults.NotFound(); + } + + if (!currentContext.ProviderProviderAdmin(providerId)) + { + return TypedResults.Unauthorized(); + } + + var providerPaymentInformationDto = await providerBillingService.GetPaymentInformationAsync(providerId); + + if (providerPaymentInformationDto == null) + { + return TypedResults.NotFound(); + } + + var (paymentSource, taxInfo) = providerPaymentInformationDto; + + var providerPaymentInformationResponse = PaymentInformationResponse.From(paymentSource, taxInfo); + + return TypedResults.Ok(providerPaymentInformationResponse); + } } diff --git a/src/Api/Billing/Controllers/ProviderClientsController.cs b/src/Api/Billing/Controllers/ProviderClientsController.cs index a47ab568b..ffd74f811 100644 --- a/src/Api/Billing/Controllers/ProviderClientsController.cs +++ b/src/Api/Billing/Controllers/ProviderClientsController.cs @@ -2,7 +2,7 @@ using Bit.Core; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; -using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Models.Business; @@ -14,16 +14,14 @@ namespace Bit.Api.Billing.Controllers; [Route("providers/{providerId:guid}/clients")] public class ProviderClientsController( - IAssignSeatsToClientOrganizationCommand assignSeatsToClientOrganizationCommand, - ICreateCustomerCommand createCustomerCommand, ICurrentContext currentContext, IFeatureService featureService, ILogger logger, IOrganizationRepository organizationRepository, + IProviderBillingService providerBillingService, IProviderOrganizationRepository providerOrganizationRepository, IProviderRepository providerRepository, IProviderService providerService, - IScaleSeatsCommand scaleSeatsCommand, IUserService userService) : Controller { [HttpPost] @@ -83,12 +81,12 @@ public class ProviderClientsController( return TypedResults.Problem(); } - await scaleSeatsCommand.ScalePasswordManagerSeats( + await providerBillingService.ScaleSeats( provider, requestBody.PlanType, requestBody.Seats); - await createCustomerCommand.CreateCustomer( + await providerBillingService.CreateCustomerForClientOrganization( provider, clientOrganization); @@ -135,7 +133,7 @@ public class ProviderClientsController( if (clientOrganization.Seats != requestBody.AssignedSeats) { - await assignSeatsToClientOrganizationCommand.AssignSeatsToClientOrganization( + await providerBillingService.AssignSeatsToClientOrganization( provider, clientOrganization, requestBody.AssignedSeats); diff --git a/src/Api/Billing/Models/Responses/PaymentInformationResponse.cs b/src/Api/Billing/Models/Responses/PaymentInformationResponse.cs new file mode 100644 index 000000000..6d6088e99 --- /dev/null +++ b/src/Api/Billing/Models/Responses/PaymentInformationResponse.cs @@ -0,0 +1,37 @@ +using Bit.Core.Enums; +using Bit.Core.Models.Business; + +namespace Bit.Api.Billing.Models.Responses; + +public record PaymentInformationResponse(PaymentMethod PaymentMethod, TaxInformation TaxInformation) +{ + public static PaymentInformationResponse From(BillingInfo.BillingSource billingSource, TaxInfo taxInfo) + { + var paymentMethodDto = new PaymentMethod( + billingSource.Type, billingSource.Description, billingSource.CardBrand + ); + + var taxInformationDto = new TaxInformation( + taxInfo.BillingAddressCountry, taxInfo.BillingAddressPostalCode, taxInfo.TaxIdNumber, + taxInfo.BillingAddressLine1, taxInfo.BillingAddressLine2, taxInfo.BillingAddressCity, + taxInfo.BillingAddressState + ); + + return new PaymentInformationResponse(paymentMethodDto, taxInformationDto); + } + +} + +public record PaymentMethod( + PaymentMethodType Type, + string Description, + string CardBrand); + +public record TaxInformation( + string Country, + string PostalCode, + string TaxId, + string Line1, + string Line2, + string City, + string State); diff --git a/src/Core/Billing/Commands/IAssignSeatsToClientOrganizationCommand.cs b/src/Core/Billing/Commands/IAssignSeatsToClientOrganizationCommand.cs deleted file mode 100644 index 43adc73d8..000000000 --- a/src/Core/Billing/Commands/IAssignSeatsToClientOrganizationCommand.cs +++ /dev/null @@ -1,21 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; - -namespace Bit.Core.Billing.Commands; - -public interface IAssignSeatsToClientOrganizationCommand -{ - /// - /// Assigns a specified number of to a client on behalf of - /// its . Seat adjustments for the client organization may autoscale the provider's Stripe - /// depending on the provider's seat minimum for the client 's - /// . - /// - /// The MSP that manages the client . - /// The client organization whose you want to update. - /// The number of seats to assign to the client organization. - Task AssignSeatsToClientOrganization( - Provider provider, - Organization organization, - int seats); -} diff --git a/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs b/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs deleted file mode 100644 index 88708d3d2..000000000 --- a/src/Core/Billing/Commands/ICancelSubscriptionCommand.cs +++ /dev/null @@ -1,23 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Billing.Models; -using Bit.Core.Entities; -using Stripe; - -namespace Bit.Core.Billing.Commands; - -public interface ICancelSubscriptionCommand -{ - /// - /// Cancels a user or organization's subscription while including user-provided feedback via the . - /// If the flag is , - /// this command sets the subscription's "cancel_at_end_of_period" property to . - /// Otherwise, this command cancels the subscription immediately. - /// - /// The or with the subscription to cancel. - /// An DTO containing user-provided feedback on why they are cancelling the subscription. - /// A flag indicating whether to cancel the subscription immediately or at the end of the subscription period. - Task CancelSubscription( - Subscription subscription, - OffboardingSurveyResponse offboardingSurveyResponse, - bool cancelImmediately); -} diff --git a/src/Core/Billing/Commands/ICreateCustomerCommand.cs b/src/Core/Billing/Commands/ICreateCustomerCommand.cs deleted file mode 100644 index 0d7994223..000000000 --- a/src/Core/Billing/Commands/ICreateCustomerCommand.cs +++ /dev/null @@ -1,17 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; - -namespace Bit.Core.Billing.Commands; - -public interface ICreateCustomerCommand -{ - /// - /// Create a Stripe for the provided client utilizing - /// the address and tax information of its . - /// - /// The MSP that owns the client organization. - /// The client organization to create a Stripe for. - Task CreateCustomer( - Provider provider, - Organization organization); -} diff --git a/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs b/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs deleted file mode 100644 index e2be6f45e..000000000 --- a/src/Core/Billing/Commands/IRemovePaymentMethodCommand.cs +++ /dev/null @@ -1,15 +0,0 @@ -using Bit.Core.AdminConsole.Entities; - -namespace Bit.Core.Billing.Commands; - -public interface IRemovePaymentMethodCommand -{ - /// - /// Attempts to remove an Organization's saved payment method. If the Stripe representing the - /// contains a valid "btCustomerId" key in its property, - /// this command will attempt to remove the Braintree . Otherwise, it will attempt to remove the - /// Stripe . - /// - /// The organization to remove the saved payment method for. - Task RemovePaymentMethod(Organization organization); -} diff --git a/src/Core/Billing/Commands/IScaleSeatsCommand.cs b/src/Core/Billing/Commands/IScaleSeatsCommand.cs deleted file mode 100644 index 97fe9e2e3..000000000 --- a/src/Core/Billing/Commands/IScaleSeatsCommand.cs +++ /dev/null @@ -1,12 +0,0 @@ -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.Enums; - -namespace Bit.Core.Billing.Commands; - -public interface IScaleSeatsCommand -{ - Task ScalePasswordManagerSeats( - Provider provider, - PlanType planType, - int seatAdjustment); -} diff --git a/src/Core/Billing/Commands/IStartSubscriptionCommand.cs b/src/Core/Billing/Commands/IStartSubscriptionCommand.cs deleted file mode 100644 index 74e9367c4..000000000 --- a/src/Core/Billing/Commands/IStartSubscriptionCommand.cs +++ /dev/null @@ -1,20 +0,0 @@ -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.Enums; -using Bit.Core.Models.Business; - -namespace Bit.Core.Billing.Commands; - -public interface IStartSubscriptionCommand -{ - /// - /// Starts a Stripe for the given utilizing the provided - /// to handle automatic taxation and non-US tax identification. subscriptions - /// will always be started with a for both the and - /// plan, and the quantity for each item will be equal the provider's seat minimum for each respective plan. - /// - /// The provider to create the for. - /// The tax information to use for automatic taxation and non-US tax identification. - Task StartSubscription( - Provider provider, - TaxInfo taxInfo); -} diff --git a/src/Core/Billing/Commands/Implementations/AssignSeatsToClientOrganizationCommand.cs b/src/Core/Billing/Commands/Implementations/AssignSeatsToClientOrganizationCommand.cs deleted file mode 100644 index be2c6be96..000000000 --- a/src/Core/Billing/Commands/Implementations/AssignSeatsToClientOrganizationCommand.cs +++ /dev/null @@ -1,174 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.AdminConsole.Enums.Provider; -using Bit.Core.Billing.Entities; -using Bit.Core.Billing.Extensions; -using Bit.Core.Billing.Queries; -using Bit.Core.Billing.Repositories; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Bit.Core.Utilities; -using Microsoft.Extensions.Logging; -using static Bit.Core.Billing.Utilities; - -namespace Bit.Core.Billing.Commands.Implementations; - -public class AssignSeatsToClientOrganizationCommand( - ILogger logger, - IOrganizationRepository organizationRepository, - IPaymentService paymentService, - IProviderBillingQueries providerBillingQueries, - IProviderPlanRepository providerPlanRepository) : IAssignSeatsToClientOrganizationCommand -{ - public async Task AssignSeatsToClientOrganization( - Provider provider, - Organization organization, - int seats) - { - ArgumentNullException.ThrowIfNull(provider); - ArgumentNullException.ThrowIfNull(organization); - - if (provider.Type == ProviderType.Reseller) - { - logger.LogError("Reseller-type provider ({ID}) cannot assign seats to client organizations", provider.Id); - - throw ContactSupport("Consolidated billing does not support reseller-type providers"); - } - - if (seats < 0) - { - throw new BillingException( - "You cannot assign negative seats to a client.", - "MSP cannot assign negative seats to a client organization"); - } - - if (seats == organization.Seats) - { - logger.LogWarning("Client organization ({ID}) already has {Seats} seats assigned", organization.Id, organization.Seats); - - return; - } - - var providerPlan = await GetProviderPlanAsync(provider, organization); - - var providerSeatMinimum = providerPlan.SeatMinimum.GetValueOrDefault(0); - - // How many seats the provider has assigned to all their client organizations that have the specified plan type. - var providerCurrentlyAssignedSeatTotal = await providerBillingQueries.GetAssignedSeatTotalForPlanOrThrow(provider.Id, providerPlan.PlanType); - - // How many seats are being added to or subtracted from this client organization. - var seatDifference = seats - (organization.Seats ?? 0); - - // How many seats the provider will have assigned to all of their client organizations after the update. - var providerNewlyAssignedSeatTotal = providerCurrentlyAssignedSeatTotal + seatDifference; - - var update = CurryUpdateFunction( - provider, - providerPlan, - organization, - seats, - providerNewlyAssignedSeatTotal); - - /* - * Below the limit => Below the limit: - * No subscription update required. We can safely update the organization's seats. - */ - if (providerCurrentlyAssignedSeatTotal <= providerSeatMinimum && - providerNewlyAssignedSeatTotal <= providerSeatMinimum) - { - organization.Seats = seats; - - await organizationRepository.ReplaceAsync(organization); - - providerPlan.AllocatedSeats = providerNewlyAssignedSeatTotal; - - await providerPlanRepository.ReplaceAsync(providerPlan); - } - /* - * Below the limit => Above the limit: - * We have to scale the subscription up from the seat minimum to the newly assigned seat total. - */ - else if (providerCurrentlyAssignedSeatTotal <= providerSeatMinimum && - providerNewlyAssignedSeatTotal > providerSeatMinimum) - { - await update( - providerSeatMinimum, - providerNewlyAssignedSeatTotal); - } - /* - * Above the limit => Above the limit: - * We have to scale the subscription from the currently assigned seat total to the newly assigned seat total. - */ - else if (providerCurrentlyAssignedSeatTotal > providerSeatMinimum && - providerNewlyAssignedSeatTotal > providerSeatMinimum) - { - await update( - providerCurrentlyAssignedSeatTotal, - providerNewlyAssignedSeatTotal); - } - /* - * Above the limit => Below the limit: - * We have to scale the subscription down from the currently assigned seat total to the seat minimum. - */ - else if (providerCurrentlyAssignedSeatTotal > providerSeatMinimum && - providerNewlyAssignedSeatTotal <= providerSeatMinimum) - { - await update( - providerCurrentlyAssignedSeatTotal, - providerSeatMinimum); - } - } - - // ReSharper disable once SuggestBaseTypeForParameter - private async Task GetProviderPlanAsync(Provider provider, Organization organization) - { - if (!organization.PlanType.SupportsConsolidatedBilling()) - { - logger.LogError("Cannot assign seats to a client organization ({ID}) with a plan type that does not support consolidated billing: {PlanType}", organization.Id, organization.PlanType); - - throw ContactSupport(); - } - - var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); - - var providerPlan = providerPlans.FirstOrDefault(providerPlan => providerPlan.PlanType == organization.PlanType); - - if (providerPlan != null && providerPlan.IsConfigured()) - { - return providerPlan; - } - - logger.LogError("Cannot assign seats to client organization ({ClientOrganizationID}) when provider's ({ProviderID}) matching plan is not configured", organization.Id, provider.Id); - - throw ContactSupport(); - } - - private Func CurryUpdateFunction( - Provider provider, - ProviderPlan providerPlan, - Organization organization, - int organizationNewlyAssignedSeats, - int providerNewlyAssignedSeats) => async (providerCurrentlySubscribedSeats, providerNewlySubscribedSeats) => - { - var plan = StaticStore.GetPlan(providerPlan.PlanType); - - await paymentService.AdjustSeats( - provider, - plan, - providerCurrentlySubscribedSeats, - providerNewlySubscribedSeats); - - organization.Seats = organizationNewlyAssignedSeats; - - await organizationRepository.ReplaceAsync(organization); - - var providerNewlyPurchasedSeats = providerNewlySubscribedSeats > providerPlan.SeatMinimum - ? providerNewlySubscribedSeats - providerPlan.SeatMinimum - : 0; - - providerPlan.PurchasedSeats = providerNewlyPurchasedSeats; - providerPlan.AllocatedSeats = providerNewlyAssignedSeats; - - await providerPlanRepository.ReplaceAsync(providerPlan); - }; -} diff --git a/src/Core/Billing/Commands/Implementations/CancelSubscriptionCommand.cs b/src/Core/Billing/Commands/Implementations/CancelSubscriptionCommand.cs deleted file mode 100644 index 09dc5dde9..000000000 --- a/src/Core/Billing/Commands/Implementations/CancelSubscriptionCommand.cs +++ /dev/null @@ -1,118 +0,0 @@ -using Bit.Core.Billing.Models; -using Bit.Core.Services; -using Microsoft.Extensions.Logging; -using Stripe; - -using static Bit.Core.Billing.Utilities; - -namespace Bit.Core.Billing.Commands.Implementations; - -public class CancelSubscriptionCommand( - ILogger logger, - IStripeAdapter stripeAdapter) - : ICancelSubscriptionCommand -{ - private static readonly List _validReasons = - [ - "customer_service", - "low_quality", - "missing_features", - "other", - "switched_service", - "too_complex", - "too_expensive", - "unused" - ]; - - public async Task CancelSubscription( - Subscription subscription, - OffboardingSurveyResponse offboardingSurveyResponse, - bool cancelImmediately) - { - if (IsInactive(subscription)) - { - logger.LogWarning("Cannot cancel subscription ({ID}) that's already inactive.", subscription.Id); - throw ContactSupport(); - } - - var metadata = new Dictionary - { - { "cancellingUserId", offboardingSurveyResponse.UserId.ToString() } - }; - - if (cancelImmediately) - { - if (BelongsToOrganization(subscription)) - { - await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, new SubscriptionUpdateOptions - { - Metadata = metadata - }); - } - - await CancelSubscriptionImmediatelyAsync(subscription.Id, offboardingSurveyResponse); - } - else - { - await CancelSubscriptionAtEndOfPeriodAsync(subscription.Id, offboardingSurveyResponse, metadata); - } - } - - private static bool BelongsToOrganization(IHasMetadata subscription) - => subscription.Metadata != null && subscription.Metadata.ContainsKey("organizationId"); - - private async Task CancelSubscriptionImmediatelyAsync( - string subscriptionId, - OffboardingSurveyResponse offboardingSurveyResponse) - { - var options = new SubscriptionCancelOptions - { - CancellationDetails = new SubscriptionCancellationDetailsOptions - { - Comment = offboardingSurveyResponse.Feedback - } - }; - - if (IsValidCancellationReason(offboardingSurveyResponse.Reason)) - { - options.CancellationDetails.Feedback = offboardingSurveyResponse.Reason; - } - - await stripeAdapter.SubscriptionCancelAsync(subscriptionId, options); - } - - private static bool IsInactive(Subscription subscription) => - subscription.CanceledAt.HasValue || - subscription.Status == "canceled" || - subscription.Status == "unpaid" || - subscription.Status == "incomplete_expired"; - - private static bool IsValidCancellationReason(string reason) => _validReasons.Contains(reason); - - private async Task CancelSubscriptionAtEndOfPeriodAsync( - string subscriptionId, - OffboardingSurveyResponse offboardingSurveyResponse, - Dictionary metadata = null) - { - var options = new SubscriptionUpdateOptions - { - CancelAtPeriodEnd = true, - CancellationDetails = new SubscriptionCancellationDetailsOptions - { - Comment = offboardingSurveyResponse.Feedback - } - }; - - if (IsValidCancellationReason(offboardingSurveyResponse.Reason)) - { - options.CancellationDetails.Feedback = offboardingSurveyResponse.Reason; - } - - if (metadata != null) - { - options.Metadata = metadata; - } - - await stripeAdapter.SubscriptionUpdateAsync(subscriptionId, options); - } -} diff --git a/src/Core/Billing/Commands/Implementations/CreateCustomerCommand.cs b/src/Core/Billing/Commands/Implementations/CreateCustomerCommand.cs deleted file mode 100644 index 9a9714f24..000000000 --- a/src/Core/Billing/Commands/Implementations/CreateCustomerCommand.cs +++ /dev/null @@ -1,89 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.Billing.Queries; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Bit.Core.Settings; -using Microsoft.Extensions.Logging; -using Stripe; - -namespace Bit.Core.Billing.Commands.Implementations; - -public class CreateCustomerCommand( - IGlobalSettings globalSettings, - ILogger logger, - IOrganizationRepository organizationRepository, - IStripeAdapter stripeAdapter, - ISubscriberQueries subscriberQueries) : ICreateCustomerCommand -{ - public async Task CreateCustomer( - Provider provider, - Organization organization) - { - ArgumentNullException.ThrowIfNull(provider); - ArgumentNullException.ThrowIfNull(organization); - - if (!string.IsNullOrEmpty(organization.GatewayCustomerId)) - { - logger.LogWarning("Client organization ({ID}) already has a populated {FieldName}", organization.Id, nameof(organization.GatewayCustomerId)); - - return; - } - - var providerCustomer = await subscriberQueries.GetCustomerOrThrow(provider, new CustomerGetOptions - { - Expand = ["tax_ids"] - }); - - var providerTaxId = providerCustomer.TaxIds.FirstOrDefault(); - - var organizationDisplayName = organization.DisplayName(); - - var customerCreateOptions = new CustomerCreateOptions - { - Address = new AddressOptions - { - Country = providerCustomer.Address?.Country, - PostalCode = providerCustomer.Address?.PostalCode, - Line1 = providerCustomer.Address?.Line1, - Line2 = providerCustomer.Address?.Line2, - City = providerCustomer.Address?.City, - State = providerCustomer.Address?.State - }, - Name = organizationDisplayName, - Description = $"{provider.Name} Client Organization", - Email = provider.BillingEmail, - InvoiceSettings = new CustomerInvoiceSettingsOptions - { - CustomFields = - [ - new CustomerInvoiceSettingsCustomFieldOptions - { - Name = organization.SubscriberType(), - Value = organizationDisplayName.Length <= 30 - ? organizationDisplayName - : organizationDisplayName[..30] - } - ] - }, - Metadata = new Dictionary - { - { "region", globalSettings.BaseServiceUri.CloudRegion } - }, - TaxIdData = providerTaxId == null ? null : - [ - new CustomerTaxIdDataOptions - { - Type = providerTaxId.Type, - Value = providerTaxId.Value - } - ] - }; - - var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); - - organization.GatewayCustomerId = customer.Id; - - await organizationRepository.ReplaceAsync(organization); - } -} diff --git a/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs b/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs deleted file mode 100644 index be8479ea9..000000000 --- a/src/Core/Billing/Commands/Implementations/RemovePaymentMethodCommand.cs +++ /dev/null @@ -1,124 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Enums; -using Bit.Core.Services; -using Braintree; -using Microsoft.Extensions.Logging; - -using static Bit.Core.Billing.Utilities; - -namespace Bit.Core.Billing.Commands.Implementations; - -public class RemovePaymentMethodCommand( - IBraintreeGateway braintreeGateway, - ILogger logger, - IStripeAdapter stripeAdapter) - : IRemovePaymentMethodCommand -{ - public async Task RemovePaymentMethod(Organization organization) - { - ArgumentNullException.ThrowIfNull(organization); - - if (organization.Gateway is not GatewayType.Stripe || string.IsNullOrEmpty(organization.GatewayCustomerId)) - { - throw ContactSupport(); - } - - var stripeCustomer = await stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId, new Stripe.CustomerGetOptions - { - Expand = ["invoice_settings.default_payment_method", "sources"] - }); - - if (stripeCustomer == null) - { - logger.LogError("Could not find Stripe customer ({ID}) when removing payment method", organization.GatewayCustomerId); - - throw ContactSupport(); - } - - if (stripeCustomer.Metadata?.TryGetValue(BraintreeCustomerIdKey, out var braintreeCustomerId) ?? false) - { - await RemoveBraintreePaymentMethodAsync(braintreeCustomerId); - } - else - { - await RemoveStripePaymentMethodsAsync(stripeCustomer); - } - } - - private async Task RemoveBraintreePaymentMethodAsync(string braintreeCustomerId) - { - var customer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId); - - if (customer == null) - { - logger.LogError("Failed to retrieve Braintree customer ({ID}) when removing payment method", braintreeCustomerId); - - throw ContactSupport(); - } - - if (customer.DefaultPaymentMethod != null) - { - var existingDefaultPaymentMethod = customer.DefaultPaymentMethod; - - var updateCustomerResult = await braintreeGateway.Customer.UpdateAsync( - braintreeCustomerId, - new CustomerRequest { DefaultPaymentMethodToken = null }); - - if (!updateCustomerResult.IsSuccess()) - { - logger.LogError("Failed to update payment method for Braintree customer ({ID}) | Message: {Message}", - braintreeCustomerId, updateCustomerResult.Message); - - throw ContactSupport(); - } - - var deletePaymentMethodResult = await braintreeGateway.PaymentMethod.DeleteAsync(existingDefaultPaymentMethod.Token); - - if (!deletePaymentMethodResult.IsSuccess()) - { - await braintreeGateway.Customer.UpdateAsync( - braintreeCustomerId, - new CustomerRequest { DefaultPaymentMethodToken = existingDefaultPaymentMethod.Token }); - - logger.LogError( - "Failed to delete Braintree payment method for Customer ({ID}), re-linked payment method. Message: {Message}", - braintreeCustomerId, deletePaymentMethodResult.Message); - - throw ContactSupport(); - } - } - else - { - logger.LogWarning("Tried to remove non-existent Braintree payment method for Customer ({ID})", braintreeCustomerId); - } - } - - private async Task RemoveStripePaymentMethodsAsync(Stripe.Customer customer) - { - if (customer.Sources != null && customer.Sources.Any()) - { - foreach (var source in customer.Sources) - { - switch (source) - { - case Stripe.BankAccount: - await stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id); - break; - case Stripe.Card: - await stripeAdapter.CardDeleteAsync(customer.Id, source.Id); - break; - } - } - } - - var paymentMethods = stripeAdapter.PaymentMethodListAutoPagingAsync(new Stripe.PaymentMethodListOptions - { - Customer = customer.Id - }); - - await foreach (var paymentMethod in paymentMethods) - { - await stripeAdapter.PaymentMethodDetachAsync(paymentMethod.Id, new Stripe.PaymentMethodDetachOptions()); - } - } -} diff --git a/src/Core/Billing/Commands/Implementations/ScaleSeatsCommand.cs b/src/Core/Billing/Commands/Implementations/ScaleSeatsCommand.cs deleted file mode 100644 index 8d6d9a90e..000000000 --- a/src/Core/Billing/Commands/Implementations/ScaleSeatsCommand.cs +++ /dev/null @@ -1,130 +0,0 @@ -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.AdminConsole.Enums.Provider; -using Bit.Core.Billing.Entities; -using Bit.Core.Billing.Extensions; -using Bit.Core.Billing.Queries; -using Bit.Core.Billing.Repositories; -using Bit.Core.Enums; -using Bit.Core.Services; -using Bit.Core.Utilities; -using Microsoft.Extensions.Logging; -using static Bit.Core.Billing.Utilities; - -namespace Bit.Core.Billing.Commands.Implementations; - -public class ScaleSeatsCommand( - ILogger logger, - IPaymentService paymentService, - IProviderBillingQueries providerBillingQueries, - IProviderPlanRepository providerPlanRepository) : IScaleSeatsCommand -{ - public async Task ScalePasswordManagerSeats(Provider provider, PlanType planType, int seatAdjustment) - { - ArgumentNullException.ThrowIfNull(provider); - - if (provider.Type != ProviderType.Msp) - { - logger.LogError("Non-MSP provider ({ProviderID}) cannot scale their Password Manager seats", provider.Id); - - throw ContactSupport(); - } - - if (!planType.SupportsConsolidatedBilling()) - { - logger.LogError("Cannot scale provider ({ProviderID}) Password Manager seats for plan type {PlanType} as it does not support consolidated billing", provider.Id, planType.ToString()); - - throw ContactSupport(); - } - - var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); - - var providerPlan = providerPlans.FirstOrDefault(providerPlan => providerPlan.PlanType == planType); - - if (providerPlan == null || !providerPlan.IsConfigured()) - { - logger.LogError("Cannot scale provider ({ProviderID}) Password Manager seats for plan type {PlanType} when their matching provider plan is not configured", provider.Id, planType); - - throw ContactSupport(); - } - - var seatMinimum = providerPlan.SeatMinimum.GetValueOrDefault(0); - - var currentlyAssignedSeatTotal = - await providerBillingQueries.GetAssignedSeatTotalForPlanOrThrow(provider.Id, planType); - - var newlyAssignedSeatTotal = currentlyAssignedSeatTotal + seatAdjustment; - - var update = CurryUpdateFunction( - provider, - providerPlan, - newlyAssignedSeatTotal); - - /* - * Below the limit => Below the limit: - * No subscription update required. We can safely update the organization's seats. - */ - if (currentlyAssignedSeatTotal <= seatMinimum && - newlyAssignedSeatTotal <= seatMinimum) - { - providerPlan.AllocatedSeats = newlyAssignedSeatTotal; - - await providerPlanRepository.ReplaceAsync(providerPlan); - } - /* - * Below the limit => Above the limit: - * We have to scale the subscription up from the seat minimum to the newly assigned seat total. - */ - else if (currentlyAssignedSeatTotal <= seatMinimum && - newlyAssignedSeatTotal > seatMinimum) - { - await update( - seatMinimum, - newlyAssignedSeatTotal); - } - /* - * Above the limit => Above the limit: - * We have to scale the subscription from the currently assigned seat total to the newly assigned seat total. - */ - else if (currentlyAssignedSeatTotal > seatMinimum && - newlyAssignedSeatTotal > seatMinimum) - { - await update( - currentlyAssignedSeatTotal, - newlyAssignedSeatTotal); - } - /* - * Above the limit => Below the limit: - * We have to scale the subscription down from the currently assigned seat total to the seat minimum. - */ - else if (currentlyAssignedSeatTotal > seatMinimum && - newlyAssignedSeatTotal <= seatMinimum) - { - await update( - currentlyAssignedSeatTotal, - seatMinimum); - } - } - - private Func CurryUpdateFunction( - Provider provider, - ProviderPlan providerPlan, - int newlyAssignedSeats) => async (currentlySubscribedSeats, newlySubscribedSeats) => - { - var plan = StaticStore.GetPlan(providerPlan.PlanType); - - await paymentService.AdjustSeats( - provider, - plan, - currentlySubscribedSeats, - newlySubscribedSeats); - - var newlyPurchasedSeats = newlySubscribedSeats > providerPlan.SeatMinimum - ? newlySubscribedSeats - providerPlan.SeatMinimum - : 0; - - providerPlan.PurchasedSeats = newlyPurchasedSeats; - providerPlan.AllocatedSeats = newlyAssignedSeats; - - await providerPlanRepository.ReplaceAsync(providerPlan); - }; -} diff --git a/src/Core/Billing/Commands/Implementations/StartSubscriptionCommand.cs b/src/Core/Billing/Commands/Implementations/StartSubscriptionCommand.cs deleted file mode 100644 index 45cab1e0c..000000000 --- a/src/Core/Billing/Commands/Implementations/StartSubscriptionCommand.cs +++ /dev/null @@ -1,202 +0,0 @@ -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.AdminConsole.Enums.Provider; -using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Constants; -using Bit.Core.Billing.Repositories; -using Bit.Core.Enums; -using Bit.Core.Models.Business; -using Bit.Core.Services; -using Bit.Core.Settings; -using Bit.Core.Utilities; -using Microsoft.Extensions.Logging; -using Stripe; -using static Bit.Core.Billing.Utilities; - -namespace Bit.Core.Billing.Commands.Implementations; - -public class StartSubscriptionCommand( - IGlobalSettings globalSettings, - ILogger logger, - IProviderPlanRepository providerPlanRepository, - IProviderRepository providerRepository, - IStripeAdapter stripeAdapter) : IStartSubscriptionCommand -{ - public async Task StartSubscription( - Provider provider, - TaxInfo taxInfo) - { - ArgumentNullException.ThrowIfNull(provider); - ArgumentNullException.ThrowIfNull(taxInfo); - - if (!string.IsNullOrEmpty(provider.GatewaySubscriptionId)) - { - logger.LogWarning("Cannot start Provider subscription - Provider ({ID}) already has a {FieldName}", provider.Id, nameof(provider.GatewaySubscriptionId)); - - throw ContactSupport(); - } - - if (string.IsNullOrEmpty(taxInfo.BillingAddressCountry) || - string.IsNullOrEmpty(taxInfo.BillingAddressPostalCode)) - { - logger.LogError("Cannot start Provider subscription - Both the Provider's ({ID}) country and postal code are required", provider.Id); - - throw ContactSupport(); - } - - var customer = await GetOrCreateCustomerAsync(provider, taxInfo); - - var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); - - if (providerPlans == null || providerPlans.Count == 0) - { - logger.LogError("Cannot start Provider subscription - Provider ({ID}) has no configured plans", provider.Id); - - throw ContactSupport(); - } - - var subscriptionItemOptionsList = new List(); - - var teamsProviderPlan = - providerPlans.SingleOrDefault(providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly); - - if (teamsProviderPlan == null) - { - logger.LogError("Cannot start Provider subscription - Provider ({ID}) has no configured Teams Monthly plan", provider.Id); - - throw ContactSupport(); - } - - var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); - - subscriptionItemOptionsList.Add(new SubscriptionItemOptions - { - Price = teamsPlan.PasswordManager.StripeSeatPlanId, - Quantity = teamsProviderPlan.SeatMinimum - }); - - var enterpriseProviderPlan = - providerPlans.SingleOrDefault(providerPlan => providerPlan.PlanType == PlanType.EnterpriseMonthly); - - if (enterpriseProviderPlan == null) - { - logger.LogError("Cannot start Provider subscription - Provider ({ID}) has no configured Enterprise Monthly plan", provider.Id); - - throw ContactSupport(); - } - - var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); - - subscriptionItemOptionsList.Add(new SubscriptionItemOptions - { - Price = enterprisePlan.PasswordManager.StripeSeatPlanId, - Quantity = enterpriseProviderPlan.SeatMinimum - }); - - var subscriptionCreateOptions = new SubscriptionCreateOptions - { - AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = true - }, - CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, - Customer = customer.Id, - DaysUntilDue = 30, - Items = subscriptionItemOptionsList, - Metadata = new Dictionary - { - { "providerId", provider.Id.ToString() } - }, - OffSession = true, - ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations - }; - - var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); - - provider.GatewaySubscriptionId = subscription.Id; - - if (subscription.Status == StripeConstants.SubscriptionStatus.Incomplete) - { - await providerRepository.ReplaceAsync(provider); - - logger.LogError("Started incomplete Provider ({ProviderID}) subscription ({SubscriptionID})", provider.Id, subscription.Id); - - throw ContactSupport(); - } - - provider.Status = ProviderStatusType.Billable; - - await providerRepository.ReplaceAsync(provider); - } - - // ReSharper disable once SuggestBaseTypeForParameter - private async Task GetOrCreateCustomerAsync( - Provider provider, - TaxInfo taxInfo) - { - if (!string.IsNullOrEmpty(provider.GatewayCustomerId)) - { - var existingCustomer = await stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, new CustomerGetOptions - { - Expand = ["tax"] - }); - - if (existingCustomer != null) - { - return existingCustomer; - } - - logger.LogError("Cannot start Provider subscription - Provider's ({ProviderID}) {CustomerIDFieldName} did not relate to a Stripe customer", provider.Id, nameof(provider.GatewayCustomerId)); - - throw ContactSupport(); - } - - var providerDisplayName = provider.DisplayName(); - - var customerCreateOptions = new CustomerCreateOptions - { - Address = new AddressOptions - { - Country = taxInfo.BillingAddressCountry, - PostalCode = taxInfo.BillingAddressPostalCode, - Line1 = taxInfo.BillingAddressLine1, - Line2 = taxInfo.BillingAddressLine2, - City = taxInfo.BillingAddressCity, - State = taxInfo.BillingAddressState - }, - Coupon = "msp-discount-35", - Description = provider.DisplayBusinessName(), - Email = provider.BillingEmail, - Expand = ["tax"], - InvoiceSettings = new CustomerInvoiceSettingsOptions - { - CustomFields = - [ - new CustomerInvoiceSettingsCustomFieldOptions - { - Name = provider.SubscriberType(), - Value = providerDisplayName.Length <= 30 - ? providerDisplayName - : providerDisplayName[..30] - } - ] - }, - Metadata = new Dictionary - { - { "region", globalSettings.BaseServiceUri.CloudRegion } - }, - TaxIdData = taxInfo.HasTaxId ? - [ - new CustomerTaxIdDataOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber } - ] - : null - }; - - var createdCustomer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); - - provider.GatewayCustomerId = createdCustomer.Id; - - await providerRepository.ReplaceAsync(provider); - - return createdCustomer; - } -} diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index 28c3ace06..d225193e7 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -1,7 +1,5 @@ -using Bit.Core.Billing.Commands; -using Bit.Core.Billing.Commands.Implementations; -using Bit.Core.Billing.Queries; -using Bit.Core.Billing.Queries.Implementations; +using Bit.Core.Billing.Services; +using Bit.Core.Billing.Services.Implementations; namespace Bit.Core.Billing.Extensions; @@ -11,17 +9,7 @@ public static class ServiceCollectionExtensions { public static void AddBillingOperations(this IServiceCollection services) { - // Queries - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); - - // Commands - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); + services.AddTransient(); + services.AddTransient(); } } diff --git a/src/Core/Billing/Models/ProviderPaymentInfoDTO.cs b/src/Core/Billing/Models/ProviderPaymentInfoDTO.cs new file mode 100644 index 000000000..810fae9a5 --- /dev/null +++ b/src/Core/Billing/Models/ProviderPaymentInfoDTO.cs @@ -0,0 +1,6 @@ +using Bit.Core.Models.Business; + +namespace Bit.Core.Billing.Models; + +public record ProviderPaymentInfoDTO(BillingInfo.BillingSource billingSource, + TaxInfo taxInfo); diff --git a/src/Core/Billing/Queries/IProviderBillingQueries.cs b/src/Core/Billing/Queries/IProviderBillingQueries.cs deleted file mode 100644 index 1347ea4b8..000000000 --- a/src/Core/Billing/Queries/IProviderBillingQueries.cs +++ /dev/null @@ -1,27 +0,0 @@ -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.AdminConsole.Enums.Provider; -using Bit.Core.Billing.Models; -using Bit.Core.Enums; - -namespace Bit.Core.Billing.Queries; - -public interface IProviderBillingQueries -{ - /// - /// Retrieves the number of seats an MSP has assigned to its client organizations with a specified . - /// - /// The ID of the MSP to retrieve the assigned seat total for. - /// The type of plan to retrieve the assigned seat total for. - /// An representing the number of seats the provider has assigned to its client organizations with the specified . - /// Thrown when the provider represented by the is . - /// Thrown when the provider represented by the has . - Task GetAssignedSeatTotalForPlanOrThrow(Guid providerId, PlanType planType); - - /// - /// Retrieves a provider's billing subscription data. - /// - /// The ID of the provider to retrieve subscription data for. - /// A object containing the provider's Stripe and their s. - /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. - Task GetSubscriptionDTO(Guid providerId); -} diff --git a/src/Core/Billing/Queries/ISubscriberQueries.cs b/src/Core/Billing/Queries/ISubscriberQueries.cs deleted file mode 100644 index 013ae3e1d..000000000 --- a/src/Core/Billing/Queries/ISubscriberQueries.cs +++ /dev/null @@ -1,58 +0,0 @@ -using Bit.Core.Entities; -using Bit.Core.Exceptions; -using Stripe; - -namespace Bit.Core.Billing.Queries; - -public interface ISubscriberQueries -{ - /// - /// Retrieves a Stripe using the 's property. - /// - /// The organization, provider or user to retrieve the customer for. - /// Optional parameters that can be passed to Stripe to expand or modify the . - /// A Stripe . - /// Thrown when the is . - /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. - Task GetCustomer( - ISubscriber subscriber, - CustomerGetOptions customerGetOptions = null); - - /// - /// Retrieves a Stripe using the 's property. - /// - /// The organization, provider or user to retrieve the subscription for. - /// Optional parameters that can be passed to Stripe to expand or modify the . - /// A Stripe . - /// Thrown when the is . - /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. - Task GetSubscription( - ISubscriber subscriber, - SubscriptionGetOptions subscriptionGetOptions = null); - - /// - /// Retrieves a Stripe using the 's property. - /// - /// The organization or user to retrieve the subscription for. - /// Optional parameters that can be passed to Stripe to expand or modify the . - /// A Stripe . - /// Thrown when the is . - /// Thrown when the subscriber's is or empty. - /// Thrown when the returned from Stripe's API is null. - Task GetCustomerOrThrow( - ISubscriber subscriber, - CustomerGetOptions customerGetOptions = null); - - /// - /// Retrieves a Stripe using the 's property. - /// - /// The organization or user to retrieve the subscription for. - /// Optional parameters that can be passed to Stripe to expand or modify the . - /// A Stripe . - /// Thrown when the is . - /// Thrown when the subscriber's is or empty. - /// Thrown when the returned from Stripe's API is null. - Task GetSubscriptionOrThrow( - ISubscriber subscriber, - SubscriptionGetOptions subscriptionGetOptions = null); -} diff --git a/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs b/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs deleted file mode 100644 index a941b6f94..000000000 --- a/src/Core/Billing/Queries/Implementations/ProviderBillingQueries.cs +++ /dev/null @@ -1,92 +0,0 @@ -using Bit.Core.AdminConsole.Enums.Provider; -using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Models; -using Bit.Core.Billing.Repositories; -using Bit.Core.Enums; -using Bit.Core.Utilities; -using Microsoft.Extensions.Logging; -using Stripe; -using static Bit.Core.Billing.Utilities; - -namespace Bit.Core.Billing.Queries.Implementations; - -public class ProviderBillingQueries( - ILogger logger, - IProviderOrganizationRepository providerOrganizationRepository, - IProviderPlanRepository providerPlanRepository, - IProviderRepository providerRepository, - ISubscriberQueries subscriberQueries) : IProviderBillingQueries -{ - public async Task GetAssignedSeatTotalForPlanOrThrow( - Guid providerId, - PlanType planType) - { - var provider = await providerRepository.GetByIdAsync(providerId); - - if (provider == null) - { - logger.LogError( - "Could not find provider ({ID}) when retrieving assigned seat total", - providerId); - - throw ContactSupport(); - } - - if (provider.Type == ProviderType.Reseller) - { - logger.LogError("Assigned seats cannot be retrieved for reseller-type provider ({ID})", providerId); - - throw ContactSupport("Consolidated billing does not support reseller-type providers"); - } - - var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId); - - var plan = StaticStore.GetPlan(planType); - - return providerOrganizations - .Where(providerOrganization => providerOrganization.Plan == plan.Name && providerOrganization.Status == OrganizationStatusType.Managed) - .Sum(providerOrganization => providerOrganization.Seats ?? 0); - } - - public async Task GetSubscriptionDTO(Guid providerId) - { - var provider = await providerRepository.GetByIdAsync(providerId); - - if (provider == null) - { - logger.LogError( - "Could not find provider ({ID}) when retrieving subscription data.", - providerId); - - return null; - } - - if (provider.Type == ProviderType.Reseller) - { - logger.LogError("Subscription data cannot be retrieved for reseller-type provider ({ID})", providerId); - - throw ContactSupport("Consolidated billing does not support reseller-type providers"); - } - - var subscription = await subscriberQueries.GetSubscription(provider, new SubscriptionGetOptions - { - Expand = ["customer"] - }); - - if (subscription == null) - { - return null; - } - - var providerPlans = await providerPlanRepository.GetByProviderId(providerId); - - var configuredProviderPlans = providerPlans - .Where(providerPlan => providerPlan.IsConfigured()) - .Select(ConfiguredProviderPlanDTO.From) - .ToList(); - - return new ProviderSubscriptionDTO( - configuredProviderPlans, - subscription); - } -} diff --git a/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs b/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs deleted file mode 100644 index b9fe492a1..000000000 --- a/src/Core/Billing/Queries/Implementations/SubscriberQueries.cs +++ /dev/null @@ -1,159 +0,0 @@ -using Bit.Core.Entities; -using Bit.Core.Services; -using Microsoft.Extensions.Logging; -using Stripe; - -using static Bit.Core.Billing.Utilities; - -namespace Bit.Core.Billing.Queries.Implementations; - -public class SubscriberQueries( - ILogger logger, - IStripeAdapter stripeAdapter) : ISubscriberQueries -{ - public async Task GetCustomer( - ISubscriber subscriber, - CustomerGetOptions customerGetOptions = null) - { - ArgumentNullException.ThrowIfNull(subscriber); - - if (string.IsNullOrEmpty(subscriber.GatewayCustomerId)) - { - logger.LogError("Cannot retrieve customer for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewayCustomerId)); - - return null; - } - - try - { - var customer = await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerGetOptions); - - if (customer != null) - { - return customer; - } - - logger.LogError("Could not find Stripe customer ({CustomerID}) for subscriber ({SubscriberID})", - subscriber.GatewayCustomerId, subscriber.Id); - - return null; - } - catch (StripeException exception) - { - logger.LogError("An error occurred while trying to retrieve Stripe customer ({CustomerID}) for subscriber ({SubscriberID}): {Error}", - subscriber.GatewayCustomerId, subscriber.Id, exception.Message); - - return null; - } - } - - public async Task GetSubscription( - ISubscriber subscriber, - SubscriptionGetOptions subscriptionGetOptions = null) - { - ArgumentNullException.ThrowIfNull(subscriber); - - if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) - { - logger.LogError("Cannot retrieve subscription for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewaySubscriptionId)); - - return null; - } - - try - { - var subscription = - await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); - - if (subscription != null) - { - return subscription; - } - - logger.LogError("Could not find Stripe subscription ({SubscriptionID}) for subscriber ({SubscriberID})", - subscriber.GatewaySubscriptionId, subscriber.Id); - - return null; - } - catch (StripeException exception) - { - logger.LogError("An error occurred while trying to retrieve Stripe subscription ({SubscriptionID}) for subscriber ({SubscriberID}): {Error}", - subscriber.GatewaySubscriptionId, subscriber.Id, exception.Message); - - return null; - } - } - - public async Task GetCustomerOrThrow( - ISubscriber subscriber, - CustomerGetOptions customerGetOptions = null) - { - ArgumentNullException.ThrowIfNull(subscriber); - - if (string.IsNullOrEmpty(subscriber.GatewayCustomerId)) - { - logger.LogError("Cannot retrieve customer for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewayCustomerId)); - - throw ContactSupport(); - } - - try - { - var customer = await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerGetOptions); - - if (customer != null) - { - return customer; - } - - logger.LogError("Could not find Stripe customer ({CustomerID}) for subscriber ({SubscriberID})", - subscriber.GatewayCustomerId, subscriber.Id); - - throw ContactSupport(); - } - catch (StripeException exception) - { - logger.LogError("An error occurred while trying to retrieve Stripe customer ({CustomerID}) for subscriber ({SubscriberID}): {Error}", - subscriber.GatewayCustomerId, subscriber.Id, exception.Message); - - throw ContactSupport("An error occurred while trying to retrieve a Stripe Customer", exception); - } - } - - public async Task GetSubscriptionOrThrow( - ISubscriber subscriber, - SubscriptionGetOptions subscriptionGetOptions = null) - { - ArgumentNullException.ThrowIfNull(subscriber); - - if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) - { - logger.LogError("Cannot retrieve subscription for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewaySubscriptionId)); - - throw ContactSupport(); - } - - try - { - var subscription = - await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); - - if (subscription != null) - { - return subscription; - } - - logger.LogError("Could not find Stripe subscription ({SubscriptionID}) for subscriber ({SubscriberID})", - subscriber.GatewaySubscriptionId, subscriber.Id); - - throw ContactSupport(); - } - catch (StripeException exception) - { - logger.LogError("An error occurred while trying to retrieve Stripe subscription ({SubscriptionID}) for subscriber ({SubscriberID}): {Error}", - subscriber.GatewaySubscriptionId, subscriber.Id, exception.Message); - - throw ContactSupport("An error occurred while trying to retrieve a Stripe Subscription", exception); - } - } -} diff --git a/src/Core/Billing/Queries/IOrganizationBillingQueries.cs b/src/Core/Billing/Services/IOrganizationBillingService.cs similarity index 56% rename from src/Core/Billing/Queries/IOrganizationBillingQueries.cs rename to src/Core/Billing/Services/IOrganizationBillingService.cs index f0d3434c5..e030cd487 100644 --- a/src/Core/Billing/Queries/IOrganizationBillingQueries.cs +++ b/src/Core/Billing/Services/IOrganizationBillingService.cs @@ -1,8 +1,8 @@ using Bit.Core.Billing.Models; -namespace Bit.Core.Billing.Queries; +namespace Bit.Core.Billing.Services; -public interface IOrganizationBillingQueries +public interface IOrganizationBillingService { Task GetMetadata(Guid organizationId); } diff --git a/src/Core/Billing/Services/IProviderBillingService.cs b/src/Core/Billing/Services/IProviderBillingService.cs new file mode 100644 index 000000000..6ff1fbf0f --- /dev/null +++ b/src/Core/Billing/Services/IProviderBillingService.cs @@ -0,0 +1,96 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.Billing.Models; +using Bit.Core.Enums; +using Bit.Core.Models.Business; + +namespace Bit.Core.Billing.Services; + +public interface IProviderBillingService +{ + /// + /// Assigns a specified number of to a client on behalf of + /// its . Seat adjustments for the client organization may autoscale the provider's Stripe + /// depending on the provider's seat minimum for the client 's + /// . + /// + /// The that manages the client . + /// The client whose you want to update. + /// The number of seats to assign to the client organization. + Task AssignSeatsToClientOrganization( + Provider provider, + Organization organization, + int seats); + + /// + /// Create a Stripe for the specified utilizing the provided . + /// + /// The to create a Stripe customer for. + /// The to use for calculating the customer's automatic tax. + /// + Task CreateCustomer( + Provider provider, + TaxInfo taxInfo); + + /// + /// Create a Stripe for the provided client utilizing + /// the address and tax information of its . + /// + /// The MSP that owns the client organization. + /// The client organization to create a Stripe for. + Task CreateCustomerForClientOrganization( + Provider provider, + Organization organization); + + /// + /// Retrieves the number of seats an MSP has assigned to its client organizations with a specified . + /// + /// The ID of the MSP to retrieve the assigned seat total for. + /// The type of plan to retrieve the assigned seat total for. + /// An representing the number of seats the provider has assigned to its client organizations with the specified . + /// Thrown when the provider represented by the is . + /// Thrown when the provider represented by the has . + Task GetAssignedSeatTotalForPlanOrThrow( + Guid providerId, + PlanType planType); + + /// + /// Retrieves a provider's billing subscription data. + /// + /// The ID of the provider to retrieve subscription data for. + /// A object containing the provider's Stripe and their s. + /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. + Task GetSubscriptionDTO( + Guid providerId); + + /// + /// Scales the 's seats for the specified using the provided . + /// This operation may autoscale the provider's Stripe depending on the 's seat minimum for the + /// specified . + /// + /// The to scale seats for. + /// The to scale seats for. + /// The change in the number of seats you'd like to apply to the . + Task ScaleSeats( + Provider provider, + PlanType planType, + int seatAdjustment); + + /// + /// Starts a Stripe for the given given it has an existing Stripe . + /// subscriptions will always be started with a for both the + /// and plan, and the quantity for each item will be equal the provider's seat minimum for each respective plan. + /// + /// The provider to create the for. + Task StartSubscription( + Provider provider); + + /// + /// Retrieves a provider's billing payment information. + /// + /// The ID of the provider to retrieve payment information for. + /// A object containing the provider's Stripe and their s. + /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. + Task GetPaymentInformationAsync(Guid providerId); +} diff --git a/src/Core/Billing/Services/ISubscriberService.cs b/src/Core/Billing/Services/ISubscriberService.cs new file mode 100644 index 000000000..dd825e39c --- /dev/null +++ b/src/Core/Billing/Services/ISubscriberService.cs @@ -0,0 +1,100 @@ +using Bit.Core.Billing.Models; +using Bit.Core.Entities; +using Bit.Core.Models.Business; +using Stripe; + +namespace Bit.Core.Billing.Services; + +public interface ISubscriberService +{ + /// + /// Cancels a subscriber's subscription while including user-provided feedback via the . + /// If the flag is , + /// this command sets the subscription's "cancel_at_end_of_period" property to . + /// Otherwise, this command cancels the subscription immediately. + /// + /// The subscriber with the subscription to cancel. + /// An DTO containing user-provided feedback on why they are cancelling the subscription. + /// A flag indicating whether to cancel the subscription immediately or at the end of the subscription period. + Task CancelSubscription( + ISubscriber subscriber, + OffboardingSurveyResponse offboardingSurveyResponse, + bool cancelImmediately); + + /// + /// Retrieves a Stripe using the 's property. + /// + /// The subscriber to retrieve the Stripe customer for. + /// Optional parameters that can be passed to Stripe to expand or modify the customer. + /// A Stripe . + /// Thrown when the is . + /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. + Task GetCustomer( + ISubscriber subscriber, + CustomerGetOptions customerGetOptions = null); + + /// + /// Retrieves a Stripe using the 's property. + /// + /// The subscriber to retrieve the Stripe customer for. + /// Optional parameters that can be passed to Stripe to expand or modify the customer. + /// A Stripe . + /// Thrown when the is . + /// Thrown when the subscriber's is or empty. + /// Thrown when the returned from Stripe's API is null. + Task GetCustomerOrThrow( + ISubscriber subscriber, + CustomerGetOptions customerGetOptions = null); + + /// + /// Retrieves a Stripe using the 's property. + /// + /// The subscriber to retrieve the Stripe subscription for. + /// Optional parameters that can be passed to Stripe to expand or modify the subscription. + /// A Stripe . + /// Thrown when the is . + /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. + Task GetSubscription( + ISubscriber subscriber, + SubscriptionGetOptions subscriptionGetOptions = null); + + /// + /// Retrieves a Stripe using the 's property. + /// + /// The subscriber to retrieve the Stripe subscription for. + /// Optional parameters that can be passed to Stripe to expand or modify the subscription. + /// A Stripe . + /// Thrown when the is . + /// Thrown when the subscriber's is or empty. + /// Thrown when the returned from Stripe's API is null. + Task GetSubscriptionOrThrow( + ISubscriber subscriber, + SubscriptionGetOptions subscriptionGetOptions = null); + + /// + /// Attempts to remove a subscriber's saved payment method. If the Stripe representing the + /// contains a valid "btCustomerId" key in its property, + /// this command will attempt to remove the Braintree . Otherwise, it will attempt to remove the + /// Stripe . + /// + /// The subscriber to remove the saved payment method for. + Task RemovePaymentMethod(ISubscriber subscriber); + + /// + /// Retrieves a Stripe using the 's property. + /// + /// The subscriber to retrieve the Stripe customer for. + /// A Stripe . + /// Thrown when the is . + /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. + Task GetTaxInformationAsync(ISubscriber subscriber); + + /// + /// Retrieves a Stripe using the 's property. + /// + /// The subscriber to retrieve the Stripe customer for. + /// A Stripe . + /// Thrown when the is . + /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. + Task GetPaymentMethodAsync(ISubscriber subscriber); +} diff --git a/src/Core/Billing/Queries/Implementations/OrganizationBillingQueries.cs b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs similarity index 85% rename from src/Core/Billing/Queries/Implementations/OrganizationBillingQueries.cs rename to src/Core/Billing/Services/Implementations/OrganizationBillingService.cs index 9f6a8b2ec..3013a269e 100644 --- a/src/Core/Billing/Queries/Implementations/OrganizationBillingQueries.cs +++ b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs @@ -5,11 +5,11 @@ using Bit.Core.Repositories; using Bit.Core.Utilities; using Stripe; -namespace Bit.Core.Billing.Queries.Implementations; +namespace Bit.Core.Billing.Services.Implementations; -public class OrganizationBillingQueries( +public class OrganizationBillingService( IOrganizationRepository organizationRepository, - ISubscriberQueries subscriberQueries) : IOrganizationBillingQueries + ISubscriberService subscriberService) : IOrganizationBillingService { public async Task GetMetadata(Guid organizationId) { @@ -20,12 +20,12 @@ public class OrganizationBillingQueries( return null; } - var customer = await subscriberQueries.GetCustomer(organization, new CustomerGetOptions + var customer = await subscriberService.GetCustomer(organization, new CustomerGetOptions { Expand = ["discount.coupon.applies_to"] }); - var subscription = await subscriberQueries.GetSubscription(organization); + var subscription = await subscriberService.GetSubscription(organization); if (customer == null || subscription == null) { diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs new file mode 100644 index 000000000..5cf21b1f4 --- /dev/null +++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs @@ -0,0 +1,444 @@ +using Bit.Core.Billing.Models; +using Bit.Core.Entities; +using Bit.Core.Models.Business; +using Bit.Core.Services; +using Braintree; +using Microsoft.Extensions.Logging; +using Stripe; + +using static Bit.Core.Billing.Utilities; +using Customer = Stripe.Customer; +using Subscription = Stripe.Subscription; + +namespace Bit.Core.Billing.Services.Implementations; + +public class SubscriberService( + IBraintreeGateway braintreeGateway, + ILogger logger, + IStripeAdapter stripeAdapter) : ISubscriberService +{ + public async Task CancelSubscription( + ISubscriber subscriber, + OffboardingSurveyResponse offboardingSurveyResponse, + bool cancelImmediately) + { + var subscription = await GetSubscriptionOrThrow(subscriber); + + if (subscription.CanceledAt.HasValue || + subscription.Status == "canceled" || + subscription.Status == "unpaid" || + subscription.Status == "incomplete_expired") + { + logger.LogWarning("Cannot cancel subscription ({ID}) that's already inactive", subscription.Id); + + throw ContactSupport(); + } + + var metadata = new Dictionary + { + { "cancellingUserId", offboardingSurveyResponse.UserId.ToString() } + }; + + List validCancellationReasons = [ + "customer_service", + "low_quality", + "missing_features", + "other", + "switched_service", + "too_complex", + "too_expensive", + "unused" + ]; + + if (cancelImmediately) + { + if (subscription.Metadata != null && subscription.Metadata.ContainsKey("organizationId")) + { + await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, new SubscriptionUpdateOptions + { + Metadata = metadata + }); + } + + var options = new SubscriptionCancelOptions + { + CancellationDetails = new SubscriptionCancellationDetailsOptions + { + Comment = offboardingSurveyResponse.Feedback + } + }; + + if (validCancellationReasons.Contains(offboardingSurveyResponse.Reason)) + { + options.CancellationDetails.Feedback = offboardingSurveyResponse.Reason; + } + + await stripeAdapter.SubscriptionCancelAsync(subscription.Id, options); + } + else + { + var options = new SubscriptionUpdateOptions + { + CancelAtPeriodEnd = true, + CancellationDetails = new SubscriptionCancellationDetailsOptions + { + Comment = offboardingSurveyResponse.Feedback + }, + Metadata = metadata + }; + + if (validCancellationReasons.Contains(offboardingSurveyResponse.Reason)) + { + options.CancellationDetails.Feedback = offboardingSurveyResponse.Reason; + } + + await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, options); + } + } + + public async Task GetCustomer( + ISubscriber subscriber, + CustomerGetOptions customerGetOptions = null) + { + ArgumentNullException.ThrowIfNull(subscriber); + + if (string.IsNullOrEmpty(subscriber.GatewayCustomerId)) + { + logger.LogError("Cannot retrieve customer for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewayCustomerId)); + + return null; + } + + try + { + var customer = await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerGetOptions); + + if (customer != null) + { + return customer; + } + + logger.LogError("Could not find Stripe customer ({CustomerID}) for subscriber ({SubscriberID})", + subscriber.GatewayCustomerId, subscriber.Id); + + return null; + } + catch (StripeException exception) + { + logger.LogError("An error occurred while trying to retrieve Stripe customer ({CustomerID}) for subscriber ({SubscriberID}): {Error}", + subscriber.GatewayCustomerId, subscriber.Id, exception.Message); + + return null; + } + } + + public async Task GetCustomerOrThrow( + ISubscriber subscriber, + CustomerGetOptions customerGetOptions = null) + { + ArgumentNullException.ThrowIfNull(subscriber); + + if (string.IsNullOrEmpty(subscriber.GatewayCustomerId)) + { + logger.LogError("Cannot retrieve customer for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewayCustomerId)); + + throw ContactSupport(); + } + + try + { + var customer = await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerGetOptions); + + if (customer != null) + { + return customer; + } + + logger.LogError("Could not find Stripe customer ({CustomerID}) for subscriber ({SubscriberID})", + subscriber.GatewayCustomerId, subscriber.Id); + + throw ContactSupport(); + } + catch (StripeException exception) + { + logger.LogError("An error occurred while trying to retrieve Stripe customer ({CustomerID}) for subscriber ({SubscriberID}): {Error}", + subscriber.GatewayCustomerId, subscriber.Id, exception.Message); + + throw ContactSupport("An error occurred while trying to retrieve a Stripe Customer", exception); + } + } + + public async Task GetSubscription( + ISubscriber subscriber, + SubscriptionGetOptions subscriptionGetOptions = null) + { + ArgumentNullException.ThrowIfNull(subscriber); + + if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) + { + logger.LogError("Cannot retrieve subscription for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewaySubscriptionId)); + + return null; + } + + try + { + var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); + + if (subscription != null) + { + return subscription; + } + + logger.LogError("Could not find Stripe subscription ({SubscriptionID}) for subscriber ({SubscriberID})", + subscriber.GatewaySubscriptionId, subscriber.Id); + + return null; + } + catch (StripeException exception) + { + logger.LogError("An error occurred while trying to retrieve Stripe subscription ({SubscriptionID}) for subscriber ({SubscriberID}): {Error}", + subscriber.GatewaySubscriptionId, subscriber.Id, exception.Message); + + return null; + } + } + + public async Task GetSubscriptionOrThrow( + ISubscriber subscriber, + SubscriptionGetOptions subscriptionGetOptions = null) + { + ArgumentNullException.ThrowIfNull(subscriber); + + if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) + { + logger.LogError("Cannot retrieve subscription for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewaySubscriptionId)); + + throw ContactSupport(); + } + + try + { + var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); + + if (subscription != null) + { + return subscription; + } + + logger.LogError("Could not find Stripe subscription ({SubscriptionID}) for subscriber ({SubscriberID})", + subscriber.GatewaySubscriptionId, subscriber.Id); + + throw ContactSupport(); + } + catch (StripeException exception) + { + logger.LogError("An error occurred while trying to retrieve Stripe subscription ({SubscriptionID}) for subscriber ({SubscriberID}): {Error}", + subscriber.GatewaySubscriptionId, subscriber.Id, exception.Message); + + throw ContactSupport("An error occurred while trying to retrieve a Stripe Subscription", exception); + } + } + + public async Task RemovePaymentMethod( + ISubscriber subscriber) + { + ArgumentNullException.ThrowIfNull(subscriber); + + if (string.IsNullOrEmpty(subscriber.GatewayCustomerId)) + { + throw ContactSupport(); + } + + var stripeCustomer = await GetCustomerOrThrow(subscriber, new CustomerGetOptions + { + Expand = ["invoice_settings.default_payment_method", "sources"] + }); + + if (stripeCustomer.Metadata?.TryGetValue(BraintreeCustomerIdKey, out var braintreeCustomerId) ?? false) + { + var braintreeCustomer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId); + + if (braintreeCustomer == null) + { + logger.LogError("Failed to retrieve Braintree customer ({ID}) when removing payment method", braintreeCustomerId); + + throw ContactSupport(); + } + + if (braintreeCustomer.DefaultPaymentMethod != null) + { + var existingDefaultPaymentMethod = braintreeCustomer.DefaultPaymentMethod; + + var updateCustomerResult = await braintreeGateway.Customer.UpdateAsync( + braintreeCustomerId, + new CustomerRequest { DefaultPaymentMethodToken = null }); + + if (!updateCustomerResult.IsSuccess()) + { + logger.LogError("Failed to update payment method for Braintree customer ({ID}) | Message: {Message}", + braintreeCustomerId, updateCustomerResult.Message); + + throw ContactSupport(); + } + + var deletePaymentMethodResult = await braintreeGateway.PaymentMethod.DeleteAsync(existingDefaultPaymentMethod.Token); + + if (!deletePaymentMethodResult.IsSuccess()) + { + await braintreeGateway.Customer.UpdateAsync( + braintreeCustomerId, + new CustomerRequest { DefaultPaymentMethodToken = existingDefaultPaymentMethod.Token }); + + logger.LogError( + "Failed to delete Braintree payment method for Customer ({ID}), re-linked payment method. Message: {Message}", + braintreeCustomerId, deletePaymentMethodResult.Message); + + throw ContactSupport(); + } + } + else + { + logger.LogWarning("Tried to remove non-existent Braintree payment method for Customer ({ID})", braintreeCustomerId); + } + } + else + { + if (stripeCustomer.Sources != null && stripeCustomer.Sources.Any()) + { + foreach (var source in stripeCustomer.Sources) + { + switch (source) + { + case BankAccount: + await stripeAdapter.BankAccountDeleteAsync(stripeCustomer.Id, source.Id); + break; + case Card: + await stripeAdapter.CardDeleteAsync(stripeCustomer.Id, source.Id); + break; + } + } + } + + var paymentMethods = stripeAdapter.PaymentMethodListAutoPagingAsync(new PaymentMethodListOptions + { + Customer = stripeCustomer.Id + }); + + await foreach (var paymentMethod in paymentMethods) + { + await stripeAdapter.PaymentMethodDetachAsync(paymentMethod.Id); + } + } + } + + public async Task GetTaxInformationAsync(ISubscriber subscriber) + { + ArgumentNullException.ThrowIfNull(subscriber); + + if (string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) + { + logger.LogError("Cannot retrieve GatewayCustomerId for subscriber ({SubscriberID}) with no {FieldName}", subscriber.Id, nameof(subscriber.GatewaySubscriptionId)); + + return null; + } + + var customer = await GetCustomerOrThrow(subscriber, new CustomerGetOptions { Expand = ["tax_ids"] }); + + if (customer is null) + { + logger.LogError("Could not find Stripe customer ({CustomerID}) for subscriber ({SubscriberID})", + subscriber.GatewayCustomerId, subscriber.Id); + + return null; + } + + var address = customer.Address; + + // Line1 is required, so if missing we're using the subscriber name + // see: https://stripe.com/docs/api/customers/create#create_customer-address-line1 + if (address is not null && string.IsNullOrWhiteSpace(address.Line1)) + { + address.Line1 = null; + } + + return MapToTaxInfo(customer); + } + + public async Task GetPaymentMethodAsync(ISubscriber subscriber) + { + ArgumentNullException.ThrowIfNull(subscriber); + var customer = await GetCustomerOrThrow(subscriber, GetCustomerPaymentOptions()); + if (customer == null) + { + logger.LogError("Could not find Stripe customer ({CustomerID}) for subscriber ({SubscriberID})", + subscriber.GatewayCustomerId, subscriber.Id); + return null; + } + + if (customer.Metadata?.ContainsKey("btCustomerId") ?? false) + { + try + { + var braintreeCustomer = await braintreeGateway.Customer.FindAsync( + customer.Metadata["btCustomerId"]); + if (braintreeCustomer?.DefaultPaymentMethod != null) + { + return new BillingInfo.BillingSource( + braintreeCustomer.DefaultPaymentMethod); + } + } + catch (Braintree.Exceptions.NotFoundException ex) + { + logger.LogError("An error occurred while trying to retrieve braintree customer ({SubscriberID}): {Error}", subscriber.Id, ex.Message); + } + } + + if (customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card") + { + return new BillingInfo.BillingSource( + customer.InvoiceSettings.DefaultPaymentMethod); + } + + if (customer.DefaultSource != null && + (customer.DefaultSource is Card || customer.DefaultSource is BankAccount)) + { + return new BillingInfo.BillingSource(customer.DefaultSource); + } + + var paymentMethod = GetLatestCardPaymentMethod(customer.Id); + return paymentMethod != null ? new BillingInfo.BillingSource(paymentMethod) : null; + } + + private static CustomerGetOptions GetCustomerPaymentOptions() + { + var customerOptions = new CustomerGetOptions(); + customerOptions.AddExpand("default_source"); + customerOptions.AddExpand("invoice_settings.default_payment_method"); + return customerOptions; + } + + private Stripe.PaymentMethod GetLatestCardPaymentMethod(string customerId) + { + var cardPaymentMethods = stripeAdapter.PaymentMethodListAutoPaging( + new PaymentMethodListOptions { Customer = customerId, Type = "card" }); + return cardPaymentMethods.MaxBy(m => m.Created); + } + + private TaxInfo MapToTaxInfo(Customer customer) + { + var address = customer.Address; + var taxId = customer.TaxIds?.FirstOrDefault(); + + return new TaxInfo + { + TaxIdNumber = taxId?.Value, + BillingAddressLine1 = address?.Line1, + BillingAddressLine2 = address?.Line2, + BillingAddressCity = address?.City, + BillingAddressState = address?.State, + BillingAddressPostalCode = address?.PostalCode, + BillingAddressCountry = address?.Country, + }; + } +} diff --git a/src/Core/Services/IPaymentService.cs b/src/Core/Services/IPaymentService.cs index 3c78c585f..52bdab4bb 100644 --- a/src/Core/Services/IPaymentService.cs +++ b/src/Core/Services/IPaymentService.cs @@ -49,7 +49,6 @@ public interface IPaymentService Task GetBillingHistoryAsync(ISubscriber subscriber); Task GetBillingBalanceAndSourceAsync(ISubscriber subscriber); Task GetSubscriptionAsync(ISubscriber subscriber); - Task GetTaxInfoAsync(ISubscriber subscriber); Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo); Task CreateTaxRateAsync(TaxRate taxRate); Task UpdateTaxRateAsync(TaxRate taxRate); diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index cc2bee06b..47185da80 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -1651,43 +1651,6 @@ public class StripePaymentService : IPaymentService return subscriptionInfo; } - public async Task GetTaxInfoAsync(ISubscriber subscriber) - { - if (subscriber == null || string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) - { - return null; - } - - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, - new CustomerGetOptions { Expand = ["tax_ids"] }); - - if (customer == null) - { - return null; - } - - var address = customer.Address; - var taxId = customer.TaxIds?.FirstOrDefault(); - - // Line1 is required, so if missing we're using the subscriber name - // see: https://stripe.com/docs/api/customers/create#create_customer-address-line1 - if (address != null && string.IsNullOrWhiteSpace(address.Line1)) - { - address.Line1 = null; - } - - return new TaxInfo - { - TaxIdNumber = taxId?.Value, - BillingAddressLine1 = address?.Line1, - BillingAddressLine2 = address?.Line2, - BillingAddressCity = address?.City, - BillingAddressState = address?.State, - BillingAddressPostalCode = address?.PostalCode, - BillingAddressCountry = address?.Country, - }; - } - public async Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo) { if (subscriber != null && !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs index a6844d8c2..abc065547 100644 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs @@ -14,7 +14,7 @@ using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Repositories; using Bit.Core.Auth.Services; -using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -48,7 +48,7 @@ public class OrganizationsControllerTests : IDisposable private readonly IPushNotificationService _pushNotificationService; private readonly IOrganizationEnableCollectionEnhancementsCommand _organizationEnableCollectionEnhancementsCommand; private readonly IProviderRepository _providerRepository; - private readonly IScaleSeatsCommand _scaleSeatsCommand; + private readonly IProviderBillingService _providerBillingService; private readonly IDataProtectorTokenFactory _orgDeleteTokenDataFactory; private readonly OrganizationsController _sut; @@ -72,7 +72,7 @@ public class OrganizationsControllerTests : IDisposable _pushNotificationService = Substitute.For(); _organizationEnableCollectionEnhancementsCommand = Substitute.For(); _providerRepository = Substitute.For(); - _scaleSeatsCommand = Substitute.For(); + _providerBillingService = Substitute.For(); _orgDeleteTokenDataFactory = Substitute.For>(); _sut = new OrganizationsController( @@ -93,7 +93,7 @@ public class OrganizationsControllerTests : IDisposable _pushNotificationService, _organizationEnableCollectionEnhancementsCommand, _providerRepository, - _scaleSeatsCommand, + _providerBillingService, _orgDeleteTokenDataFactory); } @@ -233,8 +233,8 @@ public class OrganizationsControllerTests : IDisposable await _sut.Delete(organizationId.ToString(), requestModel); - await _scaleSeatsCommand.Received(1) - .ScalePasswordManagerSeats(provider, organization.PlanType, -organization.Seats.Value); + await _providerBillingService.Received(1) + .ScaleSeats(provider, organization.PlanType, -organization.Seats.Value); await _organizationService.Received(1).DeleteAsync(organization); } diff --git a/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs b/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs index 4af60689c..9b6566bf6 100644 --- a/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs +++ b/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs @@ -14,8 +14,7 @@ using Bit.Core.Auth.Models.Api.Request.Accounts; using Bit.Core.Auth.Services; using Bit.Core.Auth.UserFeatures.UserKey; using Bit.Core.Auth.UserFeatures.UserMasterPassword.Interfaces; -using Bit.Core.Billing.Commands; -using Bit.Core.Billing.Queries; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -56,8 +55,7 @@ public class AccountsControllerTests : IDisposable private readonly ISetInitialMasterPasswordCommand _setInitialMasterPasswordCommand; private readonly IRotateUserKeyCommand _rotateUserKeyCommand; private readonly IFeatureService _featureService; - private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand; - private readonly ISubscriberQueries _subscriberQueries; + private readonly ISubscriberService _subscriberService; private readonly IReferenceEventService _referenceEventService; private readonly ICurrentContext _currentContext; @@ -89,8 +87,7 @@ public class AccountsControllerTests : IDisposable _setInitialMasterPasswordCommand = Substitute.For(); _rotateUserKeyCommand = Substitute.For(); _featureService = Substitute.For(); - _cancelSubscriptionCommand = Substitute.For(); - _subscriberQueries = Substitute.For(); + _subscriberService = Substitute.For(); _referenceEventService = Substitute.For(); _currentContext = Substitute.For(); _cipherValidator = @@ -121,8 +118,7 @@ public class AccountsControllerTests : IDisposable _setInitialMasterPasswordCommand, _rotateUserKeyCommand, _featureService, - _cancelSubscriptionCommand, - _subscriberQueries, + _subscriberService, _referenceEventService, _currentContext, _cipherValidator, diff --git a/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs b/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs index 8e495aa28..021705bed 100644 --- a/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs @@ -1,7 +1,7 @@ using Bit.Api.Billing.Controllers; using Bit.Api.Billing.Models.Responses; using Bit.Core.Billing.Models; -using Bit.Core.Billing.Queries; +using Bit.Core.Billing.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.Http.HttpResults; @@ -29,7 +29,7 @@ public class OrganizationBillingControllerTests Guid organizationId, SutProvider sutProvider) { - sutProvider.GetDependency().GetMetadata(organizationId) + sutProvider.GetDependency().GetMetadata(organizationId) .Returns(new OrganizationMetadataDTO(true)); var result = await sutProvider.Sut.GetMetadataAsync(organizationId); diff --git a/test/Api.Test/Billing/Controllers/OrganizationsControllerTests.cs b/test/Api.Test/Billing/Controllers/OrganizationsControllerTests.cs index b5737837e..1a28c344c 100644 --- a/test/Api.Test/Billing/Controllers/OrganizationsControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/OrganizationsControllerTests.cs @@ -10,8 +10,7 @@ using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Repositories; using Bit.Core.Auth.Services; -using Bit.Core.Billing.Commands; -using Bit.Core.Billing.Queries; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -45,9 +44,8 @@ public class OrganizationsControllerTests : IDisposable private readonly IUpdateSecretsManagerSubscriptionCommand _updateSecretsManagerSubscriptionCommand; private readonly IUpgradeOrganizationPlanCommand _upgradeOrganizationPlanCommand; private readonly IAddSecretsManagerSubscriptionCommand _addSecretsManagerSubscriptionCommand; - private readonly ICancelSubscriptionCommand _cancelSubscriptionCommand; - private readonly ISubscriberQueries _subscriberQueries; private readonly IReferenceEventService _referenceEventService; + private readonly ISubscriberService _subscriberService; private readonly OrganizationsController _sut; @@ -68,9 +66,8 @@ public class OrganizationsControllerTests : IDisposable _updateSecretsManagerSubscriptionCommand = Substitute.For(); _upgradeOrganizationPlanCommand = Substitute.For(); _addSecretsManagerSubscriptionCommand = Substitute.For(); - _cancelSubscriptionCommand = Substitute.For(); - _subscriberQueries = Substitute.For(); _referenceEventService = Substitute.For(); + _subscriberService = Substitute.For(); _sut = new OrganizationsController( _organizationRepository, @@ -85,9 +82,8 @@ public class OrganizationsControllerTests : IDisposable _updateSecretsManagerSubscriptionCommand, _upgradeOrganizationPlanCommand, _addSecretsManagerSubscriptionCommand, - _cancelSubscriptionCommand, - _subscriberQueries, - _referenceEventService); + _referenceEventService, + _subscriberService); } public void Dispose() diff --git a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs index 8e82e0209..ec7b3a28f 100644 --- a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs @@ -2,7 +2,7 @@ using Bit.Api.Billing.Models.Responses; using Bit.Core; using Bit.Core.Billing.Models; -using Bit.Core.Billing.Queries; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Services; @@ -61,7 +61,7 @@ public class ProviderBillingControllerTests sutProvider.GetDependency().ProviderProviderAdmin(providerId) .Returns(true); - sutProvider.GetDependency().GetSubscriptionDTO(providerId).ReturnsNull(); + sutProvider.GetDependency().GetSubscriptionDTO(providerId).ReturnsNull(); var result = await sutProvider.Sut.GetSubscriptionAsync(providerId); @@ -96,7 +96,7 @@ public class ProviderBillingControllerTests configuredProviderPlanDTOList, subscription); - sutProvider.GetDependency().GetSubscriptionDTO(providerId) + sutProvider.GetDependency().GetSubscriptionDTO(providerId) .Returns(providerSubscriptionDTO); var result = await sutProvider.Sut.GetSubscriptionAsync(providerId); diff --git a/test/Api.Test/Billing/Controllers/ProviderClientsControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderClientsControllerTests.cs index e0c9a27a6..fd445cd54 100644 --- a/test/Api.Test/Billing/Controllers/ProviderClientsControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/ProviderClientsControllerTests.cs @@ -6,7 +6,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; -using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Models.Business; @@ -185,7 +185,7 @@ public class ProviderClientsControllerTests Assert.IsType(result); - await sutProvider.GetDependency().Received(1).CreateCustomer( + await sutProvider.GetDependency().Received(1).CreateCustomerForClientOrganization( provider, clientOrganization); } @@ -327,7 +327,7 @@ public class ProviderClientsControllerTests var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .AssignSeatsToClientOrganization( provider, organization, @@ -368,7 +368,7 @@ public class ProviderClientsControllerTests var result = await sutProvider.Sut.UpdateAsync(providerId, providerOrganizationId, requestBody); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() .AssignSeatsToClientOrganization( Arg.Any(), Arg.Any(), diff --git a/test/Core.Test/Billing/Commands/AssignSeatsToClientOrganizationCommandTests.cs b/test/Core.Test/Billing/Commands/AssignSeatsToClientOrganizationCommandTests.cs deleted file mode 100644 index 918b7c47a..000000000 --- a/test/Core.Test/Billing/Commands/AssignSeatsToClientOrganizationCommandTests.cs +++ /dev/null @@ -1,339 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.Billing; -using Bit.Core.Billing.Commands.Implementations; -using Bit.Core.Billing.Entities; -using Bit.Core.Billing.Queries; -using Bit.Core.Billing.Repositories; -using Bit.Core.Enums; -using Bit.Core.Models.StaticStore; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Bit.Core.Utilities; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using NSubstitute; -using Xunit; - -using static Bit.Core.Test.Billing.Utilities; - -namespace Bit.Core.Test.Billing.Commands; - -[SutProviderCustomize] -public class AssignSeatsToClientOrganizationCommandTests -{ - [Theory, BitAutoData] - public Task AssignSeatsToClientOrganization_NullProvider_ArgumentNullException( - Organization organization, - int seats, - SutProvider sutProvider) - => Assert.ThrowsAsync(() => - sutProvider.Sut.AssignSeatsToClientOrganization(null, organization, seats)); - - [Theory, BitAutoData] - public Task AssignSeatsToClientOrganization_NullOrganization_ArgumentNullException( - Provider provider, - int seats, - SutProvider sutProvider) - => Assert.ThrowsAsync(() => - sutProvider.Sut.AssignSeatsToClientOrganization(provider, null, seats)); - - [Theory, BitAutoData] - public Task AssignSeatsToClientOrganization_NegativeSeats_BillingException( - Provider provider, - Organization organization, - SutProvider sutProvider) - => Assert.ThrowsAsync(() => - sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, -5)); - - [Theory, BitAutoData] - public async Task AssignSeatsToClientOrganization_CurrentSeatsMatchesNewSeats_NoOp( - Provider provider, - Organization organization, - int seats, - SutProvider sutProvider) - { - organization.PlanType = PlanType.TeamsMonthly; - - organization.Seats = seats; - - await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); - - await sutProvider.GetDependency().DidNotReceive().GetByProviderId(provider.Id); - } - - [Theory, BitAutoData] - public async Task AssignSeatsToClientOrganization_OrganizationPlanTypeDoesNotSupportConsolidatedBilling_ContactSupport( - Provider provider, - Organization organization, - int seats, - SutProvider sutProvider) - { - organization.PlanType = PlanType.FamiliesAnnually; - - await ThrowsContactSupportAsync(() => sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats)); - } - - [Theory, BitAutoData] - public async Task AssignSeatsToClientOrganization_ProviderPlanIsNotConfigured_ContactSupport( - Provider provider, - Organization organization, - int seats, - SutProvider sutProvider) - { - organization.PlanType = PlanType.TeamsMonthly; - - sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(new List - { - new () - { - Id = Guid.NewGuid(), - PlanType = PlanType.TeamsMonthly, - ProviderId = provider.Id - } - }); - - await ThrowsContactSupportAsync(() => sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats)); - } - - [Theory, BitAutoData] - public async Task AssignSeatsToClientOrganization_BelowToBelow_Succeeds( - Provider provider, - Organization organization, - SutProvider sutProvider) - { - organization.Seats = 10; - - organization.PlanType = PlanType.TeamsMonthly; - - // Scale up 10 seats - const int seats = 20; - - var providerPlans = new List - { - new() - { - Id = Guid.NewGuid(), - PlanType = PlanType.TeamsMonthly, - ProviderId = provider.Id, - PurchasedSeats = 0, - // 100 minimum - SeatMinimum = 100, - AllocatedSeats = 50 - }, - new() - { - Id = Guid.NewGuid(), - PlanType = PlanType.EnterpriseMonthly, - ProviderId = provider.Id, - PurchasedSeats = 0, - SeatMinimum = 500, - AllocatedSeats = 0 - } - }; - - var providerPlan = providerPlans.First(); - - sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); - - // 50 seats currently assigned with a seat minimum of 100 - sutProvider.GetDependency().GetAssignedSeatTotalForPlanOrThrow(provider.Id, providerPlan.PlanType).Returns(50); - - await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); - - // 50 assigned seats + 10 seat scale up = 60 seats, well below the 100 minimum - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().AdjustSeats( - Arg.Any(), - Arg.Any(), - Arg.Any(), - Arg.Any()); - - await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( - org => org.Id == organization.Id && org.Seats == seats)); - - await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( - pPlan => pPlan.AllocatedSeats == 60)); - } - - [Theory, BitAutoData] - public async Task AssignSeatsToClientOrganization_BelowToAbove_Succeeds( - Provider provider, - Organization organization, - SutProvider sutProvider) - { - organization.Seats = 10; - - organization.PlanType = PlanType.TeamsMonthly; - - // Scale up 10 seats - const int seats = 20; - - var providerPlans = new List - { - new() - { - Id = Guid.NewGuid(), - PlanType = PlanType.TeamsMonthly, - ProviderId = provider.Id, - PurchasedSeats = 0, - // 100 minimum - SeatMinimum = 100, - AllocatedSeats = 95 - }, - new() - { - Id = Guid.NewGuid(), - PlanType = PlanType.EnterpriseMonthly, - ProviderId = provider.Id, - PurchasedSeats = 0, - SeatMinimum = 500, - AllocatedSeats = 0 - } - }; - - var providerPlan = providerPlans.First(); - - sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); - - // 95 seats currently assigned with a seat minimum of 100 - sutProvider.GetDependency().GetAssignedSeatTotalForPlanOrThrow(provider.Id, providerPlan.PlanType).Returns(95); - - await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); - - // 95 current + 10 seat scale = 105 seats, 5 above the minimum - await sutProvider.GetDependency().Received(1).AdjustSeats( - provider, - StaticStore.GetPlan(providerPlan.PlanType), - providerPlan.SeatMinimum!.Value, - 105); - - await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( - org => org.Id == organization.Id && org.Seats == seats)); - - // 105 total seats - 100 minimum = 5 purchased seats - await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( - pPlan => pPlan.Id == providerPlan.Id && pPlan.PurchasedSeats == 5 && pPlan.AllocatedSeats == 105)); - } - - [Theory, BitAutoData] - public async Task AssignSeatsToClientOrganization_AboveToAbove_Succeeds( - Provider provider, - Organization organization, - SutProvider sutProvider) - { - organization.Seats = 10; - - organization.PlanType = PlanType.TeamsMonthly; - - // Scale up 10 seats - const int seats = 20; - - var providerPlans = new List - { - new() - { - Id = Guid.NewGuid(), - PlanType = PlanType.TeamsMonthly, - ProviderId = provider.Id, - // 10 additional purchased seats - PurchasedSeats = 10, - // 100 seat minimum - SeatMinimum = 100, - AllocatedSeats = 110 - }, - new() - { - Id = Guid.NewGuid(), - PlanType = PlanType.EnterpriseMonthly, - ProviderId = provider.Id, - PurchasedSeats = 0, - SeatMinimum = 500, - AllocatedSeats = 0 - } - }; - - var providerPlan = providerPlans.First(); - - sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); - - // 110 seats currently assigned with a seat minimum of 100 - sutProvider.GetDependency().GetAssignedSeatTotalForPlanOrThrow(provider.Id, providerPlan.PlanType).Returns(110); - - await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); - - // 110 current + 10 seat scale up = 120 seats - await sutProvider.GetDependency().Received(1).AdjustSeats( - provider, - StaticStore.GetPlan(providerPlan.PlanType), - 110, - 120); - - await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( - org => org.Id == organization.Id && org.Seats == seats)); - - // 120 total seats - 100 seat minimum = 20 purchased seats - await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( - pPlan => pPlan.Id == providerPlan.Id && pPlan.PurchasedSeats == 20 && pPlan.AllocatedSeats == 120)); - } - - [Theory, BitAutoData] - public async Task AssignSeatsToClientOrganization_AboveToBelow_Succeeds( - Provider provider, - Organization organization, - SutProvider sutProvider) - { - organization.Seats = 50; - - organization.PlanType = PlanType.TeamsMonthly; - - // Scale down 30 seats - const int seats = 20; - - var providerPlans = new List - { - new() - { - Id = Guid.NewGuid(), - PlanType = PlanType.TeamsMonthly, - ProviderId = provider.Id, - // 10 additional purchased seats - PurchasedSeats = 10, - // 100 seat minimum - SeatMinimum = 100, - AllocatedSeats = 110 - }, - new() - { - Id = Guid.NewGuid(), - PlanType = PlanType.EnterpriseMonthly, - ProviderId = provider.Id, - PurchasedSeats = 0, - SeatMinimum = 500, - AllocatedSeats = 0 - } - }; - - var providerPlan = providerPlans.First(); - - sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); - - // 110 seats currently assigned with a seat minimum of 100 - sutProvider.GetDependency().GetAssignedSeatTotalForPlanOrThrow(provider.Id, providerPlan.PlanType).Returns(110); - - await sutProvider.Sut.AssignSeatsToClientOrganization(provider, organization, seats); - - // 110 seats - 30 scale down seats = 80 seats, below the 100 seat minimum. - await sutProvider.GetDependency().Received(1).AdjustSeats( - provider, - StaticStore.GetPlan(providerPlan.PlanType), - 110, - providerPlan.SeatMinimum!.Value); - - await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( - org => org.Id == organization.Id && org.Seats == seats)); - - // Being below the seat minimum means no purchased seats. - await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( - pPlan => pPlan.Id == providerPlan.Id && pPlan.PurchasedSeats == 0 && pPlan.AllocatedSeats == 80)); - } -} diff --git a/test/Core.Test/Billing/Commands/CancelSubscriptionCommandTests.cs b/test/Core.Test/Billing/Commands/CancelSubscriptionCommandTests.cs deleted file mode 100644 index ba98c26a5..000000000 --- a/test/Core.Test/Billing/Commands/CancelSubscriptionCommandTests.cs +++ /dev/null @@ -1,163 +0,0 @@ -using System.Linq.Expressions; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Billing.Commands.Implementations; -using Bit.Core.Billing.Models; -using Bit.Core.Services; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using NSubstitute; -using Stripe; -using Xunit; - -using static Bit.Core.Test.Billing.Utilities; - -namespace Bit.Core.Test.Billing.Commands; - -[SutProviderCustomize] -public class CancelSubscriptionCommandTests -{ - private const string _subscriptionId = "subscription_id"; - private const string _cancellingUserIdKey = "cancellingUserId"; - - [Theory, BitAutoData] - public async Task CancelSubscription_SubscriptionInactive_ThrowsGatewayException( - SutProvider sutProvider) - { - var subscription = new Subscription - { - Status = "canceled" - }; - - await ThrowsContactSupportAsync(() => - sutProvider.Sut.CancelSubscription(subscription, new OffboardingSurveyResponse(), false)); - - await DidNotUpdateSubscription(sutProvider); - - await DidNotCancelSubscription(sutProvider); - } - - [Theory, BitAutoData] - public async Task CancelSubscription_CancelImmediately_BelongsToOrganization_UpdatesSubscription_CancelSubscriptionImmediately( - SutProvider sutProvider) - { - var userId = Guid.NewGuid(); - - var subscription = new Subscription - { - Id = _subscriptionId, - Status = "active", - Metadata = new Dictionary - { - { "organizationId", "organization_id" } - } - }; - - var offboardingSurveyResponse = new OffboardingSurveyResponse - { - UserId = userId, - Reason = "missing_features", - Feedback = "Lorem ipsum" - }; - - await sutProvider.Sut.CancelSubscription(subscription, offboardingSurveyResponse, true); - - await UpdatedSubscriptionWith(sutProvider, options => options.Metadata[_cancellingUserIdKey] == userId.ToString()); - - await CancelledSubscriptionWith(sutProvider, options => - options.CancellationDetails.Comment == offboardingSurveyResponse.Feedback && - options.CancellationDetails.Feedback == offboardingSurveyResponse.Reason); - } - - [Theory, BitAutoData] - public async Task CancelSubscription_CancelImmediately_BelongsToUser_CancelSubscriptionImmediately( - SutProvider sutProvider) - { - var userId = Guid.NewGuid(); - - var subscription = new Subscription - { - Id = _subscriptionId, - Status = "active", - Metadata = new Dictionary - { - { "userId", "user_id" } - } - }; - - var offboardingSurveyResponse = new OffboardingSurveyResponse - { - UserId = userId, - Reason = "missing_features", - Feedback = "Lorem ipsum" - }; - - await sutProvider.Sut.CancelSubscription(subscription, offboardingSurveyResponse, true); - - await DidNotUpdateSubscription(sutProvider); - - await CancelledSubscriptionWith(sutProvider, options => - options.CancellationDetails.Comment == offboardingSurveyResponse.Feedback && - options.CancellationDetails.Feedback == offboardingSurveyResponse.Reason); - } - - [Theory, BitAutoData] - public async Task CancelSubscription_DoNotCancelImmediately_UpdateSubscriptionToCancelAtEndOfPeriod( - Organization organization, - SutProvider sutProvider) - { - var userId = Guid.NewGuid(); - - organization.ExpirationDate = DateTime.UtcNow.AddDays(5); - - var subscription = new Subscription - { - Id = _subscriptionId, - Status = "active" - }; - - var offboardingSurveyResponse = new OffboardingSurveyResponse - { - UserId = userId, - Reason = "missing_features", - Feedback = "Lorem ipsum" - }; - - await sutProvider.Sut.CancelSubscription(subscription, offboardingSurveyResponse, false); - - await UpdatedSubscriptionWith(sutProvider, options => - options.CancelAtPeriodEnd == true && - options.CancellationDetails.Comment == offboardingSurveyResponse.Feedback && - options.CancellationDetails.Feedback == offboardingSurveyResponse.Reason && - options.Metadata[_cancellingUserIdKey] == userId.ToString()); - - await DidNotCancelSubscription(sutProvider); - } - - private static Task DidNotCancelSubscription(SutProvider sutProvider) - => sutProvider - .GetDependency() - .DidNotReceiveWithAnyArgs() - .SubscriptionCancelAsync(Arg.Any(), Arg.Any()); - - private static Task DidNotUpdateSubscription(SutProvider sutProvider) - => sutProvider - .GetDependency() - .DidNotReceiveWithAnyArgs() - .SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); - - private static Task CancelledSubscriptionWith( - SutProvider sutProvider, - Expression> predicate) - => sutProvider - .GetDependency() - .Received(1) - .SubscriptionCancelAsync(_subscriptionId, Arg.Is(predicate)); - - private static Task UpdatedSubscriptionWith( - SutProvider sutProvider, - Expression> predicate) - => sutProvider - .GetDependency() - .Received(1) - .SubscriptionUpdateAsync(_subscriptionId, Arg.Is(predicate)); -} diff --git a/test/Core.Test/Billing/Commands/CreateCustomerCommandTests.cs b/test/Core.Test/Billing/Commands/CreateCustomerCommandTests.cs deleted file mode 100644 index b532879e9..000000000 --- a/test/Core.Test/Billing/Commands/CreateCustomerCommandTests.cs +++ /dev/null @@ -1,129 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.Billing.Commands.Implementations; -using Bit.Core.Billing.Queries; -using Bit.Core.Entities; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Bit.Core.Settings; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using NSubstitute; -using Stripe; -using Xunit; -using GlobalSettings = Bit.Core.Settings.GlobalSettings; - -namespace Bit.Core.Test.Billing.Commands; - -[SutProviderCustomize] -public class CreateCustomerCommandTests -{ - private const string _customerId = "customer_id"; - - [Theory, BitAutoData] - public async Task CreateCustomer_ForClientOrg_ProviderNull_ThrowsArgumentNullException( - Organization organization, - SutProvider sutProvider) => - await Assert.ThrowsAsync(() => sutProvider.Sut.CreateCustomer(null, organization)); - - [Theory, BitAutoData] - public async Task CreateCustomer_ForClientOrg_OrganizationNull_ThrowsArgumentNullException( - Provider provider, - SutProvider sutProvider) => - await Assert.ThrowsAsync(() => sutProvider.Sut.CreateCustomer(provider, null)); - - [Theory, BitAutoData] - public async Task CreateCustomer_ForClientOrg_HasGatewayCustomerId_NoOp( - Provider provider, - Organization organization, - SutProvider sutProvider) - { - organization.GatewayCustomerId = _customerId; - - await sutProvider.Sut.CreateCustomer(provider, organization); - - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .GetCustomerOrThrow(Arg.Any(), Arg.Any()); - } - - [Theory, BitAutoData] - public async Task CreateCustomer_ForClientOrg_Succeeds( - Provider provider, - Organization organization, - SutProvider sutProvider) - { - organization.GatewayCustomerId = null; - organization.Name = "Name"; - organization.BusinessName = "BusinessName"; - - var providerCustomer = new Customer - { - Address = new Address - { - Country = "USA", - PostalCode = "12345", - Line1 = "123 Main St.", - Line2 = "Unit 4", - City = "Fake Town", - State = "Fake State" - }, - TaxIds = new StripeList - { - Data = - [ - new TaxId { Type = "TYPE", Value = "VALUE" } - ] - } - }; - - sutProvider.GetDependency().GetCustomerOrThrow(provider, Arg.Is( - options => options.Expand.FirstOrDefault() == "tax_ids")) - .Returns(providerCustomer); - - sutProvider.GetDependency().BaseServiceUri - .Returns(new GlobalSettings.BaseServiceUriSettings(new GlobalSettings()) { CloudRegion = "US" }); - - sutProvider.GetDependency().CustomerCreateAsync(Arg.Is( - options => - options.Address.Country == providerCustomer.Address.Country && - options.Address.PostalCode == providerCustomer.Address.PostalCode && - options.Address.Line1 == providerCustomer.Address.Line1 && - options.Address.Line2 == providerCustomer.Address.Line2 && - options.Address.City == providerCustomer.Address.City && - options.Address.State == providerCustomer.Address.State && - options.Name == organization.DisplayName() && - options.Description == $"{provider.Name} Client Organization" && - options.Email == provider.BillingEmail && - options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" && - options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" && - options.Metadata["region"] == "US" && - options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type && - options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value)) - .Returns(new Customer - { - Id = "customer_id" - }); - - await sutProvider.Sut.CreateCustomer(provider, organization); - - await sutProvider.GetDependency().Received(1).CustomerCreateAsync(Arg.Is( - options => - options.Address.Country == providerCustomer.Address.Country && - options.Address.PostalCode == providerCustomer.Address.PostalCode && - options.Address.Line1 == providerCustomer.Address.Line1 && - options.Address.Line2 == providerCustomer.Address.Line2 && - options.Address.City == providerCustomer.Address.City && - options.Address.State == providerCustomer.Address.State && - options.Name == organization.DisplayName() && - options.Description == $"{provider.Name} Client Organization" && - options.Email == provider.BillingEmail && - options.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Organization" && - options.InvoiceSettings.CustomFields.FirstOrDefault().Value == "Name" && - options.Metadata["region"] == "US" && - options.TaxIdData.FirstOrDefault().Type == providerCustomer.TaxIds.FirstOrDefault().Type && - options.TaxIdData.FirstOrDefault().Value == providerCustomer.TaxIds.FirstOrDefault().Value)); - - await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is( - org => org.GatewayCustomerId == "customer_id")); - } -} diff --git a/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs b/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs deleted file mode 100644 index 968bfeb84..000000000 --- a/test/Core.Test/Billing/Commands/RemovePaymentMethodCommandTests.cs +++ /dev/null @@ -1,358 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Billing.Commands.Implementations; -using Bit.Core.Enums; -using Bit.Core.Services; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using NSubstitute; -using NSubstitute.ReturnsExtensions; -using Xunit; -using static Bit.Core.Test.Billing.Utilities; -using BT = Braintree; -using S = Stripe; - -namespace Bit.Core.Test.Billing.Commands; - -[SutProviderCustomize] -public class RemovePaymentMethodCommandTests -{ - [Theory, BitAutoData] - public async Task RemovePaymentMethod_NullOrganization_ArgumentNullException( - SutProvider sutProvider) => - await Assert.ThrowsAsync(() => sutProvider.Sut.RemovePaymentMethod(null)); - - [Theory, BitAutoData] - public async Task RemovePaymentMethod_NonStripeGateway_ContactSupport( - Organization organization, - SutProvider sutProvider) - { - organization.Gateway = GatewayType.BitPay; - - await ThrowsContactSupportAsync(() => sutProvider.Sut.RemovePaymentMethod(organization)); - } - - [Theory, BitAutoData] - public async Task RemovePaymentMethod_NoGatewayCustomerId_ContactSupport( - Organization organization, - SutProvider sutProvider) - { - organization.Gateway = GatewayType.Stripe; - organization.GatewayCustomerId = null; - - await ThrowsContactSupportAsync(() => sutProvider.Sut.RemovePaymentMethod(organization)); - } - - [Theory, BitAutoData] - public async Task RemovePaymentMethod_NoStripeCustomer_ContactSupport( - Organization organization, - SutProvider sutProvider) - { - organization.Gateway = GatewayType.Stripe; - - sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) - .ReturnsNull(); - - await ThrowsContactSupportAsync(() => sutProvider.Sut.RemovePaymentMethod(organization)); - } - - [Theory, BitAutoData] - public async Task RemovePaymentMethod_Braintree_NoCustomer_ContactSupport( - Organization organization, - SutProvider sutProvider) - { - organization.Gateway = GatewayType.Stripe; - - const string braintreeCustomerId = "1"; - - var stripeCustomer = new S.Customer - { - Metadata = new Dictionary - { - { "btCustomerId", braintreeCustomerId } - } - }; - - sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) - .Returns(stripeCustomer); - - var (braintreeGateway, customerGateway, paymentMethodGateway) = Setup(sutProvider.GetDependency()); - - customerGateway.FindAsync(braintreeCustomerId).ReturnsNull(); - - braintreeGateway.Customer.Returns(customerGateway); - - await ThrowsContactSupportAsync(() => sutProvider.Sut.RemovePaymentMethod(organization)); - - await customerGateway.Received(1).FindAsync(braintreeCustomerId); - - await customerGateway.DidNotReceiveWithAnyArgs() - .UpdateAsync(Arg.Any(), Arg.Any()); - - await paymentMethodGateway.DidNotReceiveWithAnyArgs().DeleteAsync(Arg.Any()); - } - - [Theory, BitAutoData] - public async Task RemovePaymentMethod_Braintree_NoPaymentMethod_NoOp( - Organization organization, - SutProvider sutProvider) - { - organization.Gateway = GatewayType.Stripe; - - const string braintreeCustomerId = "1"; - - var stripeCustomer = new S.Customer - { - Metadata = new Dictionary - { - { "btCustomerId", braintreeCustomerId } - } - }; - - sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) - .Returns(stripeCustomer); - - var (_, customerGateway, paymentMethodGateway) = Setup(sutProvider.GetDependency()); - - var braintreeCustomer = Substitute.For(); - - braintreeCustomer.PaymentMethods.Returns(Array.Empty()); - - customerGateway.FindAsync(braintreeCustomerId).Returns(braintreeCustomer); - - await sutProvider.Sut.RemovePaymentMethod(organization); - - await customerGateway.Received(1).FindAsync(braintreeCustomerId); - - await customerGateway.DidNotReceiveWithAnyArgs().UpdateAsync(Arg.Any(), Arg.Any()); - - await paymentMethodGateway.DidNotReceiveWithAnyArgs().DeleteAsync(Arg.Any()); - } - - [Theory, BitAutoData] - public async Task RemovePaymentMethod_Braintree_CustomerUpdateFails_ContactSupport( - Organization organization, - SutProvider sutProvider) - { - organization.Gateway = GatewayType.Stripe; - - const string braintreeCustomerId = "1"; - const string braintreePaymentMethodToken = "TOKEN"; - - var stripeCustomer = new S.Customer - { - Metadata = new Dictionary - { - { "btCustomerId", braintreeCustomerId } - } - }; - - sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) - .Returns(stripeCustomer); - - var (_, customerGateway, paymentMethodGateway) = Setup(sutProvider.GetDependency()); - - var braintreeCustomer = Substitute.For(); - - var paymentMethod = Substitute.For(); - paymentMethod.Token.Returns(braintreePaymentMethodToken); - paymentMethod.IsDefault.Returns(true); - - braintreeCustomer.PaymentMethods.Returns(new[] - { - paymentMethod - }); - - customerGateway.FindAsync(braintreeCustomerId).Returns(braintreeCustomer); - - var updateBraintreeCustomerResult = Substitute.For>(); - updateBraintreeCustomerResult.IsSuccess().Returns(false); - - customerGateway.UpdateAsync( - braintreeCustomerId, - Arg.Is(request => request.DefaultPaymentMethodToken == null)) - .Returns(updateBraintreeCustomerResult); - - await ThrowsContactSupportAsync(() => sutProvider.Sut.RemovePaymentMethod(organization)); - - await customerGateway.Received(1).FindAsync(braintreeCustomerId); - - await customerGateway.Received(1).UpdateAsync(braintreeCustomerId, Arg.Is(request => - request.DefaultPaymentMethodToken == null)); - - await paymentMethodGateway.DidNotReceiveWithAnyArgs().DeleteAsync(paymentMethod.Token); - - await customerGateway.DidNotReceive().UpdateAsync(braintreeCustomerId, Arg.Is(request => - request.DefaultPaymentMethodToken == paymentMethod.Token)); - } - - [Theory, BitAutoData] - public async Task RemovePaymentMethod_Braintree_PaymentMethodDeleteFails_RollBack_ContactSupport( - Organization organization, - SutProvider sutProvider) - { - organization.Gateway = GatewayType.Stripe; - - const string braintreeCustomerId = "1"; - const string braintreePaymentMethodToken = "TOKEN"; - - var stripeCustomer = new S.Customer - { - Metadata = new Dictionary - { - { "btCustomerId", braintreeCustomerId } - } - }; - - sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) - .Returns(stripeCustomer); - - var (_, customerGateway, paymentMethodGateway) = Setup(sutProvider.GetDependency()); - - var braintreeCustomer = Substitute.For(); - - var paymentMethod = Substitute.For(); - paymentMethod.Token.Returns(braintreePaymentMethodToken); - paymentMethod.IsDefault.Returns(true); - - braintreeCustomer.PaymentMethods.Returns(new[] - { - paymentMethod - }); - - customerGateway.FindAsync(braintreeCustomerId).Returns(braintreeCustomer); - - var updateBraintreeCustomerResult = Substitute.For>(); - updateBraintreeCustomerResult.IsSuccess().Returns(true); - - customerGateway.UpdateAsync(braintreeCustomerId, Arg.Any()) - .Returns(updateBraintreeCustomerResult); - - var deleteBraintreePaymentMethodResult = Substitute.For>(); - deleteBraintreePaymentMethodResult.IsSuccess().Returns(false); - - paymentMethodGateway.DeleteAsync(paymentMethod.Token).Returns(deleteBraintreePaymentMethodResult); - - await ThrowsContactSupportAsync(() => sutProvider.Sut.RemovePaymentMethod(organization)); - - await customerGateway.Received(1).FindAsync(braintreeCustomerId); - - await customerGateway.Received(1).UpdateAsync(braintreeCustomerId, Arg.Is(request => - request.DefaultPaymentMethodToken == null)); - - await paymentMethodGateway.Received(1).DeleteAsync(paymentMethod.Token); - - await customerGateway.Received(1).UpdateAsync(braintreeCustomerId, Arg.Is(request => - request.DefaultPaymentMethodToken == paymentMethod.Token)); - } - - [Theory, BitAutoData] - public async Task RemovePaymentMethod_Stripe_Legacy_RemovesSources( - Organization organization, - SutProvider sutProvider) - { - organization.Gateway = GatewayType.Stripe; - - const string bankAccountId = "bank_account_id"; - const string cardId = "card_id"; - - var sources = new List - { - new S.BankAccount { Id = bankAccountId }, new S.Card { Id = cardId } - }; - - var stripeCustomer = new S.Customer { Sources = new S.StripeList { Data = sources } }; - - var stripeAdapter = sutProvider.GetDependency(); - - stripeAdapter - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) - .Returns(stripeCustomer); - - stripeAdapter - .PaymentMethodListAutoPagingAsync(Arg.Any()) - .Returns(GetPaymentMethodsAsync(new List())); - - await sutProvider.Sut.RemovePaymentMethod(organization); - - await stripeAdapter.Received(1).BankAccountDeleteAsync(stripeCustomer.Id, bankAccountId); - - await stripeAdapter.Received(1).CardDeleteAsync(stripeCustomer.Id, cardId); - - await stripeAdapter.DidNotReceiveWithAnyArgs() - .PaymentMethodDetachAsync(Arg.Any(), Arg.Any()); - } - - [Theory, BitAutoData] - public async Task RemovePaymentMethod_Stripe_DetachesPaymentMethods( - Organization organization, - SutProvider sutProvider) - { - organization.Gateway = GatewayType.Stripe; - const string bankAccountId = "bank_account_id"; - const string cardId = "card_id"; - - var sources = new List(); - - var stripeCustomer = new S.Customer { Sources = new S.StripeList { Data = sources } }; - - var stripeAdapter = sutProvider.GetDependency(); - - stripeAdapter - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) - .Returns(stripeCustomer); - - stripeAdapter - .PaymentMethodListAutoPagingAsync(Arg.Any()) - .Returns(GetPaymentMethodsAsync(new List - { - new () - { - Id = bankAccountId - }, - new () - { - Id = cardId - } - })); - - await sutProvider.Sut.RemovePaymentMethod(organization); - - await stripeAdapter.DidNotReceiveWithAnyArgs().BankAccountDeleteAsync(Arg.Any(), Arg.Any()); - - await stripeAdapter.DidNotReceiveWithAnyArgs().CardDeleteAsync(Arg.Any(), Arg.Any()); - - await stripeAdapter.Received(1) - .PaymentMethodDetachAsync(bankAccountId, Arg.Any()); - - await stripeAdapter.Received(1) - .PaymentMethodDetachAsync(cardId, Arg.Any()); - } - - private static async IAsyncEnumerable GetPaymentMethodsAsync( - IEnumerable paymentMethods) - { - foreach (var paymentMethod in paymentMethods) - { - yield return paymentMethod; - } - - await Task.CompletedTask; - } - - private static (BT.IBraintreeGateway, BT.ICustomerGateway, BT.IPaymentMethodGateway) Setup( - BT.IBraintreeGateway braintreeGateway) - { - var customerGateway = Substitute.For(); - var paymentMethodGateway = Substitute.For(); - - braintreeGateway.Customer.Returns(customerGateway); - braintreeGateway.PaymentMethod.Returns(paymentMethodGateway); - - return (braintreeGateway, customerGateway, paymentMethodGateway); - } -} diff --git a/test/Core.Test/Billing/Commands/StartSubscriptionCommandTests.cs b/test/Core.Test/Billing/Commands/StartSubscriptionCommandTests.cs deleted file mode 100644 index 6e8213c2d..000000000 --- a/test/Core.Test/Billing/Commands/StartSubscriptionCommandTests.cs +++ /dev/null @@ -1,420 +0,0 @@ -using System.Net; -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Commands.Implementations; -using Bit.Core.Billing.Constants; -using Bit.Core.Billing.Entities; -using Bit.Core.Billing.Repositories; -using Bit.Core.Enums; -using Bit.Core.Models.Business; -using Bit.Core.Services; -using Bit.Core.Utilities; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using NSubstitute; -using Stripe; -using Xunit; - -using static Bit.Core.Test.Billing.Utilities; - -namespace Bit.Core.Test.Billing.Commands; - -[SutProviderCustomize] -public class StartSubscriptionCommandTests -{ - private const string _customerId = "customer_id"; - private const string _subscriptionId = "subscription_id"; - - // These tests are only trying to assert on the thrown exceptions and thus use the least amount of data setup possible. - #region Error Cases - [Theory, BitAutoData] - public async Task StartSubscription_NullProvider_ThrowsArgumentNullException( - SutProvider sutProvider, - TaxInfo taxInfo) => - await Assert.ThrowsAsync(() => sutProvider.Sut.StartSubscription(null, taxInfo)); - - [Theory, BitAutoData] - public async Task StartSubscription_NullTaxInfo_ThrowsArgumentNullException( - SutProvider sutProvider, - Provider provider) => - await Assert.ThrowsAsync(() => sutProvider.Sut.StartSubscription(provider, null)); - - [Theory, BitAutoData] - public async Task StartSubscription_AlreadyHasGatewaySubscriptionId_ThrowsBillingException( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo) - { - provider.GatewayCustomerId = _customerId; - - provider.GatewaySubscriptionId = _subscriptionId; - - await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider, taxInfo)); - - await DidNotRetrieveCustomerAsync(sutProvider); - } - - [Theory, BitAutoData] - public async Task StartSubscription_MissingCountry_ThrowsBillingException( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo) - { - provider.GatewayCustomerId = _customerId; - - provider.GatewaySubscriptionId = null; - - taxInfo.BillingAddressCountry = null; - - await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider, taxInfo)); - - await DidNotRetrieveCustomerAsync(sutProvider); - } - - [Theory, BitAutoData] - public async Task StartSubscription_MissingPostalCode_ThrowsBillingException( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo) - { - provider.GatewayCustomerId = _customerId; - - provider.GatewaySubscriptionId = null; - - taxInfo.BillingAddressPostalCode = null; - - await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider, taxInfo)); - - await DidNotRetrieveCustomerAsync(sutProvider); - } - - [Theory, BitAutoData] - public async Task StartSubscription_MissingStripeCustomer_ThrowsBillingException( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo) - { - provider.GatewayCustomerId = _customerId; - - provider.GatewaySubscriptionId = null; - - SetCustomerRetrieval(sutProvider, null); - - await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider, taxInfo)); - - await DidNotRetrieveProviderPlansAsync(sutProvider); - } - - [Theory, BitAutoData] - public async Task StartSubscription_NoProviderPlans_ThrowsBillingException( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo) - { - provider.GatewayCustomerId = _customerId; - - provider.GatewaySubscriptionId = null; - - SetCustomerRetrieval(sutProvider, new Customer - { - Id = _customerId, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - } - }); - - sutProvider.GetDependency().GetByProviderId(provider.Id) - .Returns(new List()); - - await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider, taxInfo)); - - await DidNotCreateSubscriptionAsync(sutProvider); - } - - [Theory, BitAutoData] - public async Task StartSubscription_NoProviderTeamsPlan_ThrowsBillingException( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo) - { - provider.GatewayCustomerId = _customerId; - - provider.GatewaySubscriptionId = null; - - SetCustomerRetrieval(sutProvider, new Customer - { - Id = _customerId, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - } - }); - - var providerPlans = new List - { - new () - { - PlanType = PlanType.EnterpriseMonthly - } - }; - - sutProvider.GetDependency().GetByProviderId(provider.Id) - .Returns(providerPlans); - - await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider, taxInfo)); - - await DidNotCreateSubscriptionAsync(sutProvider); - } - - [Theory, BitAutoData] - public async Task StartSubscription_NoProviderEnterprisePlan_ThrowsBillingException( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo) - { - provider.GatewayCustomerId = _customerId; - - provider.GatewaySubscriptionId = null; - - SetCustomerRetrieval(sutProvider, new Customer - { - Id = _customerId, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - } - }); - - var providerPlans = new List - { - new () - { - PlanType = PlanType.TeamsMonthly - } - }; - - sutProvider.GetDependency().GetByProviderId(provider.Id) - .Returns(providerPlans); - - await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider, taxInfo)); - - await DidNotCreateSubscriptionAsync(sutProvider); - } - - [Theory, BitAutoData] - public async Task StartSubscription_SubscriptionIncomplete_ThrowsBillingException( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo) - { - provider.GatewayCustomerId = _customerId; - - provider.GatewaySubscriptionId = null; - - SetCustomerRetrieval(sutProvider, new Customer - { - Id = _customerId, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - } - }); - - var providerPlans = new List - { - new () - { - PlanType = PlanType.TeamsMonthly, - SeatMinimum = 100 - }, - new () - { - PlanType = PlanType.EnterpriseMonthly, - SeatMinimum = 100 - } - }; - - sutProvider.GetDependency().GetByProviderId(provider.Id) - .Returns(providerPlans); - - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Any()).Returns(new Subscription - { - Id = _subscriptionId, - Status = StripeConstants.SubscriptionStatus.Incomplete - }); - - await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider, taxInfo)); - - await sutProvider.GetDependency().Received(1).ReplaceAsync(provider); - } - #endregion - - #region Success Cases - [Theory, BitAutoData] - public async Task StartSubscription_ExistingCustomer_Succeeds( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo) - { - provider.GatewayCustomerId = _customerId; - - provider.GatewaySubscriptionId = null; - - SetCustomerRetrieval(sutProvider, new Customer - { - Id = _customerId, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - } - }); - - var providerPlans = new List - { - new () - { - PlanType = PlanType.TeamsMonthly, - SeatMinimum = 100 - }, - new () - { - PlanType = PlanType.EnterpriseMonthly, - SeatMinimum = 100 - } - }; - - sutProvider.GetDependency().GetByProviderId(provider.Id) - .Returns(providerPlans); - - var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); - var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); - - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( - sub => - sub.AutomaticTax.Enabled == true && - sub.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && - sub.Customer == _customerId && - sub.DaysUntilDue == 30 && - sub.Items.Count == 2 && - sub.Items.ElementAt(0).Price == teamsPlan.PasswordManager.StripeSeatPlanId && - sub.Items.ElementAt(0).Quantity == 100 && - sub.Items.ElementAt(1).Price == enterprisePlan.PasswordManager.StripeSeatPlanId && - sub.Items.ElementAt(1).Quantity == 100 && - sub.Metadata["providerId"] == provider.Id.ToString() && - sub.OffSession == true && - sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations)).Returns(new Subscription - { - Id = _subscriptionId, - Status = StripeConstants.SubscriptionStatus.Active - }); - - await sutProvider.Sut.StartSubscription(provider, taxInfo); - - await sutProvider.GetDependency().Received(1).ReplaceAsync(provider); - } - - [Theory, BitAutoData] - public async Task StartSubscription_NewCustomer_Succeeds( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo) - { - provider.GatewayCustomerId = null; - - provider.GatewaySubscriptionId = null; - - provider.Name = "MSP"; - - taxInfo.BillingAddressCountry = "AD"; - - sutProvider.GetDependency().CustomerCreateAsync(Arg.Is(o => - o.Address.Country == taxInfo.BillingAddressCountry && - o.Address.PostalCode == taxInfo.BillingAddressPostalCode && - o.Address.Line1 == taxInfo.BillingAddressLine1 && - o.Address.Line2 == taxInfo.BillingAddressLine2 && - o.Address.City == taxInfo.BillingAddressCity && - o.Address.State == taxInfo.BillingAddressState && - o.Coupon == "msp-discount-35" && - o.Description == WebUtility.HtmlDecode(provider.BusinessName) && - o.Email == provider.BillingEmail && - o.Expand.FirstOrDefault() == "tax" && - o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && - o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && - o.Metadata["region"] == "" && - o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && - o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) - .Returns(new Customer - { - Id = _customerId, - Tax = new CustomerTax - { - AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported - } - }); - - var providerPlans = new List - { - new () - { - PlanType = PlanType.TeamsMonthly, - SeatMinimum = 100 - }, - new () - { - PlanType = PlanType.EnterpriseMonthly, - SeatMinimum = 100 - } - }; - - sutProvider.GetDependency().GetByProviderId(provider.Id) - .Returns(providerPlans); - - var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); - var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); - - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( - sub => - sub.AutomaticTax.Enabled == true && - sub.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && - sub.Customer == _customerId && - sub.DaysUntilDue == 30 && - sub.Items.Count == 2 && - sub.Items.ElementAt(0).Price == teamsPlan.PasswordManager.StripeSeatPlanId && - sub.Items.ElementAt(0).Quantity == 100 && - sub.Items.ElementAt(1).Price == enterprisePlan.PasswordManager.StripeSeatPlanId && - sub.Items.ElementAt(1).Quantity == 100 && - sub.Metadata["providerId"] == provider.Id.ToString() && - sub.OffSession == true && - sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations)).Returns(new Subscription - { - Id = _subscriptionId, - Status = StripeConstants.SubscriptionStatus.Active - }); - - await sutProvider.Sut.StartSubscription(provider, taxInfo); - - await sutProvider.GetDependency().Received(2).ReplaceAsync(provider); - } - #endregion - - private static async Task DidNotCreateSubscriptionAsync(SutProvider sutProvider) => - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .SubscriptionCreateAsync(Arg.Any()); - - private static async Task DidNotRetrieveCustomerAsync(SutProvider sutProvider) => - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .CustomerGetAsync(Arg.Any(), Arg.Any()); - - private static async Task DidNotRetrieveProviderPlansAsync(SutProvider sutProvider) => - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .GetByProviderId(Arg.Any()); - - private static void SetCustomerRetrieval(SutProvider sutProvider, - Customer customer) => sutProvider.GetDependency() - .CustomerGetAsync(_customerId, Arg.Is(o => o.Expand.FirstOrDefault() == "tax")) - .Returns(customer); -} diff --git a/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs b/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs deleted file mode 100644 index afa361781..000000000 --- a/test/Core.Test/Billing/Queries/ProviderBillingQueriesTests.cs +++ /dev/null @@ -1,154 +0,0 @@ -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Entities; -using Bit.Core.Billing.Models; -using Bit.Core.Billing.Queries; -using Bit.Core.Billing.Queries.Implementations; -using Bit.Core.Billing.Repositories; -using Bit.Core.Enums; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using NSubstitute; -using NSubstitute.ReturnsExtensions; -using Stripe; -using Xunit; - -namespace Bit.Core.Test.Billing.Queries; - -[SutProviderCustomize] -public class ProviderBillingQueriesTests -{ - #region GetSubscriptionData - - [Theory, BitAutoData] - public async Task GetSubscriptionData_NullProvider_ReturnsNull( - SutProvider sutProvider, - Guid providerId) - { - var providerRepository = sutProvider.GetDependency(); - - providerRepository.GetByIdAsync(providerId).ReturnsNull(); - - var subscriptionData = await sutProvider.Sut.GetSubscriptionDTO(providerId); - - Assert.Null(subscriptionData); - - await providerRepository.Received(1).GetByIdAsync(providerId); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionData_NullSubscription_ReturnsNull( - SutProvider sutProvider, - Guid providerId, - Provider provider) - { - var providerRepository = sutProvider.GetDependency(); - - providerRepository.GetByIdAsync(providerId).Returns(provider); - - var subscriberQueries = sutProvider.GetDependency(); - - subscriberQueries.GetSubscription(provider).ReturnsNull(); - - var subscriptionData = await sutProvider.Sut.GetSubscriptionDTO(providerId); - - Assert.Null(subscriptionData); - - await providerRepository.Received(1).GetByIdAsync(providerId); - - await subscriberQueries.Received(1).GetSubscription( - provider, - Arg.Is( - options => options.Expand.Count == 1 && options.Expand.First() == "customer")); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionData_Success( - SutProvider sutProvider, - Guid providerId, - Provider provider) - { - var providerRepository = sutProvider.GetDependency(); - - providerRepository.GetByIdAsync(providerId).Returns(provider); - - var subscriberQueries = sutProvider.GetDependency(); - - var subscription = new Subscription(); - - subscriberQueries.GetSubscription(provider, Arg.Is( - options => options.Expand.Count == 1 && options.Expand.First() == "customer")).Returns(subscription); - - var providerPlanRepository = sutProvider.GetDependency(); - - var enterprisePlan = new ProviderPlan - { - Id = Guid.NewGuid(), - ProviderId = providerId, - PlanType = PlanType.EnterpriseMonthly, - SeatMinimum = 100, - PurchasedSeats = 0, - AllocatedSeats = 0 - }; - - var teamsPlan = new ProviderPlan - { - Id = Guid.NewGuid(), - ProviderId = providerId, - PlanType = PlanType.TeamsMonthly, - SeatMinimum = 50, - PurchasedSeats = 10, - AllocatedSeats = 60 - }; - - var providerPlans = new List - { - enterprisePlan, - teamsPlan, - }; - - providerPlanRepository.GetByProviderId(providerId).Returns(providerPlans); - - var subscriptionData = await sutProvider.Sut.GetSubscriptionDTO(providerId); - - Assert.NotNull(subscriptionData); - - Assert.Equivalent(subscriptionData.Subscription, subscription); - - Assert.Equal(2, subscriptionData.ProviderPlans.Count); - - var configuredEnterprisePlan = - subscriptionData.ProviderPlans.FirstOrDefault(configuredPlan => - configuredPlan.PlanType == PlanType.EnterpriseMonthly); - - var configuredTeamsPlan = - subscriptionData.ProviderPlans.FirstOrDefault(configuredPlan => - configuredPlan.PlanType == PlanType.TeamsMonthly); - - Compare(enterprisePlan, configuredEnterprisePlan); - - Compare(teamsPlan, configuredTeamsPlan); - - await providerRepository.Received(1).GetByIdAsync(providerId); - - await subscriberQueries.Received(1).GetSubscription( - provider, - Arg.Is( - options => options.Expand.Count == 1 && options.Expand.First() == "customer")); - - await providerPlanRepository.Received(1).GetByProviderId(providerId); - - return; - - void Compare(ProviderPlan providerPlan, ConfiguredProviderPlanDTO configuredProviderPlan) - { - Assert.NotNull(configuredProviderPlan); - Assert.Equal(providerPlan.Id, configuredProviderPlan.Id); - Assert.Equal(providerPlan.ProviderId, configuredProviderPlan.ProviderId); - Assert.Equal(providerPlan.SeatMinimum!.Value, configuredProviderPlan.SeatMinimum); - Assert.Equal(providerPlan.PurchasedSeats!.Value, configuredProviderPlan.PurchasedSeats); - Assert.Equal(providerPlan.AllocatedSeats!.Value, configuredProviderPlan.AssignedSeats); - } - } - #endregion -} diff --git a/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs b/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs deleted file mode 100644 index c1539e868..000000000 --- a/test/Core.Test/Billing/Queries/SubscriberQueriesTests.cs +++ /dev/null @@ -1,272 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Billing.Queries.Implementations; -using Bit.Core.Services; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using NSubstitute; -using NSubstitute.ExceptionExtensions; -using NSubstitute.ReturnsExtensions; -using Stripe; -using Xunit; - -using static Bit.Core.Test.Billing.Utilities; - -namespace Bit.Core.Test.Billing.Queries; - -[SutProviderCustomize] -public class SubscriberQueriesTests -{ - #region GetCustomer - [Theory, BitAutoData] - public async Task GetCustomer_NullSubscriber_ThrowsArgumentNullException( - SutProvider sutProvider) - => await Assert.ThrowsAsync( - async () => await sutProvider.Sut.GetCustomer(null)); - - [Theory, BitAutoData] - public async Task GetCustomer_NoGatewayCustomerId_ReturnsNull( - Organization organization, - SutProvider sutProvider) - { - organization.GatewayCustomerId = null; - - var customer = await sutProvider.Sut.GetCustomer(organization); - - Assert.Null(customer); - } - - [Theory, BitAutoData] - public async Task GetCustomer_NoCustomer_ReturnsNull( - Organization organization, - SutProvider sutProvider) - { - sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) - .ReturnsNull(); - - var customer = await sutProvider.Sut.GetCustomer(organization); - - Assert.Null(customer); - } - - [Theory, BitAutoData] - public async Task GetCustomer_StripeException_ReturnsNull( - Organization organization, - SutProvider sutProvider) - { - sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) - .ThrowsAsync(); - - var customer = await sutProvider.Sut.GetCustomer(organization); - - Assert.Null(customer); - } - - [Theory, BitAutoData] - public async Task GetCustomer_Succeeds( - Organization organization, - SutProvider sutProvider) - { - var customer = new Customer(); - - sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) - .Returns(customer); - - var gotCustomer = await sutProvider.Sut.GetCustomer(organization); - - Assert.Equivalent(customer, gotCustomer); - } - #endregion - - #region GetSubscription - [Theory, BitAutoData] - public async Task GetSubscription_NullSubscriber_ThrowsArgumentNullException( - SutProvider sutProvider) - => await Assert.ThrowsAsync( - async () => await sutProvider.Sut.GetSubscription(null)); - - [Theory, BitAutoData] - public async Task GetSubscription_NoGatewaySubscriptionId_ReturnsNull( - Organization organization, - SutProvider sutProvider) - { - organization.GatewaySubscriptionId = null; - - var subscription = await sutProvider.Sut.GetSubscription(organization); - - Assert.Null(subscription); - } - - [Theory, BitAutoData] - public async Task GetSubscription_NoSubscription_ReturnsNull( - Organization organization, - SutProvider sutProvider) - { - sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) - .ReturnsNull(); - - var subscription = await sutProvider.Sut.GetSubscription(organization); - - Assert.Null(subscription); - } - - [Theory, BitAutoData] - public async Task GetSubscription_StripeException_ReturnsNull( - Organization organization, - SutProvider sutProvider) - { - sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) - .ThrowsAsync(); - - var subscription = await sutProvider.Sut.GetSubscription(organization); - - Assert.Null(subscription); - } - - [Theory, BitAutoData] - public async Task GetSubscription_Succeeds( - Organization organization, - SutProvider sutProvider) - { - var subscription = new Subscription(); - - sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) - .Returns(subscription); - - var gotSubscription = await sutProvider.Sut.GetSubscription(organization); - - Assert.Equivalent(subscription, gotSubscription); - } - #endregion - - #region GetCustomerOrThrow - [Theory, BitAutoData] - public async Task GetCustomerOrThrow_NullSubscriber_ThrowsArgumentNullException( - SutProvider sutProvider) - => await Assert.ThrowsAsync( - async () => await sutProvider.Sut.GetCustomerOrThrow(null)); - - [Theory, BitAutoData] - public async Task GetCustomerOrThrow_NoGatewayCustomerId_ContactSupport( - Organization organization, - SutProvider sutProvider) - { - organization.GatewayCustomerId = null; - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization)); - } - - [Theory, BitAutoData] - public async Task GetCustomerOrThrow_NoCustomer_ContactSupport( - Organization organization, - SutProvider sutProvider) - { - sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) - .ReturnsNull(); - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization)); - } - - [Theory, BitAutoData] - public async Task GetCustomerOrThrow_StripeException_ContactSupport( - Organization organization, - SutProvider sutProvider) - { - var stripeException = new StripeException(); - - sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) - .ThrowsAsync(stripeException); - - await ThrowsContactSupportAsync( - async () => await sutProvider.Sut.GetCustomerOrThrow(organization), - "An error occurred while trying to retrieve a Stripe Customer", - stripeException); - } - - [Theory, BitAutoData] - public async Task GetCustomerOrThrow_Succeeds( - Organization organization, - SutProvider sutProvider) - { - var customer = new Customer(); - - sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) - .Returns(customer); - - var gotCustomer = await sutProvider.Sut.GetCustomerOrThrow(organization); - - Assert.Equivalent(customer, gotCustomer); - } - #endregion - - #region GetSubscriptionOrThrow - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_NullSubscriber_ThrowsArgumentNullException( - SutProvider sutProvider) - => await Assert.ThrowsAsync( - async () => await sutProvider.Sut.GetSubscriptionOrThrow(null)); - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_NoGatewaySubscriptionId_ContactSupport( - Organization organization, - SutProvider sutProvider) - { - organization.GatewaySubscriptionId = null; - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_NoSubscription_ContactSupport( - Organization organization, - SutProvider sutProvider) - { - sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) - .ReturnsNull(); - - await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_StripeException_ContactSupport( - Organization organization, - SutProvider sutProvider) - { - var stripeException = new StripeException(); - - sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) - .ThrowsAsync(stripeException); - - await ThrowsContactSupportAsync( - async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization), - "An error occurred while trying to retrieve a Stripe Subscription", - stripeException); - } - - [Theory, BitAutoData] - public async Task GetSubscriptionOrThrow_Succeeds( - Organization organization, - SutProvider sutProvider) - { - var subscription = new Subscription(); - - sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) - .Returns(subscription); - - var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(organization); - - Assert.Equivalent(subscription, gotSubscription); - } - #endregion -} diff --git a/test/Core.Test/Billing/Queries/OrganizationBillingQueriesTests.cs b/test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs similarity index 81% rename from test/Core.Test/Billing/Queries/OrganizationBillingQueriesTests.cs rename to test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs index f98bf58e5..2e1782f5c 100644 --- a/test/Core.Test/Billing/Queries/OrganizationBillingQueriesTests.cs +++ b/test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs @@ -1,7 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Constants; -using Bit.Core.Billing.Queries; -using Bit.Core.Billing.Queries.Implementations; +using Bit.Core.Billing.Services; +using Bit.Core.Billing.Services.Implementations; using Bit.Core.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -9,16 +9,16 @@ using NSubstitute; using Stripe; using Xunit; -namespace Bit.Core.Test.Billing.Queries; +namespace Bit.Core.Test.Billing.Services; [SutProviderCustomize] -public class OrganizationBillingQueriesTests +public class OrganizationBillingServiceTests { #region GetMetadata [Theory, BitAutoData] public async Task GetMetadata_OrganizationNull_ReturnsNull( Guid organizationId, - SutProvider sutProvider) + SutProvider sutProvider) { var metadata = await sutProvider.Sut.GetMetadata(organizationId); @@ -29,7 +29,7 @@ public class OrganizationBillingQueriesTests public async Task GetMetadata_CustomerNull_ReturnsNull( Guid organizationId, Organization organization, - SutProvider sutProvider) + SutProvider sutProvider) { sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(organization); @@ -42,11 +42,11 @@ public class OrganizationBillingQueriesTests public async Task GetMetadata_SubscriptionNull_ReturnsNull( Guid organizationId, Organization organization, - SutProvider sutProvider) + SutProvider sutProvider) { sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(organization); - sutProvider.GetDependency().GetCustomer(organization).Returns(new Customer()); + sutProvider.GetDependency().GetCustomer(organization).Returns(new Customer()); var metadata = await sutProvider.Sut.GetMetadata(organizationId); @@ -57,13 +57,13 @@ public class OrganizationBillingQueriesTests public async Task GetMetadata_Succeeds( Guid organizationId, Organization organization, - SutProvider sutProvider) + SutProvider sutProvider) { sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(organization); - var subscriberQueries = sutProvider.GetDependency(); + var subscriberService = sutProvider.GetDependency(); - subscriberQueries + subscriberService .GetCustomer(organization, Arg.Is(options => options.Expand.FirstOrDefault() == "discount.coupon.applies_to")) .Returns(new Customer { @@ -80,7 +80,7 @@ public class OrganizationBillingQueriesTests } }); - subscriberQueries.GetSubscription(organization).Returns(new Subscription + subscriberService.GetSubscription(organization).Returns(new Subscription { Items = new StripeList { diff --git a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs new file mode 100644 index 000000000..f052fb92d --- /dev/null +++ b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs @@ -0,0 +1,881 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Billing; +using Bit.Core.Billing.Models; +using Bit.Core.Billing.Services.Implementations; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Braintree; +using NSubstitute; +using NSubstitute.ExceptionExtensions; +using NSubstitute.ReturnsExtensions; +using Stripe; +using Xunit; + +using static Bit.Core.Test.Billing.Utilities; +using Customer = Stripe.Customer; +using PaymentMethod = Stripe.PaymentMethod; +using Subscription = Stripe.Subscription; + +namespace Bit.Core.Test.Billing.Services; + +[SutProviderCustomize] +public class SubscriberServiceTests +{ + #region CancelSubscription + [Theory, BitAutoData] + public async Task CancelSubscription_SubscriptionInactive_ContactSupport( + Organization organization, + SutProvider sutProvider) + { + var subscription = new Subscription + { + Status = "canceled" + }; + + var stripeAdapter = sutProvider.GetDependency(); + + stripeAdapter + .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .Returns(subscription); + + await ThrowsContactSupportAsync(() => + sutProvider.Sut.CancelSubscription(organization, new OffboardingSurveyResponse(), false)); + + await stripeAdapter + .DidNotReceiveWithAnyArgs() + .SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); + + await stripeAdapter + .DidNotReceiveWithAnyArgs() + .SubscriptionCancelAsync(Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task CancelSubscription_CancelImmediately_BelongsToOrganization_UpdatesSubscription_CancelSubscriptionImmediately( + Organization organization, + SutProvider sutProvider) + { + var userId = Guid.NewGuid(); + + const string subscriptionId = "subscription_id"; + + var subscription = new Subscription + { + Id = subscriptionId, + Status = "active", + Metadata = new Dictionary + { + { "organizationId", "organization_id" } + } + }; + + var stripeAdapter = sutProvider.GetDependency(); + + stripeAdapter + .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .Returns(subscription); + + var offboardingSurveyResponse = new OffboardingSurveyResponse + { + UserId = userId, + Reason = "missing_features", + Feedback = "Lorem ipsum" + }; + + await sutProvider.Sut.CancelSubscription(organization, offboardingSurveyResponse, true); + + await stripeAdapter + .Received(1) + .SubscriptionUpdateAsync(subscriptionId, Arg.Is( + options => options.Metadata["cancellingUserId"] == userId.ToString())); + + await stripeAdapter + .Received(1) + .SubscriptionCancelAsync(subscriptionId, Arg.Is(options => + options.CancellationDetails.Comment == offboardingSurveyResponse.Feedback && + options.CancellationDetails.Feedback == offboardingSurveyResponse.Reason)); + } + + [Theory, BitAutoData] + public async Task CancelSubscription_CancelImmediately_BelongsToUser_CancelSubscriptionImmediately( + Organization organization, + SutProvider sutProvider) + { + var userId = Guid.NewGuid(); + + const string subscriptionId = "subscription_id"; + + var subscription = new Subscription + { + Id = subscriptionId, + Status = "active", + Metadata = new Dictionary + { + { "userId", "user_id" } + } + }; + + var stripeAdapter = sutProvider.GetDependency(); + + stripeAdapter + .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .Returns(subscription); + + var offboardingSurveyResponse = new OffboardingSurveyResponse + { + UserId = userId, + Reason = "missing_features", + Feedback = "Lorem ipsum" + }; + + await sutProvider.Sut.CancelSubscription(organization, offboardingSurveyResponse, true); + + await stripeAdapter + .DidNotReceiveWithAnyArgs() + .SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); + + await stripeAdapter + .Received(1) + .SubscriptionCancelAsync(subscriptionId, Arg.Is(options => + options.CancellationDetails.Comment == offboardingSurveyResponse.Feedback && + options.CancellationDetails.Feedback == offboardingSurveyResponse.Reason)); + } + + [Theory, BitAutoData] + public async Task CancelSubscription_DoNotCancelImmediately_UpdateSubscriptionToCancelAtEndOfPeriod( + Organization organization, + SutProvider sutProvider) + { + var userId = Guid.NewGuid(); + + const string subscriptionId = "subscription_id"; + + organization.ExpirationDate = DateTime.UtcNow.AddDays(5); + + var subscription = new Subscription + { + Id = subscriptionId, + Status = "active" + }; + + var stripeAdapter = sutProvider.GetDependency(); + + stripeAdapter + .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .Returns(subscription); + + var offboardingSurveyResponse = new OffboardingSurveyResponse + { + UserId = userId, + Reason = "missing_features", + Feedback = "Lorem ipsum" + }; + + await sutProvider.Sut.CancelSubscription(organization, offboardingSurveyResponse, false); + + await stripeAdapter + .Received(1) + .SubscriptionUpdateAsync(subscriptionId, Arg.Is(options => + options.CancelAtPeriodEnd == true && + options.CancellationDetails.Comment == offboardingSurveyResponse.Feedback && + options.CancellationDetails.Feedback == offboardingSurveyResponse.Reason && + options.Metadata["cancellingUserId"] == userId.ToString())); + + await stripeAdapter + .DidNotReceiveWithAnyArgs() + .SubscriptionCancelAsync(Arg.Any(), Arg.Any()); ; + } + #endregion + + #region GetCustomer + [Theory, BitAutoData] + public async Task GetCustomer_NullSubscriber_ThrowsArgumentNullException( + SutProvider sutProvider) + => await Assert.ThrowsAsync( + async () => await sutProvider.Sut.GetCustomer(null)); + + [Theory, BitAutoData] + public async Task GetCustomer_NoGatewayCustomerId_ReturnsNull( + Organization organization, + SutProvider sutProvider) + { + organization.GatewayCustomerId = null; + + var customer = await sutProvider.Sut.GetCustomer(organization); + + Assert.Null(customer); + } + + [Theory, BitAutoData] + public async Task GetCustomer_NoCustomer_ReturnsNull( + Organization organization, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .CustomerGetAsync(organization.GatewayCustomerId) + .ReturnsNull(); + + var customer = await sutProvider.Sut.GetCustomer(organization); + + Assert.Null(customer); + } + + [Theory, BitAutoData] + public async Task GetCustomer_StripeException_ReturnsNull( + Organization organization, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .CustomerGetAsync(organization.GatewayCustomerId) + .ThrowsAsync(); + + var customer = await sutProvider.Sut.GetCustomer(organization); + + Assert.Null(customer); + } + + [Theory, BitAutoData] + public async Task GetCustomer_Succeeds( + Organization organization, + SutProvider sutProvider) + { + var customer = new Customer(); + + sutProvider.GetDependency() + .CustomerGetAsync(organization.GatewayCustomerId) + .Returns(customer); + + var gotCustomer = await sutProvider.Sut.GetCustomer(organization); + + Assert.Equivalent(customer, gotCustomer); + } + #endregion + + #region GetCustomerOrThrow + [Theory, BitAutoData] + public async Task GetCustomerOrThrow_NullSubscriber_ThrowsArgumentNullException( + SutProvider sutProvider) + => await Assert.ThrowsAsync( + async () => await sutProvider.Sut.GetCustomerOrThrow(null)); + + [Theory, BitAutoData] + public async Task GetCustomerOrThrow_NoGatewayCustomerId_ContactSupport( + Organization organization, + SutProvider sutProvider) + { + organization.GatewayCustomerId = null; + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization)); + } + + [Theory, BitAutoData] + public async Task GetCustomerOrThrow_NoCustomer_ContactSupport( + Organization organization, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .CustomerGetAsync(organization.GatewayCustomerId) + .ReturnsNull(); + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization)); + } + + [Theory, BitAutoData] + public async Task GetCustomerOrThrow_StripeException_ContactSupport( + Organization organization, + SutProvider sutProvider) + { + var stripeException = new StripeException(); + + sutProvider.GetDependency() + .CustomerGetAsync(organization.GatewayCustomerId) + .ThrowsAsync(stripeException); + + await ThrowsContactSupportAsync( + async () => await sutProvider.Sut.GetCustomerOrThrow(organization), + "An error occurred while trying to retrieve a Stripe Customer", + stripeException); + } + + [Theory, BitAutoData] + public async Task GetCustomerOrThrow_Succeeds( + Organization organization, + SutProvider sutProvider) + { + var customer = new Customer(); + + sutProvider.GetDependency() + .CustomerGetAsync(organization.GatewayCustomerId) + .Returns(customer); + + var gotCustomer = await sutProvider.Sut.GetCustomerOrThrow(organization); + + Assert.Equivalent(customer, gotCustomer); + } + #endregion + + #region GetSubscription + [Theory, BitAutoData] + public async Task GetSubscription_NullSubscriber_ThrowsArgumentNullException( + SutProvider sutProvider) + => await Assert.ThrowsAsync( + async () => await sutProvider.Sut.GetSubscription(null)); + + [Theory, BitAutoData] + public async Task GetSubscription_NoGatewaySubscriptionId_ReturnsNull( + Organization organization, + SutProvider sutProvider) + { + organization.GatewaySubscriptionId = null; + + var subscription = await sutProvider.Sut.GetSubscription(organization); + + Assert.Null(subscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_NoSubscription_ReturnsNull( + Organization organization, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .ReturnsNull(); + + var subscription = await sutProvider.Sut.GetSubscription(organization); + + Assert.Null(subscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_StripeException_ReturnsNull( + Organization organization, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .ThrowsAsync(); + + var subscription = await sutProvider.Sut.GetSubscription(organization); + + Assert.Null(subscription); + } + + [Theory, BitAutoData] + public async Task GetSubscription_Succeeds( + Organization organization, + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency() + .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .Returns(subscription); + + var gotSubscription = await sutProvider.Sut.GetSubscription(organization); + + Assert.Equivalent(subscription, gotSubscription); + } + #endregion + + #region GetSubscriptionOrThrow + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_NullSubscriber_ThrowsArgumentNullException( + SutProvider sutProvider) + => await Assert.ThrowsAsync( + async () => await sutProvider.Sut.GetSubscriptionOrThrow(null)); + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_NoGatewaySubscriptionId_ContactSupport( + Organization organization, + SutProvider sutProvider) + { + organization.GatewaySubscriptionId = null; + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_NoSubscription_ContactSupport( + Organization organization, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .ReturnsNull(); + + await ThrowsContactSupportAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_StripeException_ContactSupport( + Organization organization, + SutProvider sutProvider) + { + var stripeException = new StripeException(); + + sutProvider.GetDependency() + .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .ThrowsAsync(stripeException); + + await ThrowsContactSupportAsync( + async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization), + "An error occurred while trying to retrieve a Stripe Subscription", + stripeException); + } + + [Theory, BitAutoData] + public async Task GetSubscriptionOrThrow_Succeeds( + Organization organization, + SutProvider sutProvider) + { + var subscription = new Subscription(); + + sutProvider.GetDependency() + .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .Returns(subscription); + + var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(organization); + + Assert.Equivalent(subscription, gotSubscription); + } + #endregion + + #region RemovePaymentMethod + [Theory, BitAutoData] + public async Task RemovePaymentMethod_NullSubscriber_ArgumentNullException( + SutProvider sutProvider) => + await Assert.ThrowsAsync(() => sutProvider.Sut.RemovePaymentMethod(null)); + + [Theory, BitAutoData] + public async Task RemovePaymentMethod_Braintree_NoCustomer_ContactSupport( + Organization organization, + SutProvider sutProvider) + { + const string braintreeCustomerId = "1"; + + var stripeCustomer = new Customer + { + Metadata = new Dictionary + { + { "btCustomerId", braintreeCustomerId } + } + }; + + sutProvider.GetDependency() + .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .Returns(stripeCustomer); + + var (braintreeGateway, customerGateway, paymentMethodGateway) = SetupBraintree(sutProvider.GetDependency()); + + customerGateway.FindAsync(braintreeCustomerId).ReturnsNull(); + + braintreeGateway.Customer.Returns(customerGateway); + + await ThrowsContactSupportAsync(() => sutProvider.Sut.RemovePaymentMethod(organization)); + + await customerGateway.Received(1).FindAsync(braintreeCustomerId); + + await customerGateway.DidNotReceiveWithAnyArgs() + .UpdateAsync(Arg.Any(), Arg.Any()); + + await paymentMethodGateway.DidNotReceiveWithAnyArgs().DeleteAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task RemovePaymentMethod_Braintree_NoPaymentMethod_NoOp( + Organization organization, + SutProvider sutProvider) + { + const string braintreeCustomerId = "1"; + + var stripeCustomer = new Customer + { + Metadata = new Dictionary + { + { "btCustomerId", braintreeCustomerId } + } + }; + + sutProvider.GetDependency() + .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .Returns(stripeCustomer); + + var (_, customerGateway, paymentMethodGateway) = SetupBraintree(sutProvider.GetDependency()); + + var braintreeCustomer = Substitute.For(); + + braintreeCustomer.PaymentMethods.Returns([]); + + customerGateway.FindAsync(braintreeCustomerId).Returns(braintreeCustomer); + + await sutProvider.Sut.RemovePaymentMethod(organization); + + await customerGateway.Received(1).FindAsync(braintreeCustomerId); + + await customerGateway.DidNotReceiveWithAnyArgs().UpdateAsync(Arg.Any(), Arg.Any()); + + await paymentMethodGateway.DidNotReceiveWithAnyArgs().DeleteAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task RemovePaymentMethod_Braintree_CustomerUpdateFails_ContactSupport( + Organization organization, + SutProvider sutProvider) + { + const string braintreeCustomerId = "1"; + const string braintreePaymentMethodToken = "TOKEN"; + + var stripeCustomer = new Customer + { + Metadata = new Dictionary + { + { "btCustomerId", braintreeCustomerId } + } + }; + + sutProvider.GetDependency() + .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .Returns(stripeCustomer); + + var (_, customerGateway, paymentMethodGateway) = SetupBraintree(sutProvider.GetDependency()); + + var braintreeCustomer = Substitute.For(); + + var paymentMethod = Substitute.For(); + paymentMethod.Token.Returns(braintreePaymentMethodToken); + paymentMethod.IsDefault.Returns(true); + + braintreeCustomer.PaymentMethods.Returns([ + paymentMethod + ]); + + customerGateway.FindAsync(braintreeCustomerId).Returns(braintreeCustomer); + + var updateBraintreeCustomerResult = Substitute.For>(); + updateBraintreeCustomerResult.IsSuccess().Returns(false); + + customerGateway.UpdateAsync( + braintreeCustomerId, + Arg.Is(request => request.DefaultPaymentMethodToken == null)) + .Returns(updateBraintreeCustomerResult); + + await ThrowsContactSupportAsync(() => sutProvider.Sut.RemovePaymentMethod(organization)); + + await customerGateway.Received(1).FindAsync(braintreeCustomerId); + + await customerGateway.Received(1).UpdateAsync(braintreeCustomerId, Arg.Is(request => + request.DefaultPaymentMethodToken == null)); + + await paymentMethodGateway.DidNotReceiveWithAnyArgs().DeleteAsync(paymentMethod.Token); + + await customerGateway.DidNotReceive().UpdateAsync(braintreeCustomerId, Arg.Is(request => + request.DefaultPaymentMethodToken == paymentMethod.Token)); + } + + [Theory, BitAutoData] + public async Task RemovePaymentMethod_Braintree_PaymentMethodDeleteFails_RollBack_ContactSupport( + Organization organization, + SutProvider sutProvider) + { + const string braintreeCustomerId = "1"; + const string braintreePaymentMethodToken = "TOKEN"; + + var stripeCustomer = new Customer + { + Metadata = new Dictionary + { + { "btCustomerId", braintreeCustomerId } + } + }; + + sutProvider.GetDependency() + .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .Returns(stripeCustomer); + + var (_, customerGateway, paymentMethodGateway) = SetupBraintree(sutProvider.GetDependency()); + + var braintreeCustomer = Substitute.For(); + + var paymentMethod = Substitute.For(); + paymentMethod.Token.Returns(braintreePaymentMethodToken); + paymentMethod.IsDefault.Returns(true); + + braintreeCustomer.PaymentMethods.Returns([ + paymentMethod + ]); + + customerGateway.FindAsync(braintreeCustomerId).Returns(braintreeCustomer); + + var updateBraintreeCustomerResult = Substitute.For>(); + updateBraintreeCustomerResult.IsSuccess().Returns(true); + + customerGateway.UpdateAsync(braintreeCustomerId, Arg.Any()) + .Returns(updateBraintreeCustomerResult); + + var deleteBraintreePaymentMethodResult = Substitute.For>(); + deleteBraintreePaymentMethodResult.IsSuccess().Returns(false); + + paymentMethodGateway.DeleteAsync(paymentMethod.Token).Returns(deleteBraintreePaymentMethodResult); + + await ThrowsContactSupportAsync(() => sutProvider.Sut.RemovePaymentMethod(organization)); + + await customerGateway.Received(1).FindAsync(braintreeCustomerId); + + await customerGateway.Received(1).UpdateAsync(braintreeCustomerId, Arg.Is(request => + request.DefaultPaymentMethodToken == null)); + + await paymentMethodGateway.Received(1).DeleteAsync(paymentMethod.Token); + + await customerGateway.Received(1).UpdateAsync(braintreeCustomerId, Arg.Is(request => + request.DefaultPaymentMethodToken == paymentMethod.Token)); + } + + [Theory, BitAutoData] + public async Task RemovePaymentMethod_Stripe_Legacy_RemovesSources( + Organization organization, + SutProvider sutProvider) + { + const string bankAccountId = "bank_account_id"; + const string cardId = "card_id"; + + var sources = new List + { + new BankAccount { Id = bankAccountId }, new Card { Id = cardId } + }; + + var stripeCustomer = new Customer { Sources = new StripeList { Data = sources } }; + + var stripeAdapter = sutProvider.GetDependency(); + + stripeAdapter + .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .Returns(stripeCustomer); + + stripeAdapter + .PaymentMethodListAutoPagingAsync(Arg.Any()) + .Returns(GetPaymentMethodsAsync(new List())); + + await sutProvider.Sut.RemovePaymentMethod(organization); + + await stripeAdapter.Received(1).BankAccountDeleteAsync(stripeCustomer.Id, bankAccountId); + + await stripeAdapter.Received(1).CardDeleteAsync(stripeCustomer.Id, cardId); + + await stripeAdapter.DidNotReceiveWithAnyArgs() + .PaymentMethodDetachAsync(Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task RemovePaymentMethod_Stripe_DetachesPaymentMethods( + Organization organization, + SutProvider sutProvider) + { + const string bankAccountId = "bank_account_id"; + const string cardId = "card_id"; + + var sources = new List(); + + var stripeCustomer = new Customer { Sources = new StripeList { Data = sources } }; + + var stripeAdapter = sutProvider.GetDependency(); + + stripeAdapter + .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .Returns(stripeCustomer); + + stripeAdapter + .PaymentMethodListAutoPagingAsync(Arg.Any()) + .Returns(GetPaymentMethodsAsync(new List + { + new () + { + Id = bankAccountId + }, + new () + { + Id = cardId + } + })); + + await sutProvider.Sut.RemovePaymentMethod(organization); + + await stripeAdapter.DidNotReceiveWithAnyArgs().BankAccountDeleteAsync(Arg.Any(), Arg.Any()); + + await stripeAdapter.DidNotReceiveWithAnyArgs().CardDeleteAsync(Arg.Any(), Arg.Any()); + + await stripeAdapter.Received(1) + .PaymentMethodDetachAsync(bankAccountId); + + await stripeAdapter.Received(1) + .PaymentMethodDetachAsync(cardId); + } + + private static async IAsyncEnumerable GetPaymentMethodsAsync( + IEnumerable paymentMethods) + { + foreach (var paymentMethod in paymentMethods) + { + yield return paymentMethod; + } + + await Task.CompletedTask; + } + + private static (IBraintreeGateway, ICustomerGateway, IPaymentMethodGateway) SetupBraintree( + IBraintreeGateway braintreeGateway) + { + var customerGateway = Substitute.For(); + var paymentMethodGateway = Substitute.For(); + + braintreeGateway.Customer.Returns(customerGateway); + braintreeGateway.PaymentMethod.Returns(paymentMethodGateway); + + return (braintreeGateway, customerGateway, paymentMethodGateway); + } + #endregion + + #region GetTaxInformationAsync + [Theory, BitAutoData] + public async Task GetTaxInformationAsync_NullSubscriber_ThrowsArgumentNullException( + SutProvider sutProvider) + => await Assert.ThrowsAsync( + async () => await sutProvider.Sut.GetTaxInformationAsync(null)); + + [Theory, BitAutoData] + public async Task GetTaxInformationAsync_NoGatewayCustomerId_ReturnsNull( + Provider subscriber, + SutProvider sutProvider) + { + subscriber.GatewayCustomerId = null; + + var taxInfo = await sutProvider.Sut.GetTaxInformationAsync(subscriber); + + Assert.Null(taxInfo); + } + + [Theory, BitAutoData] + public async Task GetTaxInformationAsync_NoCustomer_ReturnsNull( + Provider subscriber, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .CustomerGetAsync(subscriber.GatewayCustomerId, Arg.Any()) + .Returns((Customer)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.GetTaxInformationAsync(subscriber)); + } + + [Theory, BitAutoData] + public async Task GetTaxInformationAsync_StripeException_ReturnsNull( + Provider subscriber, + SutProvider sutProvider) + { + sutProvider.GetDependency() + .CustomerGetAsync(subscriber.GatewayCustomerId, Arg.Any()) + .ThrowsAsync(new StripeException()); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.GetTaxInformationAsync(subscriber)); + } + + [Theory, BitAutoData] + public async Task GetTaxInformationAsync_Succeeds( + Provider subscriber, + SutProvider sutProvider) + { + var customer = new Customer + { + Address = new Stripe.Address + { + Line1 = "123 Main St", + Line2 = "Apt 4B", + City = "Metropolis", + State = "NY", + PostalCode = "12345", + Country = "US" + } + }; + + sutProvider.GetDependency() + .CustomerGetAsync(subscriber.GatewayCustomerId, Arg.Any()) + .Returns(customer); + + var taxInfo = await sutProvider.Sut.GetTaxInformationAsync(subscriber); + + Assert.NotNull(taxInfo); + Assert.Equal("123 Main St", taxInfo.BillingAddressLine1); + Assert.Equal("Apt 4B", taxInfo.BillingAddressLine2); + Assert.Equal("Metropolis", taxInfo.BillingAddressCity); + Assert.Equal("NY", taxInfo.BillingAddressState); + Assert.Equal("12345", taxInfo.BillingAddressPostalCode); + Assert.Equal("US", taxInfo.BillingAddressCountry); + } + #endregion + + #region GetPaymentMethodAsync + [Theory, BitAutoData] + public async Task GetPaymentMethodAsync_NullSubscriber_ThrowsArgumentNullException( + SutProvider sutProvider) + { + await Assert.ThrowsAsync( + async () => await sutProvider.Sut.GetPaymentMethodAsync(null)); + } + + [Theory, BitAutoData] + public async Task GetPaymentMethodAsync_NoCustomer_ReturnsNull( + Provider subscriber, + SutProvider sutProvider) + { + subscriber.GatewayCustomerId = null; + sutProvider.GetDependency() + .CustomerGetAsync(subscriber.GatewayCustomerId, Arg.Any()) + .Returns((Customer)null); + + await Assert.ThrowsAsync(() => sutProvider.Sut.GetPaymentMethodAsync(subscriber)); + } + + [Theory, BitAutoData] + public async Task GetPaymentMethodAsync_StripeCardPaymentMethod_ReturnsBillingSource( + Provider subscriber, + SutProvider sutProvider) + { + var customer = new Customer(); + var paymentMethod = CreateSamplePaymentMethod(); + subscriber.GatewayCustomerId = "test_customer_id"; + customer.InvoiceSettings = new CustomerInvoiceSettings + { + DefaultPaymentMethod = paymentMethod + }; + + sutProvider.GetDependency() + .CustomerGetAsync(subscriber.GatewayCustomerId, Arg.Any()) + .Returns(customer); + + var billingSource = await sutProvider.Sut.GetPaymentMethodAsync(subscriber); + + Assert.NotNull(billingSource); + Assert.Equal(paymentMethod.Card.Brand, billingSource.CardBrand); + } + + private static PaymentMethod CreateSamplePaymentMethod() + { + var paymentMethod = new PaymentMethod + { + Id = "pm_test123", + Type = "card", + Card = new PaymentMethodCard + { + Brand = "visa", + Last4 = "4242", + ExpMonth = 12, + ExpYear = 2024 + } + }; + return paymentMethod; + } + #endregion +}