diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index ac2a14846..1f84e1b0e 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Business; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -32,6 +33,7 @@ public class StripePaymentService : IPaymentService private readonly IStripeAdapter _stripeAdapter; private readonly IGlobalSettings _globalSettings; private readonly IFeatureService _featureService; + private readonly ITaxService _taxService; public StripePaymentService( ITransactionRepository transactionRepository, @@ -40,7 +42,8 @@ public class StripePaymentService : IPaymentService IStripeAdapter stripeAdapter, Braintree.IBraintreeGateway braintreeGateway, IGlobalSettings globalSettings, - IFeatureService featureService) + IFeatureService featureService, + ITaxService taxService) { _transactionRepository = transactionRepository; _logger = logger; @@ -49,6 +52,7 @@ public class StripePaymentService : IPaymentService _btGateway = braintreeGateway; _globalSettings = globalSettings; _featureService = featureService; + _taxService = taxService; } public async Task PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType, @@ -112,6 +116,20 @@ public class StripePaymentService : IPaymentService Subscription subscription; try { + if (taxInfo.TaxIdNumber != null && taxInfo.TaxIdType == null) + { + taxInfo.TaxIdType = _taxService.GetStripeTaxCode(taxInfo.BillingAddressCountry, + taxInfo.TaxIdNumber); + + if (taxInfo.TaxIdType == null) + { + _logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", + taxInfo.BillingAddressCountry, + taxInfo.TaxIdNumber); + throw new BadRequestException("billingTaxIdTypeInferenceError"); + } + } + var customerCreateOptions = new CustomerCreateOptions { Description = org.DisplayBusinessName(), @@ -146,12 +164,9 @@ public class StripePaymentService : IPaymentService City = taxInfo?.BillingAddressCity, State = taxInfo?.BillingAddressState, }, - TaxIdData = taxInfo?.HasTaxId != true - ? null - : - [ - new CustomerTaxIdDataOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber, } - ], + TaxIdData = taxInfo.HasTaxId + ? [new CustomerTaxIdDataOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber }] + : null }; customerCreateOptions.AddExpand("tax"); diff --git a/test/Core.Test/Services/StripePaymentServiceTests.cs b/test/Core.Test/Services/StripePaymentServiceTests.cs index e15f07b11..35e1901a2 100644 --- a/test/Core.Test/Services/StripePaymentServiceTests.cs +++ b/test/Core.Test/Services/StripePaymentServiceTests.cs @@ -77,7 +77,8 @@ public class StripePaymentServiceTests c.Address.Line2 == taxInfo.BillingAddressLine2 && c.Address.City == taxInfo.BillingAddressCity && c.Address.State == taxInfo.BillingAddressState && - c.TaxIdData == null + c.TaxIdData.First().Value == taxInfo.TaxIdNumber && + c.TaxIdData.First().Type == taxInfo.TaxIdType )); await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => @@ -134,7 +135,8 @@ public class StripePaymentServiceTests c.Address.Line2 == taxInfo.BillingAddressLine2 && c.Address.City == taxInfo.BillingAddressCity && c.Address.State == taxInfo.BillingAddressState && - c.TaxIdData == null + c.TaxIdData.First().Value == taxInfo.TaxIdNumber && + c.TaxIdData.First().Type == taxInfo.TaxIdType )); await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => @@ -190,7 +192,8 @@ public class StripePaymentServiceTests c.Address.Line2 == taxInfo.BillingAddressLine2 && c.Address.City == taxInfo.BillingAddressCity && c.Address.State == taxInfo.BillingAddressState && - c.TaxIdData == null + c.TaxIdData.First().Value == taxInfo.TaxIdNumber && + c.TaxIdData.First().Type == taxInfo.TaxIdType )); await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => @@ -247,7 +250,8 @@ public class StripePaymentServiceTests c.Address.Line2 == taxInfo.BillingAddressLine2 && c.Address.City == taxInfo.BillingAddressCity && c.Address.State == taxInfo.BillingAddressState && - c.TaxIdData == null + c.TaxIdData.First().Value == taxInfo.TaxIdNumber && + c.TaxIdData.First().Type == taxInfo.TaxIdType )); await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => @@ -441,7 +445,8 @@ public class StripePaymentServiceTests c.Address.Line2 == taxInfo.BillingAddressLine2 && c.Address.City == taxInfo.BillingAddressCity && c.Address.State == taxInfo.BillingAddressState && - c.TaxIdData == null + c.TaxIdData.First().Value == taxInfo.TaxIdNumber && + c.TaxIdData.First().Type == taxInfo.TaxIdType )); await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s => @@ -510,7 +515,8 @@ public class StripePaymentServiceTests c.Address.Line2 == taxInfo.BillingAddressLine2 && c.Address.City == taxInfo.BillingAddressCity && c.Address.State == taxInfo.BillingAddressState && - c.TaxIdData == null + c.TaxIdData.First().Value == taxInfo.TaxIdNumber && + c.TaxIdData.First().Type == taxInfo.TaxIdType )); await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is(s =>