1
0
mirror of https://github.com/bitwarden/server.git synced 2024-11-24 12:35:25 +01:00

[PM-5766] Enabled Automatic Tax for all customers (#3685)

* Removed TaxRate logic when creating or updating a Stripe subscription and replaced it with AutomaticTax enabled flag

* Updated Stripe webhook to update subscription to automatically calculate tax

* Removed TaxRate unit tests since Stripe now handles tax

* Removed test proration logic

* Including taxInfo when updating payment method

* Adding the address to the upgrade free org flow if it doesn't exist

* Fixed failing tests and added a new test to validate that the customer is updated
This commit is contained in:
Conner Turnbull 2024-01-29 09:48:59 -05:00 committed by GitHub
parent c2b4ee7eac
commit a2e6550b61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 127 additions and 167 deletions

View File

@ -21,7 +21,6 @@ using Customer = Stripe.Customer;
using Event = Stripe.Event;
using PaymentMethod = Stripe.PaymentMethod;
using Subscription = Stripe.Subscription;
using TaxRate = Bit.Core.Entities.TaxRate;
using Transaction = Bit.Core.Entities.Transaction;
using TransactionType = Bit.Core.Enums.TransactionType;
@ -223,9 +222,17 @@ public class StripeController : Controller
$"Received null Subscription from Stripe for ID '{invoice.SubscriptionId}' while processing Event with ID '{parsedEvent.Id}'");
}
var updatedSubscription = await VerifyCorrectTaxRateForCharge(invoice, subscription);
if (!subscription.AutomaticTax.Enabled)
{
subscription = await _stripeFacade.UpdateSubscription(subscription.Id,
new SubscriptionUpdateOptions
{
DefaultTaxRates = new List<string>(),
AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }
});
}
var (organizationId, userId) = GetIdsFromMetaData(updatedSubscription.Metadata);
var (organizationId, userId) = GetIdsFromMetaData(subscription.Metadata);
var invoiceLineItemDescriptions = invoice.Lines.Select(i => i.Description).ToList();
@ -246,7 +253,7 @@ public class StripeController : Controller
if (organizationId.HasValue)
{
if (IsSponsoredSubscription(updatedSubscription))
if (IsSponsoredSubscription(subscription))
{
await _validateSponsorshipCommand.ValidateSponsorshipAsync(organizationId.Value);
}
@ -828,32 +835,6 @@ public class StripeController : Controller
invoice.BillingReason == "subscription_cycle" && invoice.SubscriptionId != null;
}
private async Task<Subscription> VerifyCorrectTaxRateForCharge(Invoice invoice, Subscription subscription)
{
if (!string.IsNullOrWhiteSpace(invoice?.CustomerAddress?.Country) && !string.IsNullOrWhiteSpace(invoice?.CustomerAddress?.PostalCode))
{
var localBitwardenTaxRates = await _taxRateRepository.GetByLocationAsync(
new TaxRate()
{
Country = invoice.CustomerAddress.Country,
PostalCode = invoice.CustomerAddress.PostalCode
}
);
if (localBitwardenTaxRates.Any())
{
var stripeTaxRate = await new TaxRateService().GetAsync(localBitwardenTaxRates.First().Id);
if (stripeTaxRate != null && !subscription.DefaultTaxRates.Any(x => x == stripeTaxRate))
{
subscription.DefaultTaxRates = new List<Stripe.TaxRate> { stripeTaxRate };
var subscriptionOptions = new SubscriptionUpdateOptions() { DefaultTaxRates = new List<string>() { stripeTaxRate.Id } };
subscription = await new SubscriptionService().UpdateAsync(subscription.Id, subscriptionOptions);
}
}
}
return subscription;
}
private static bool IsSponsoredSubscription(Subscription subscription) =>
StaticStore.SponsoredPlans.Any(p => p.StripePlanId == subscription.Id);

View File

@ -33,4 +33,10 @@ public interface IStripeFacade
SubscriptionGetOptions subscriptionGetOptions = null,
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default);
Task<Subscription> UpdateSubscription(
string subscriptionId,
SubscriptionUpdateOptions subscriptionGetOptions = null,
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default);
}

View File

@ -44,4 +44,11 @@ public class StripeFacade : IStripeFacade
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default) =>
await _subscriptionService.GetAsync(subscriptionId, subscriptionGetOptions, requestOptions, cancellationToken);
public async Task<Subscription> UpdateSubscription(
string subscriptionId,
SubscriptionUpdateOptions subscriptionUpdateOptions = null,
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default) =>
await _subscriptionService.UpdateAsync(subscriptionId, subscriptionUpdateOptions, requestOptions, cancellationToken);
}

View File

@ -140,7 +140,7 @@ public class OrganizationService : IOrganizationService
await _paymentService.SaveTaxInfoAsync(organization, taxInfo);
var updated = await _paymentService.UpdatePaymentMethodAsync(organization,
paymentMethodType, paymentToken);
paymentMethodType, paymentToken, taxInfo);
if (updated)
{
await ReplaceAndUpdateCacheAsync(organization);

View File

@ -98,23 +98,6 @@ public class StripePaymentService : IPaymentService
throw new GatewayException("Payment method is not supported at this time.");
}
if (taxInfo != null && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressCountry) && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressPostalCode))
{
var taxRateSearch = new TaxRate
{
Country = taxInfo.BillingAddressCountry,
PostalCode = taxInfo.BillingAddressPostalCode
};
var taxRates = await _taxRateRepository.GetByLocationAsync(taxRateSearch);
// should only be one tax rate per country/zip combo
var taxRate = taxRates.FirstOrDefault();
if (taxRate != null)
{
taxInfo.StripeTaxRateId = taxRate.Id;
}
}
var subCreateOptions = new OrganizationPurchaseSubscriptionOptions(org, plan, taxInfo, additionalSeats, additionalStorageGb, premiumAccessAddon
, additionalSmSeats, additionalServiceAccount);
@ -163,6 +146,9 @@ public class StripePaymentService : IPaymentService
});
subCreateOptions.AddExpand("latest_invoice.payment_intent");
subCreateOptions.Customer = customer.Id;
subCreateOptions.AutomaticTax = new Stripe.SubscriptionAutomaticTaxOptions { Enabled = true };
subscription = await _stripeAdapter.SubscriptionCreateAsync(subCreateOptions);
if (subscription.Status == "incomplete" && subscription.LatestInvoice?.PaymentIntent != null)
{
@ -244,25 +230,31 @@ public class StripePaymentService : IPaymentService
throw new GatewayException("Could not find customer payment profile.");
}
var taxInfo = upgrade.TaxInfo;
if (taxInfo != null && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressCountry) && !string.IsNullOrWhiteSpace(taxInfo.BillingAddressPostalCode))
if (customer.Address is null &&
!string.IsNullOrEmpty(upgrade.TaxInfo?.BillingAddressCountry) &&
!string.IsNullOrEmpty(upgrade.TaxInfo?.BillingAddressPostalCode))
{
var taxRateSearch = new TaxRate
var addressOptions = new Stripe.AddressOptions
{
Country = taxInfo.BillingAddressCountry,
PostalCode = taxInfo.BillingAddressPostalCode
Country = upgrade.TaxInfo.BillingAddressCountry,
PostalCode = upgrade.TaxInfo.BillingAddressPostalCode,
// Line1 is required in Stripe's API, suggestion in Docs is to use Business Name instead.
Line1 = upgrade.TaxInfo.BillingAddressLine1 ?? string.Empty,
Line2 = upgrade.TaxInfo.BillingAddressLine2,
City = upgrade.TaxInfo.BillingAddressCity,
State = upgrade.TaxInfo.BillingAddressState,
};
var taxRates = await _taxRateRepository.GetByLocationAsync(taxRateSearch);
// should only be one tax rate per country/zip combo
var taxRate = taxRates.FirstOrDefault();
if (taxRate != null)
{
taxInfo.StripeTaxRateId = taxRate.Id;
}
var customerUpdateOptions = new Stripe.CustomerUpdateOptions { Address = addressOptions };
customerUpdateOptions.AddExpand("default_source");
customerUpdateOptions.AddExpand("invoice_settings.default_payment_method");
customer = await _stripeAdapter.CustomerUpdateAsync(org.GatewayCustomerId, customerUpdateOptions);
}
var subCreateOptions = new OrganizationUpgradeSubscriptionOptions(customer.Id, org, plan, upgrade);
var subCreateOptions = new OrganizationUpgradeSubscriptionOptions(customer.Id, org, plan, upgrade)
{
DefaultTaxRates = new List<string>(),
AutomaticTax = new Stripe.SubscriptionAutomaticTaxOptions { Enabled = true }
};
var (stripePaymentMethod, paymentMethodType) = IdentifyPaymentMethod(customer, subCreateOptions);
var subscription = await ChargeForNewSubscriptionAsync(org, customer, false,
@ -459,26 +451,6 @@ public class StripePaymentService : IPaymentService
Quantity = 1
});
if (!string.IsNullOrWhiteSpace(taxInfo?.BillingAddressCountry)
&& !string.IsNullOrWhiteSpace(taxInfo?.BillingAddressPostalCode))
{
var taxRates = await _taxRateRepository.GetByLocationAsync(
new TaxRate()
{
Country = taxInfo.BillingAddressCountry,
PostalCode = taxInfo.BillingAddressPostalCode
}
);
var taxRate = taxRates.FirstOrDefault();
if (taxRate != null)
{
subCreateOptions.DefaultTaxRates = new List<string>(1)
{
taxRate.Id
};
}
}
if (additionalStorageGb > 0)
{
subCreateOptions.Items.Add(new Stripe.SubscriptionItemOptions
@ -488,6 +460,8 @@ public class StripePaymentService : IPaymentService
});
}
subCreateOptions.AutomaticTax = new Stripe.SubscriptionAutomaticTaxOptions { Enabled = true };
var subscription = await ChargeForNewSubscriptionAsync(user, customer, createdStripeCustomer,
stripePaymentMethod, paymentMethodType, subCreateOptions, braintreeCustomer);
@ -525,7 +499,8 @@ public class StripePaymentService : IPaymentService
{
Customer = customer.Id,
SubscriptionItems = ToInvoiceSubscriptionItemOptions(subCreateOptions.Items),
SubscriptionDefaultTaxRates = subCreateOptions.DefaultTaxRates,
AutomaticTax =
new Stripe.InvoiceAutomaticTaxOptions { Enabled = subCreateOptions.AutomaticTax.Enabled }
});
if (previewInvoice.AmountDue > 0)
@ -583,7 +558,8 @@ public class StripePaymentService : IPaymentService
{
Customer = customer.Id,
SubscriptionItems = ToInvoiceSubscriptionItemOptions(subCreateOptions.Items),
SubscriptionDefaultTaxRates = subCreateOptions.DefaultTaxRates,
AutomaticTax =
new Stripe.InvoiceAutomaticTaxOptions { Enabled = subCreateOptions.AutomaticTax.Enabled }
});
if (previewInvoice.AmountDue > 0)
{
@ -593,6 +569,7 @@ public class StripePaymentService : IPaymentService
subCreateOptions.OffSession = true;
subCreateOptions.AddExpand("latest_invoice.payment_intent");
subCreateOptions.AutomaticTax = new Stripe.SubscriptionAutomaticTaxOptions { Enabled = true };
subscription = await _stripeAdapter.SubscriptionCreateAsync(subCreateOptions);
if (subscription.Status == "incomplete" && subscription.LatestInvoice?.PaymentIntent != null)
{
@ -692,6 +669,8 @@ public class StripePaymentService : IPaymentService
DaysUntilDue = daysUntilDue ?? 1,
CollectionMethod = "send_invoice",
ProrationDate = prorationDate,
DefaultTaxRates = new List<string>(),
AutomaticTax = new Stripe.SubscriptionAutomaticTaxOptions { Enabled = true }
};
if (!subscriptionUpdate.UpdateNeeded(sub))
@ -700,28 +679,6 @@ public class StripePaymentService : IPaymentService
return null;
}
var customer = await _stripeAdapter.CustomerGetAsync(sub.CustomerId);
if (!string.IsNullOrWhiteSpace(customer?.Address?.Country)
&& !string.IsNullOrWhiteSpace(customer?.Address?.PostalCode))
{
var taxRates = await _taxRateRepository.GetByLocationAsync(
new TaxRate()
{
Country = customer.Address.Country,
PostalCode = customer.Address.PostalCode
}
);
var taxRate = taxRates.FirstOrDefault();
if (taxRate != null && !sub.DefaultTaxRates.Any(x => x.Equals(taxRate.Id)))
{
subUpdateOptions.DefaultTaxRates = new List<string>(1)
{
taxRate.Id
};
}
}
string paymentIntentClientSecret = null;
try
{

View File

@ -2,7 +2,6 @@
using Bit.Core.Enums;
using Bit.Core.Exceptions;
using Bit.Core.Models.Business;
using Bit.Core.Repositories;
using Bit.Core.Services;
using Bit.Core.Settings;
using Bit.Core.Utilities;
@ -14,7 +13,6 @@ using Xunit;
using Customer = Braintree.Customer;
using PaymentMethod = Braintree.PaymentMethod;
using PaymentMethodType = Bit.Core.Enums.PaymentMethodType;
using TaxRate = Bit.Core.Entities.TaxRate;
namespace Bit.Core.Test.Services;
@ -259,65 +257,6 @@ public class StripePaymentServiceTests
));
}
[Theory, BitAutoData]
public async void PurchaseOrganizationAsync_Stripe_TaxRate(SutProvider<StripePaymentService> sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo)
{
var plan = StaticStore.GetPlan(PlanType.EnterpriseAnnually);
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer
{
Id = "C-1",
});
stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription
{
Id = "S-1",
CurrentPeriodEnd = DateTime.Today.AddDays(10),
});
sutProvider.GetDependency<ITaxRateRepository>().GetByLocationAsync(Arg.Is<TaxRate>(t =>
t.Country == taxInfo.BillingAddressCountry && t.PostalCode == taxInfo.BillingAddressPostalCode))
.Returns(new List<TaxRate> { new() { Id = "T-1" } });
var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 0, 0, false, taxInfo);
Assert.Null(result);
await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is<Stripe.SubscriptionCreateOptions>(s =>
s.DefaultTaxRates.Count == 1 &&
s.DefaultTaxRates[0] == "T-1"
));
}
[Theory, BitAutoData]
public async void PurchaseOrganizationAsync_Stripe_TaxRate_SM(SutProvider<StripePaymentService> sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo)
{
var plan = StaticStore.GetPlan(PlanType.EnterpriseAnnually);
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
stripeAdapter.CustomerCreateAsync(default).ReturnsForAnyArgs(new Stripe.Customer
{
Id = "C-1",
});
stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription
{
Id = "S-1",
CurrentPeriodEnd = DateTime.Today.AddDays(10),
});
sutProvider.GetDependency<ITaxRateRepository>().GetByLocationAsync(Arg.Is<TaxRate>(t =>
t.Country == taxInfo.BillingAddressCountry && t.PostalCode == taxInfo.BillingAddressPostalCode))
.Returns(new List<TaxRate> { new() { Id = "T-1" } });
var result = await sutProvider.Sut.PurchaseOrganizationAsync(organization, PaymentMethodType.Card, paymentToken, plan, 2, 2,
false, taxInfo, false, 2, 2);
Assert.Null(result);
await stripeAdapter.Received().SubscriptionCreateAsync(Arg.Is<Stripe.SubscriptionCreateOptions>(s =>
s.DefaultTaxRates.Count == 1 &&
s.DefaultTaxRates[0] == "T-1"
));
}
[Theory, BitAutoData]
public async void PurchaseOrganizationAsync_Stripe_Declined(SutProvider<StripePaymentService> sutProvider, Organization organization, string paymentToken, TaxInfo taxInfo)
{
@ -678,6 +617,14 @@ public class StripePaymentServiceTests
{ "btCustomerId", "B-123" },
}
});
stripeAdapter.CustomerUpdateAsync(default).ReturnsForAnyArgs(new Stripe.Customer
{
Id = "C-1",
Metadata = new Dictionary<string, string>
{
{ "btCustomerId", "B-123" },
}
});
stripeAdapter.InvoiceUpcomingAsync(default).ReturnsForAnyArgs(new Stripe.Invoice
{
PaymentIntent = new Stripe.PaymentIntent { Status = "requires_payment_method", },
@ -715,6 +662,14 @@ public class StripePaymentServiceTests
{ "btCustomerId", "B-123" },
}
});
stripeAdapter.CustomerUpdateAsync(default).ReturnsForAnyArgs(new Stripe.Customer
{
Id = "C-1",
Metadata = new Dictionary<string, string>
{
{ "btCustomerId", "B-123" },
}
});
stripeAdapter.InvoiceUpcomingAsync(default).ReturnsForAnyArgs(new Stripe.Invoice
{
PaymentIntent = new Stripe.PaymentIntent { Status = "requires_payment_method", },
@ -737,4 +692,58 @@ public class StripePaymentServiceTests
Assert.Null(result);
}
[Theory, BitAutoData]
public async void UpgradeFreeOrganizationAsync_WhenCustomerHasNoAddress_UpdatesCustomerAddressWithTaxInfo(
SutProvider<StripePaymentService> sutProvider,
Organization organization,
TaxInfo taxInfo)
{
organization.GatewaySubscriptionId = null;
var stripeAdapter = sutProvider.GetDependency<IStripeAdapter>();
stripeAdapter.CustomerGetAsync(default).ReturnsForAnyArgs(new Stripe.Customer
{
Id = "C-1",
Metadata = new Dictionary<string, string>
{
{ "btCustomerId", "B-123" },
}
});
stripeAdapter.CustomerUpdateAsync(default).ReturnsForAnyArgs(new Stripe.Customer
{
Id = "C-1",
Metadata = new Dictionary<string, string>
{
{ "btCustomerId", "B-123" },
}
});
stripeAdapter.InvoiceUpcomingAsync(default).ReturnsForAnyArgs(new Stripe.Invoice
{
PaymentIntent = new Stripe.PaymentIntent { Status = "requires_payment_method", },
AmountDue = 0
});
stripeAdapter.SubscriptionCreateAsync(default).ReturnsForAnyArgs(new Stripe.Subscription { });
var upgrade = new OrganizationUpgrade()
{
AdditionalStorageGb = 1,
AdditionalSeats = 10,
PremiumAccessAddon = false,
TaxInfo = taxInfo,
AdditionalSmSeats = 5,
AdditionalServiceAccounts = 50
};
var plan = StaticStore.GetPlan(PlanType.EnterpriseAnnually);
_ = await sutProvider.Sut.UpgradeFreeOrganizationAsync(organization, plan, upgrade);
await stripeAdapter.Received()
.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is<Stripe.CustomerUpdateOptions>(c =>
c.Address.Country == taxInfo.BillingAddressCountry &&
c.Address.PostalCode == taxInfo.BillingAddressPostalCode &&
c.Address.Line1 == taxInfo.BillingAddressLine1 &&
c.Address.Line2 == taxInfo.BillingAddressLine2 &&
c.Address.City == taxInfo.BillingAddressCity &&
c.Address.State == taxInfo.BillingAddressState));
}
}