diff --git a/src/Api/AdminConsole/Controllers/OrganizationsController.cs b/src/Api/AdminConsole/Controllers/OrganizationsController.cs index 822f9635e..7231f29c4 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationsController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationsController.cs @@ -724,7 +724,7 @@ public class OrganizationsController : Controller [HttpPut("{id}/tax")] [SelfHosted(NotSelfHostedOnly = true)] - public async Task PutTaxInfo(string id, [FromBody] OrganizationTaxInfoUpdateRequestModel model) + public async Task PutTaxInfo(string id, [FromBody] ExpandedTaxInfoUpdateRequestModel model) { var orgIdGuid = new Guid(id); if (!await _currentContext.OrganizationOwner(orgIdGuid)) diff --git a/src/Api/AdminConsole/Controllers/ProvidersController.cs b/src/Api/AdminConsole/Controllers/ProvidersController.cs index cd39a90a8..9039779f1 100644 --- a/src/Api/AdminConsole/Controllers/ProvidersController.cs +++ b/src/Api/AdminConsole/Controllers/ProvidersController.cs @@ -1,9 +1,12 @@ using Bit.Api.AdminConsole.Models.Request.Providers; using Bit.Api.AdminConsole.Models.Response.Providers; +using Bit.Core; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; +using Bit.Core.Billing.Commands; using Bit.Core.Context; using Bit.Core.Exceptions; +using Bit.Core.Models.Business; using Bit.Core.Services; using Bit.Core.Settings; using Microsoft.AspNetCore.Authorization; @@ -20,15 +23,23 @@ public class ProvidersController : Controller private readonly IProviderService _providerService; private readonly ICurrentContext _currentContext; private readonly GlobalSettings _globalSettings; + private readonly IFeatureService _featureService; + private readonly IStartSubscriptionCommand _startSubscriptionCommand; + private readonly ILogger _logger; public ProvidersController(IUserService userService, IProviderRepository providerRepository, - IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings) + IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings, + IFeatureService featureService, IStartSubscriptionCommand startSubscriptionCommand, + ILogger logger) { _userService = userService; _providerRepository = providerRepository; _providerService = providerService; _currentContext = currentContext; _globalSettings = globalSettings; + _featureService = featureService; + _startSubscriptionCommand = startSubscriptionCommand; + _logger = logger; } [HttpGet("{id:guid}")] @@ -86,6 +97,30 @@ public class ProvidersController : Controller var response = await _providerService.CompleteSetupAsync(model.ToProvider(provider), userId, model.Token, model.Key); + if (_featureService.IsEnabled(FeatureFlagKeys.EnableConsolidatedBilling)) + { + var taxInfo = new TaxInfo + { + BillingAddressCountry = model.TaxInfo.Country, + BillingAddressPostalCode = model.TaxInfo.PostalCode, + TaxIdNumber = model.TaxInfo.TaxId, + BillingAddressLine1 = model.TaxInfo.Line1, + BillingAddressLine2 = model.TaxInfo.Line2, + BillingAddressCity = model.TaxInfo.City, + BillingAddressState = model.TaxInfo.State + }; + + try + { + await _startSubscriptionCommand.StartSubscription(provider, taxInfo); + } + catch + { + // We don't want to trap the user on the setup page, so we'll let this go through but the provider will be in an un-billable state. + _logger.LogError("Failed to create subscription for provider with ID {ID} during setup", provider.Id); + } + } + return new ProviderResponseModel(response); } } diff --git a/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs b/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs index f68d3b92a..5e10807c6 100644 --- a/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs @@ -1,5 +1,6 @@ using System.ComponentModel.DataAnnotations; using System.Text.Json.Serialization; +using Bit.Api.Models.Request; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Utilities; @@ -22,6 +23,7 @@ public class ProviderSetupRequestModel public string Token { get; set; } [Required] public string Key { get; set; } + public ExpandedTaxInfoUpdateRequestModel TaxInfo { get; set; } public virtual Provider ToProvider(Provider provider) { diff --git a/src/Api/Models/Request/Organizations/OrganizationTaxInfoUpdateRequestModel.cs b/src/Api/Models/Request/ExpandedTaxInfoUpdateRequestModel.cs similarity index 65% rename from src/Api/Models/Request/Organizations/OrganizationTaxInfoUpdateRequestModel.cs rename to src/Api/Models/Request/ExpandedTaxInfoUpdateRequestModel.cs index c20fa07af..7f95d755a 100644 --- a/src/Api/Models/Request/Organizations/OrganizationTaxInfoUpdateRequestModel.cs +++ b/src/Api/Models/Request/ExpandedTaxInfoUpdateRequestModel.cs @@ -1,8 +1,8 @@ using Bit.Api.Models.Request.Accounts; -namespace Bit.Api.Models.Request.Organizations; +namespace Bit.Api.Models.Request; -public class OrganizationTaxInfoUpdateRequestModel : TaxInfoUpdateRequestModel +public class ExpandedTaxInfoUpdateRequestModel : TaxInfoUpdateRequestModel { public string TaxId { get; set; } public string Line1 { get; set; } diff --git a/src/Api/Models/Request/PaymentRequestModel.cs b/src/Api/Models/Request/PaymentRequestModel.cs index 47e39b010..eae1abfce 100644 --- a/src/Api/Models/Request/PaymentRequestModel.cs +++ b/src/Api/Models/Request/PaymentRequestModel.cs @@ -1,10 +1,9 @@ using System.ComponentModel.DataAnnotations; -using Bit.Api.Models.Request.Organizations; using Bit.Core.Enums; namespace Bit.Api.Models.Request; -public class PaymentRequestModel : OrganizationTaxInfoUpdateRequestModel +public class PaymentRequestModel : ExpandedTaxInfoUpdateRequestModel { [Required] public PaymentMethodType? PaymentMethodType { get; set; } diff --git a/src/Billing/Controllers/StripeController.cs b/src/Billing/Controllers/StripeController.cs index 7fc5189e5..1395b4081 100644 --- a/src/Billing/Controllers/StripeController.cs +++ b/src/Billing/Controllers/StripeController.cs @@ -255,7 +255,7 @@ public class StripeController : Controller customerGetOptions.AddExpand("tax"); var customer = await _stripeFacade.GetCustomer(subscription.CustomerId, customerGetOptions); if (!subscription.AutomaticTax.Enabled && - customer.Tax?.AutomaticTax == StripeCustomerAutomaticTaxStatus.Supported) + customer.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported) { subscription = await _stripeFacade.UpdateSubscription(subscription.Id, new SubscriptionUpdateOptions diff --git a/src/Core/AdminConsole/Enums/Provider/ProviderStatusType.cs b/src/Core/AdminConsole/Enums/Provider/ProviderStatusType.cs index 794bd36bf..1beedebf5 100644 --- a/src/Core/AdminConsole/Enums/Provider/ProviderStatusType.cs +++ b/src/Core/AdminConsole/Enums/Provider/ProviderStatusType.cs @@ -4,4 +4,5 @@ public enum ProviderStatusType : byte { Pending = 0, Created = 1, + Billable = 2 } diff --git a/src/Core/Billing/Commands/IStartSubscriptionCommand.cs b/src/Core/Billing/Commands/IStartSubscriptionCommand.cs new file mode 100644 index 000000000..9a5ce7d79 --- /dev/null +++ b/src/Core/Billing/Commands/IStartSubscriptionCommand.cs @@ -0,0 +1,11 @@ +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Models.Business; + +namespace Bit.Core.Billing.Commands; + +public interface IStartSubscriptionCommand +{ + Task StartSubscription( + Provider provider, + TaxInfo taxInfo); +} diff --git a/src/Core/Billing/Commands/Implementations/StartSubscriptionCommand.cs b/src/Core/Billing/Commands/Implementations/StartSubscriptionCommand.cs new file mode 100644 index 000000000..a3223f0ce --- /dev/null +++ b/src/Core/Billing/Commands/Implementations/StartSubscriptionCommand.cs @@ -0,0 +1,209 @@ +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Repositories; +using Bit.Core.Enums; +using Bit.Core.Models.Business; +using Bit.Core.Services; +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Microsoft.Extensions.Logging; +using Stripe; +using static Bit.Core.Billing.Utilities; + +namespace Bit.Core.Billing.Commands.Implementations; + +public class StartSubscriptionCommand( + IGlobalSettings globalSettings, + ILogger logger, + IProviderPlanRepository providerPlanRepository, + IProviderRepository providerRepository, + IStripeAdapter stripeAdapter) : IStartSubscriptionCommand +{ + public async Task StartSubscription( + Provider provider, + TaxInfo taxInfo) + { + ArgumentNullException.ThrowIfNull(provider); + ArgumentNullException.ThrowIfNull(taxInfo); + + if (!string.IsNullOrEmpty(provider.GatewaySubscriptionId)) + { + logger.LogWarning("Cannot start Provider subscription - Provider ({ID}) already has a {FieldName}", provider.Id, nameof(provider.GatewaySubscriptionId)); + + throw ContactSupport(); + } + + if (string.IsNullOrEmpty(taxInfo.BillingAddressCountry) || + string.IsNullOrEmpty(taxInfo.BillingAddressPostalCode)) + { + logger.LogError("Cannot start Provider subscription - Both the Provider's ({ID}) country and postal code are required", provider.Id); + + throw ContactSupport(); + } + + var customer = await GetOrCreateCustomerAsync(provider, taxInfo); + + if (taxInfo.BillingAddressCountry == "US" && customer.Tax is not { AutomaticTax: StripeConstants.AutomaticTaxStatus.Supported }) + { + logger.LogError("Cannot start Provider subscription - Provider's ({ProviderID}) Stripe customer ({CustomerID}) is in the US and does not support automatic tax", provider.Id, customer.Id); + + throw ContactSupport(); + } + + var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); + + if (providerPlans == null || providerPlans.Count == 0) + { + logger.LogError("Cannot start Provider subscription - Provider ({ID}) has no configured plans", provider.Id); + + throw ContactSupport(); + } + + var subscriptionItemOptionsList = new List(); + + var teamsProviderPlan = + providerPlans.SingleOrDefault(providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly); + + if (teamsProviderPlan == null) + { + logger.LogError("Cannot start Provider subscription - Provider ({ID}) has no configured Teams Monthly plan", provider.Id); + + throw ContactSupport(); + } + + var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + + subscriptionItemOptionsList.Add(new SubscriptionItemOptions + { + Price = teamsPlan.PasswordManager.StripeSeatPlanId, + Quantity = teamsProviderPlan.SeatMinimum + }); + + var enterpriseProviderPlan = + providerPlans.SingleOrDefault(providerPlan => providerPlan.PlanType == PlanType.EnterpriseMonthly); + + if (enterpriseProviderPlan == null) + { + logger.LogError("Cannot start Provider subscription - Provider ({ID}) has no configured Enterprise Monthly plan", provider.Id); + + throw ContactSupport(); + } + + var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + + subscriptionItemOptionsList.Add(new SubscriptionItemOptions + { + Price = enterprisePlan.PasswordManager.StripeSeatPlanId, + Quantity = enterpriseProviderPlan.SeatMinimum + }); + + var subscriptionCreateOptions = new SubscriptionCreateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions + { + Enabled = true + }, + CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, + Customer = customer.Id, + DaysUntilDue = 30, + Items = subscriptionItemOptionsList, + Metadata = new Dictionary + { + { "providerId", provider.Id.ToString() } + }, + OffSession = true, + ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations + }; + + var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + + provider.GatewaySubscriptionId = subscription.Id; + + if (subscription.Status == StripeConstants.SubscriptionStatus.Incomplete) + { + await providerRepository.ReplaceAsync(provider); + + logger.LogError("Started incomplete Provider ({ProviderID}) subscription ({SubscriptionID})", provider.Id, subscription.Id); + + throw ContactSupport(); + } + + provider.Status = ProviderStatusType.Billable; + + await providerRepository.ReplaceAsync(provider); + } + + // ReSharper disable once SuggestBaseTypeForParameter + private async Task GetOrCreateCustomerAsync( + Provider provider, + TaxInfo taxInfo) + { + if (!string.IsNullOrEmpty(provider.GatewayCustomerId)) + { + var existingCustomer = await stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, new CustomerGetOptions + { + Expand = ["tax"] + }); + + if (existingCustomer != null) + { + return existingCustomer; + } + + logger.LogError("Cannot start Provider subscription - Provider's ({ProviderID}) {CustomerIDFieldName} did not relate to a Stripe customer", provider.Id, nameof(provider.GatewayCustomerId)); + + throw ContactSupport(); + } + + var providerDisplayName = provider.DisplayName(); + + var customerCreateOptions = new CustomerCreateOptions + { + Address = new AddressOptions + { + Country = taxInfo.BillingAddressCountry, + PostalCode = taxInfo.BillingAddressPostalCode, + Line1 = taxInfo.BillingAddressLine1, + Line2 = taxInfo.BillingAddressLine2, + City = taxInfo.BillingAddressCity, + State = taxInfo.BillingAddressState + }, + Coupon = "msp-discount-35", + Description = provider.DisplayBusinessName(), + Email = provider.BillingEmail, + Expand = ["tax"], + InvoiceSettings = new CustomerInvoiceSettingsOptions + { + CustomFields = + [ + new CustomerInvoiceSettingsCustomFieldOptions + { + Name = provider.SubscriberType(), + Value = providerDisplayName.Length <= 30 + ? providerDisplayName + : providerDisplayName[..30] + } + ] + }, + Metadata = new Dictionary + { + { "region", globalSettings.BaseServiceUri.CloudRegion } + }, + TaxIdData = taxInfo.HasTaxId ? + [ + new CustomerTaxIdDataOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber } + ] + : null + }; + + var createdCustomer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + + provider.GatewayCustomerId = createdCustomer.Id; + + await providerRepository.ReplaceAsync(provider); + + return createdCustomer; + } +} diff --git a/src/Core/Billing/Constants/StripeConstants.cs b/src/Core/Billing/Constants/StripeConstants.cs new file mode 100644 index 000000000..9fd4e8489 --- /dev/null +++ b/src/Core/Billing/Constants/StripeConstants.cs @@ -0,0 +1,37 @@ +namespace Bit.Core.Billing.Constants; + +public static class StripeConstants +{ + public static class AutomaticTaxStatus + { + public const string Failed = "failed"; + public const string NotCollecting = "not_collecting"; + public const string Supported = "supported"; + public const string UnrecognizedLocation = "unrecognized_location"; + } + + public static class CollectionMethod + { + public const string ChargeAutomatically = "charge_automatically"; + public const string SendInvoice = "send_invoice"; + } + + public static class ProrationBehavior + { + public const string AlwaysInvoice = "always_invoice"; + public const string CreateProrations = "create_prorations"; + public const string None = "none"; + } + + public static class SubscriptionStatus + { + public const string Trialing = "trialing"; + public const string Active = "active"; + public const string Incomplete = "incomplete"; + public const string IncompleteExpired = "incomplete_expired"; + public const string PastDue = "past_due"; + public const string Canceled = "canceled"; + public const string Unpaid = "unpaid"; + public const string Paused = "paused"; + } +} diff --git a/src/Core/Billing/Constants/StripeCustomerAutomaticTaxStatus.cs b/src/Core/Billing/Constants/StripeCustomerAutomaticTaxStatus.cs deleted file mode 100644 index f9f352647..000000000 --- a/src/Core/Billing/Constants/StripeCustomerAutomaticTaxStatus.cs +++ /dev/null @@ -1,9 +0,0 @@ -namespace Bit.Core.Billing.Constants; - -public static class StripeCustomerAutomaticTaxStatus -{ - public const string Failed = "failed"; - public const string NotCollecting = "not_collecting"; - public const string Supported = "supported"; - public const string UnrecognizedLocation = "unrecognized_location"; -} diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index 8e28b2339..c4f25e2f6 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -19,5 +19,6 @@ public static class ServiceCollectionExtensions services.AddTransient(); services.AddTransient(); services.AddTransient(); + services.AddTransient(); } } diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index 234543a8f..2fd1d7237 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -1923,7 +1923,7 @@ public class StripePaymentService : IPaymentService /// /// private static bool CustomerHasTaxLocationVerified(Customer customer) => - customer?.Tax?.AutomaticTax == StripeCustomerAutomaticTaxStatus.Supported; + customer?.Tax?.AutomaticTax == StripeConstants.AutomaticTaxStatus.Supported; // We are taking only first 30 characters of the SubscriberName because stripe provide // for 30 characters for custom_fields,see the link: https://stripe.com/docs/api/invoices/create diff --git a/test/Core.Test/Billing/Commands/StartSubscriptionCommandTests.cs b/test/Core.Test/Billing/Commands/StartSubscriptionCommandTests.cs new file mode 100644 index 000000000..308126380 --- /dev/null +++ b/test/Core.Test/Billing/Commands/StartSubscriptionCommandTests.cs @@ -0,0 +1,446 @@ +using System.Net; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Commands.Implementations; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Entities; +using Bit.Core.Billing.Repositories; +using Bit.Core.Enums; +using Bit.Core.Models.Business; +using Bit.Core.Services; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Stripe; +using Xunit; + +using static Bit.Core.Test.Billing.Utilities; + +namespace Bit.Core.Test.Billing.Commands; + +[SutProviderCustomize] +public class StartSubscriptionCommandTests +{ + private const string _customerId = "customer_id"; + private const string _subscriptionId = "subscription_id"; + + // These tests are only trying to assert on the thrown exceptions and thus use the least amount of data setup possible. + #region Error Cases + [Theory, BitAutoData] + public async Task StartSubscription_NullProvider_ThrowsArgumentNullException( + SutProvider sutProvider, + TaxInfo taxInfo) => + await Assert.ThrowsAsync(() => sutProvider.Sut.StartSubscription(null, taxInfo)); + + [Theory, BitAutoData] + public async Task StartSubscription_NullTaxInfo_ThrowsArgumentNullException( + SutProvider sutProvider, + Provider provider) => + await Assert.ThrowsAsync(() => sutProvider.Sut.StartSubscription(provider, null)); + + [Theory, BitAutoData] + public async Task StartSubscription_AlreadyHasGatewaySubscriptionId_ThrowsBillingException( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.GatewayCustomerId = _customerId; + + provider.GatewaySubscriptionId = _subscriptionId; + + await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider, taxInfo)); + + await DidNotRetrieveCustomerAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task StartSubscription_MissingCountry_ThrowsBillingException( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.GatewayCustomerId = _customerId; + + provider.GatewaySubscriptionId = null; + + taxInfo.BillingAddressCountry = null; + + await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider, taxInfo)); + + await DidNotRetrieveCustomerAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task StartSubscription_MissingPostalCode_ThrowsBillingException( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.GatewayCustomerId = _customerId; + + provider.GatewaySubscriptionId = null; + + taxInfo.BillingAddressPostalCode = null; + + await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider, taxInfo)); + + await DidNotRetrieveCustomerAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task StartSubscription_MissingStripeCustomer_ThrowsBillingException( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.GatewayCustomerId = _customerId; + + provider.GatewaySubscriptionId = null; + + SetCustomerRetrieval(sutProvider, null); + + await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider, taxInfo)); + + await DidNotRetrieveProviderPlansAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task StartSubscription_CustomerDoesNotSupportAutomaticTax_ThrowsBillingException( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.GatewayCustomerId = _customerId; + + provider.GatewaySubscriptionId = null; + + taxInfo.BillingAddressCountry = "US"; + + SetCustomerRetrieval(sutProvider, new Customer + { + Id = _customerId, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.NotCollecting + } + }); + + await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider, taxInfo)); + + await DidNotRetrieveProviderPlansAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task StartSubscription_NoProviderPlans_ThrowsBillingException( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.GatewayCustomerId = _customerId; + + provider.GatewaySubscriptionId = null; + + SetCustomerRetrieval(sutProvider, new Customer + { + Id = _customerId, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + } + }); + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(new List()); + + await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider, taxInfo)); + + await DidNotCreateSubscriptionAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task StartSubscription_NoProviderTeamsPlan_ThrowsBillingException( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.GatewayCustomerId = _customerId; + + provider.GatewaySubscriptionId = null; + + SetCustomerRetrieval(sutProvider, new Customer + { + Id = _customerId, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + } + }); + + var providerPlans = new List + { + new () + { + PlanType = PlanType.EnterpriseMonthly + } + }; + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(providerPlans); + + await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider, taxInfo)); + + await DidNotCreateSubscriptionAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task StartSubscription_NoProviderEnterprisePlan_ThrowsBillingException( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.GatewayCustomerId = _customerId; + + provider.GatewaySubscriptionId = null; + + SetCustomerRetrieval(sutProvider, new Customer + { + Id = _customerId, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + } + }); + + var providerPlans = new List + { + new () + { + PlanType = PlanType.TeamsMonthly + } + }; + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(providerPlans); + + await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider, taxInfo)); + + await DidNotCreateSubscriptionAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task StartSubscription_SubscriptionIncomplete_ThrowsBillingException( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.GatewayCustomerId = _customerId; + + provider.GatewaySubscriptionId = null; + + SetCustomerRetrieval(sutProvider, new Customer + { + Id = _customerId, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + } + }); + + var providerPlans = new List + { + new () + { + PlanType = PlanType.TeamsMonthly, + SeatMinimum = 100 + }, + new () + { + PlanType = PlanType.EnterpriseMonthly, + SeatMinimum = 100 + } + }; + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(providerPlans); + + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Any()).Returns(new Subscription + { + Id = _subscriptionId, + Status = StripeConstants.SubscriptionStatus.Incomplete + }); + + await ThrowsContactSupportAsync(() => sutProvider.Sut.StartSubscription(provider, taxInfo)); + + await sutProvider.GetDependency().Received(1).ReplaceAsync(provider); + } + #endregion + + #region Success Cases + [Theory, BitAutoData] + public async Task StartSubscription_ExistingCustomer_Succeeds( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.GatewayCustomerId = _customerId; + + provider.GatewaySubscriptionId = null; + + SetCustomerRetrieval(sutProvider, new Customer + { + Id = _customerId, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + } + }); + + var providerPlans = new List + { + new () + { + PlanType = PlanType.TeamsMonthly, + SeatMinimum = 100 + }, + new () + { + PlanType = PlanType.EnterpriseMonthly, + SeatMinimum = 100 + } + }; + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(providerPlans); + + var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sub => + sub.AutomaticTax.Enabled == true && + sub.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && + sub.Customer == _customerId && + sub.DaysUntilDue == 30 && + sub.Items.Count == 2 && + sub.Items.ElementAt(0).Price == teamsPlan.PasswordManager.StripeSeatPlanId && + sub.Items.ElementAt(0).Quantity == 100 && + sub.Items.ElementAt(1).Price == enterprisePlan.PasswordManager.StripeSeatPlanId && + sub.Items.ElementAt(1).Quantity == 100 && + sub.Metadata["providerId"] == provider.Id.ToString() && + sub.OffSession == true && + sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations)).Returns(new Subscription + { + Id = _subscriptionId, + Status = StripeConstants.SubscriptionStatus.Active + }); + + await sutProvider.Sut.StartSubscription(provider, taxInfo); + + await sutProvider.GetDependency().Received(1).ReplaceAsync(provider); + } + + [Theory, BitAutoData] + public async Task StartSubscription_NewCustomer_Succeeds( + SutProvider sutProvider, + Provider provider, + TaxInfo taxInfo) + { + provider.GatewayCustomerId = null; + + provider.GatewaySubscriptionId = null; + + provider.Name = "MSP"; + + taxInfo.BillingAddressCountry = "AD"; + + sutProvider.GetDependency().CustomerCreateAsync(Arg.Is(o => + o.Address.Country == taxInfo.BillingAddressCountry && + o.Address.PostalCode == taxInfo.BillingAddressPostalCode && + o.Address.Line1 == taxInfo.BillingAddressLine1 && + o.Address.Line2 == taxInfo.BillingAddressLine2 && + o.Address.City == taxInfo.BillingAddressCity && + o.Address.State == taxInfo.BillingAddressState && + o.Coupon == "msp-discount-35" && + o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Email == provider.BillingEmail && + o.Expand.FirstOrDefault() == "tax" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.Metadata["region"] == "" && + o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && + o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + .Returns(new Customer + { + Id = _customerId, + Tax = new CustomerTax + { + AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported + } + }); + + var providerPlans = new List + { + new () + { + PlanType = PlanType.TeamsMonthly, + SeatMinimum = 100 + }, + new () + { + PlanType = PlanType.EnterpriseMonthly, + SeatMinimum = 100 + } + }; + + sutProvider.GetDependency().GetByProviderId(provider.Id) + .Returns(providerPlans); + + var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + + sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sub => + sub.AutomaticTax.Enabled == true && + sub.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && + sub.Customer == _customerId && + sub.DaysUntilDue == 30 && + sub.Items.Count == 2 && + sub.Items.ElementAt(0).Price == teamsPlan.PasswordManager.StripeSeatPlanId && + sub.Items.ElementAt(0).Quantity == 100 && + sub.Items.ElementAt(1).Price == enterprisePlan.PasswordManager.StripeSeatPlanId && + sub.Items.ElementAt(1).Quantity == 100 && + sub.Metadata["providerId"] == provider.Id.ToString() && + sub.OffSession == true && + sub.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations)).Returns(new Subscription + { + Id = _subscriptionId, + Status = StripeConstants.SubscriptionStatus.Active + }); + + await sutProvider.Sut.StartSubscription(provider, taxInfo); + + await sutProvider.GetDependency().Received(2).ReplaceAsync(provider); + } + #endregion + + private static async Task DidNotCreateSubscriptionAsync(SutProvider sutProvider) => + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SubscriptionCreateAsync(Arg.Any()); + + private static async Task DidNotRetrieveCustomerAsync(SutProvider sutProvider) => + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .CustomerGetAsync(Arg.Any(), Arg.Any()); + + private static async Task DidNotRetrieveProviderPlansAsync(SutProvider sutProvider) => + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .GetByProviderId(Arg.Any()); + + private static void SetCustomerRetrieval(SutProvider sutProvider, + Customer customer) => sutProvider.GetDependency() + .CustomerGetAsync(_customerId, Arg.Is(o => o.Expand.FirstOrDefault() == "tax")) + .Returns(customer); +}